# TINY CLIP = fast but less accurate

In [None]:
import io
import numpy as np
import onnxruntime as ort
from PIL import Image
from transformers import CLIPTokenizerFast

# ---------------------------------------------------------------
# 0) Load ONNX model session once
img_session = ort.InferenceSession("tiny_clip/model.onnx")

# ---------------------------------------------------------------
# 1) Utilities
def compress_image(img, quality=100):
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=quality)
    return buf.getvalue()

def embed(session: ort.InferenceSession, jpeg_bytes: bytes) -> np.ndarray:
    """JPEG bytes --> 512-D Tiny-CLIP vector."""
    img = (Image.open(io.BytesIO(jpeg_bytes))
           .convert("RGB")
           .resize((224, 224), Image.Resampling.BICUBIC))

    arr = (np.asarray(img, dtype=np.float32).transpose(2, 0, 1) / 127.5) - 1.0
    arr = arr[np.newaxis, ...]  # (1,3,224,224)

    ids = np.zeros((1, 77), dtype=np.int64)  # dummy
    mask = np.ones((1, 77), dtype=np.int64)  # dummy
    feeds = {}
    for inp in session.get_inputs():
        if "pixel" in inp.name:
            feeds[inp.name] = arr
        elif "mask" in inp.name:
            feeds[inp.name] = mask
        else:
            feeds[inp.name] = ids

    vec = session.run(["image_embeds"], feeds)[0]  # (1,512)
    return vec[0]

# ---------------------------------------------------------------
# 2) Load and embed image
full_img = Image.open("screenshot_002.png").convert("RGB")
jpeg_bytes = compress_image(full_img)
vec = embed(img_session, jpeg_bytes)
print(f"Vector shape: {vec.shape}, dtype: {vec.dtype}")

# ---------------------------------------------------------------
# 3) Classification Labels
import json

# Load labels from JSON
with open("labels.json", "r") as f:
    label_data = json.load(f)

APP_LABELS = label_data["apps"]
ACT_LABELS = label_data["actions"]
ACTION_PRIORS = label_data["action_priors"]
CATEGORY_LABELS = label_data["categories"]


# ---------------------------------------------------------------
# 4) Text Embedding (uses same ONNX session)
tok = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32")

def text_embed(sentences):
    toks = tok(sentences, padding=True, return_tensors="np")
    feeds = {
        "input_ids": toks["input_ids"],
        "attention_mask": toks["attention_mask"],
        "pixel_values": np.zeros((len(sentences), 3, 224, 224), dtype=np.float32)
    }
    return img_session.run(["text_embeds"], feeds)[0]

def apply_action_priors(app_name: str, act_scores: np.ndarray) -> np.ndarray:
    """
    Multiply action logits by the prior weights for the chosen app.
    Unknown pairs default to 1.0 (no change).
    """
    priors = ACTION_PRIORS.get(app_name, {})
    boosted = act_scores.copy()
    for act, weight in priors.items():
        try:
            j = ACT_LABELS.index(act)
            boosted[j] *= weight
        except ValueError:
            pass                         # action not in list
    return boosted

# ---------------------------------------------------------------
# 5) Build Label Embeddings
app_vecs = text_embed(APP_LABELS)
app_vecs /= np.linalg.norm(app_vecs, axis=1, keepdims=True)

act_vecs = text_embed(ACT_LABELS)
act_vecs /= np.linalg.norm(act_vecs, axis=1, keepdims=True)

# ---------------------------------------------------------------
# 6) Classification Function
def classify_with_priors(img_vec, top_k_apps=3):
    vec = img_vec / np.linalg.norm(img_vec)

    # --- raw similarity scores
    app_scores = app_vecs @ vec               # shape (n_apps,)
    act_scores = act_vecs @ vec               # shape (n_actions,)

    # --- choose top-K app candidates
    top_app_idx = app_scores.argsort()[-top_k_apps:][::-1]
    top_probs   = app_scores[top_app_idx]     # still cosine, not softmax

    # --- aggregate boosted action scores from each candidate app
    agg_act_scores = np.zeros_like(act_scores)
    for idx, score in zip(top_app_idx, top_probs):
        app_name   = APP_LABELS[idx]
        boosted    = apply_action_priors(app_name, act_scores)
        agg_act_scores += score * boosted      # weight by app confidence

    # --- final predictions
    best_app_idx = top_app_idx[0]
    best_act_idx = agg_act_scores.argmax()

    return {
        "app": APP_LABELS[best_app_idx],
        "action": ACT_LABELS[best_act_idx],
        "confidence": {
            "app": float(app_scores[best_app_idx]),
            "action_raw": float(act_scores[best_act_idx]),
            "action_boosted": float(agg_act_scores[best_act_idx])
        }
    }


