In [None]:
%%writefile qoa.py
#!/usr/bin/env python3
import argparse
import logging
from datetime import datetime
from typing import Any, Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel

logging.basicConfig(level=logging.INFO, format="%(asctime)s • %(levelname)s • %(message)s")
logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(device)
model.eval()

def safe_str(obj: Any) -> str:
    if isinstance(obj, (str, int, float)):
        return str(obj)
    if isinstance(obj, datetime):
        return obj.strftime("%Y-%m-%d %H:%M:%S")
    if obj is None:
        return ""
    try:
        if pd.isna(obj):
            return ""
    except Exception:
        pass
    return str(obj)

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # (B, T, H)
    mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    summed = torch.sum(token_embeddings * mask, dim=1)
    denom = torch.clamp(mask.sum(dim=1), min=1e-9)
    return summed / denom  # (B, H)

def chunk_text_words(text: str, chunk_size: int) -> List[str]:
    words = (text or "").split()
    if not words:
        return [""]
    return [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]

@torch.inference_mode()
def embed_chunks_batched(chunks: List[str], batch_size: int) -> np.ndarray:
    """Return L2-normalized embeddings for each chunk, shape (N, D)."""
    embs = []
    for i in range(0, len(chunks), batch_size):
        batch = chunks[i:i+batch_size]
        enc = tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(device)
        out = model(**enc)
        pooled = mean_pooling(out, enc["attention_mask"])
        pooled = F.normalize(pooled, p=2, dim=1)
        embs.append(pooled.detach().cpu().numpy())
    return np.vstack(embs) if embs else np.zeros((0, model.config.hidden_size), dtype=np.float32)

def build_text_embedding_map(
    texts: List[str],
    chunk_size_words: int = 250,
    batch_size: int = 64
) -> Dict[str, np.ndarray]:
    """
    Embed each unique text (with word-chunking) by:
      1) splitting into chunks,
      2) embedding all chunks in batches,
      3) averaging chunk embeddings,
      4) L2-normalizing final text embedding.
    Returns {text: embedding_vector}.
    """
    unique_texts = list(dict.fromkeys(texts))  # stable order, unique
    # Build flat chunk list with reverse index to text
    chunk_list: List[str] = []
    chunk_owner: List[int] = []  # which text index this chunk belongs to

    for ti, t in enumerate(unique_texts):
        chunks = chunk_text_words(t, chunk_size_words)
        for c in chunks:
            chunk_list.append(c)
            chunk_owner.append(ti)

    logger.info(f"Embedding {len(unique_texts)} unique texts via {len(chunk_list)} total chunks (batch_size={batch_size})")
    chunk_embs = embed_chunks_batched(chunk_list, batch_size=batch_size)  # (Nc, D)

    D = chunk_embs.shape[1] if chunk_embs.size else model.config.hidden_size
    sums = np.zeros((len(unique_texts), D), dtype=np.float32)
    counts = np.zeros((len(unique_texts),), dtype=np.int32)

    for ci, ti in enumerate(chunk_owner):
        sums[ti] += chunk_embs[ci]
        counts[ti] += 1

    # average + normalize
    text_map: Dict[str, np.ndarray] = {}
    for ti, t in enumerate(unique_texts):
        if counts[ti] == 0:
            vec = np.zeros((D,), dtype=np.float32)
        else:
            vec = sums[ti] / float(counts[ti])
            n = np.linalg.norm(vec)
            if n > 0:
                vec = vec / n
        text_map[t] = vec

    return text_map

def compute_qoa_series(
    annotated: pd.Series,
    outputs: pd.Series,
    text_emb: Dict[str, np.ndarray]
) -> np.ndarray:
    qoa = np.full((len(annotated),), np.nan, dtype=np.float32)

    for i, (a_raw, o_raw) in enumerate(zip(annotated.values, outputs.values)):
        a = safe_str(a_raw).strip()
        o = safe_str(o_raw).strip()

        if not a or not o:
            qoa[i] = np.nan
            continue

        # Exact-match shortcut (case-insensitive containment)
        if a.lower() in o.lower():
            qoa[i] = 1.0
            continue

        ea = text_emb.get(a)
        eo = text_emb.get(o)
        if ea is None or eo is None:
            qoa[i] = np.nan
            continue

        # Since vectors are L2-normalized, cosine == dot product
        qoa[i] = float(np.dot(ea, eo))

    return qoa

def main():
    p = argparse.ArgumentParser(description="Compute QoA (cosine similarity) for LLM outputs using sentence-transformers.")
    p.add_argument("-i", "--input-file", required=True, help="Input CSV path")
    p.add_argument("-o", "--output-file", required=True, help="Output CSV path")
    p.add_argument("--ann-col", default="Annotated Answer", help="Annotated answer column name")
    p.add_argument("--out-col", default="Final_Output", help="Model output column name")
    p.add_argument("--batch-size", type=int, default=64, help="Embedding batch size")
    p.add_argument("--chunk-size-words", type=int, default=250, help="Chunk size in words for long texts")
    args = p.parse_args()

    logger.info(f"Loading {args.input_file}")
    df = pd.read_csv(args.input_file, dtype=str)

    if args.ann_col not in df.columns or args.out_col not in df.columns:
        raise ValueError(f"Missing required columns. Need: '{args.ann_col}' and '{args.out_col}'. Found: {list(df.columns)}")

    ann = df[args.ann_col].fillna("").astype(str)
    out = df[args.out_col].fillna("").astype(str)

    # Build embedding map for unique texts across both columns
    all_texts = [safe_str(x).strip() for x in pd.concat([ann, out]).tolist()]
    text_emb = build_text_embedding_map(
        all_texts,
        chunk_size_words=args.chunk_size_words,
        batch_size=args.batch_size
    )

    logger.info("Computing QoA scores")
    df["Quality_of_Answer"] = compute_qoa_series(ann, out, text_emb)

    logger.info(f"Saving results to {args.output_file}")
    df.to_csv(args.output_file, index=False)

    logger.info("QoA summary:\n" + df["Quality_of_Answer"].describe().to_string())
    logger.info("Done.")

if __name__ == "__main__":
    main()