# ---------------------------------------------------------------
# 7) Run Classification
result = classify_with_priors(vec, top_k_apps=2)
print(result)


  from .autonotebook import tqdm as notebook_tqdm


Vector shape: (512,), dtype: float32
{'app': 'text_editor', 'action': 'coding', 'confidence': {'app': 0.28129708766937256, 'action_raw': 0.2698988914489746, 'action_boosted': 0.18177959322929382}}


# Text to text embedding

In [3]:
import json
import numpy as np
from typing import List, Dict

# ----------------------------------------------------------------------
# >>> one-time globals (reuse across calls) <<<

# 1. flatten your hierarchical label dict
flat_labels, label2cat = [], {}
for cat, items in TEXT_LABELS.items():
    for lbl in items:
        flat_labels.append(lbl)
        label2cat[lbl] = cat

# 2. embed all label prompts once
label_vecs = text_embed(flat_labels)          # shape (N,512)
label_vecs /= np.linalg.norm(label_vecs, axis=1, keepdims=True)

# ----------------------------------------------------------------------
# helper: reconstruct plain text from key list (you already had this)
def reconstruct(keys: List[str]) -> str:
    buffer = []
    for k in keys:
        if k in ("Key.backspace", "Key.delete"):
            if buffer:
                buffer.pop()
        elif k == "Key.enter":
            buffer.append("\n")
        elif k.startswith("Key."):
            continue
        else:
            buffer.append(k)
    return "".join(buffer)

# ----------------------------------------------------------------------
def summarize_keystrokes(
        keys: List[str],
        top_k: int          = 5,
        return_json: bool   = True
    ) -> Dict:
    """
    keys         : list of pynput key names
    top_k        : number of top labels to report
    return_json  : if True → JSON string, else Python dict
    """
    plain_text = reconstruct(keys)

    # --- embed typed string --------------------------
    vec = text_embed([plain_text])[0]
    vec /= np.linalg.norm(vec)

    # --- similarity scores ---------------------------
    sims = label_vecs @ vec
    idx  = sims.argsort()[-top_k:][::-1]

    # --- build summary dict --------------------------
    best_idx = idx[0]
    summary = {
        "plain_text": plain_text,
        "best_label": {
            "label":      flat_labels[best_idx],
            "category":   label2cat[flat_labels[best_idx]],
            "similarity": float(sims[best_idx])
        },
        "top_k": [
            {
                "rank":       rank + 1,
                "label":      flat_labels[i],
                "category":   label2cat[flat_labels[i]],
                "similarity": float(sims[i])
            }
            for rank, i in enumerate(idx)
        ]
    }

    return json.dumps(summary, ensure_ascii=False, indent=2) if return_json else summary

# ----------------------------------------------------------------------
# >>> EXAMPLE USAGE <<<

raw_keys = [
    "Key.backspace", "t", "y", "p", "i", "n", "g", "Key.space",
    "s", "t", "u", "f", "f", "Key.space", "Key.backspace", ",",
    "Key.space", "d", "o", "i", "n", "g", "Key.space",
    "w", "o", "r", "k", "Key.space",
    "a", "n", "d", "Key.space",
    "t", "y", "p", "i", "n", "g", "Key.space",
    "s", "t", "u", "f", "f", "Key.enter"
]

print(summarize_keystrokes(raw_keys, top_k=3))


{
  "plain_text": "typingstuf,doingworkandtypingstuff\n",
  "best_label": {
    "label": "continuous_typing",
    "category": "length-intensity",
    "similarity": 0.8902647495269775
  },
  "top_k": [
    {
      "rank": 1,
      "label": "continuous_typing",
      "category": "length-intensity",
      "similarity": 0.8902647495269775
    },
    {
      "rank": 2,
      "label": "bug_report",
      "category": "content-domain",
      "similarity": 0.8863478899002075
    },
    {
      "rank": 3,
      "label": "very_short_reply",
      "category": "length-intensity",
      "similarity": 0.883529543876648
    }
  ]
}
