In [1]:
import json
import torch
from tqdm import tqdm
from collections import Counter
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# ================= CONFIG =================

MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.3"
MAX_NEW_TOKENS = 80

DATASET_PATH = "data/context_situated_pun.csv"
DATASET_SPLIT = "train"

TEXT_JSONL = "cache/phase3_text.jsonl"
AUDIO_JSONL = "cache/phase3_audio.jsonl"

OUT_PATH = "cache/phase4_judge_text_vs_audio.jsonl"

TEXT_FIELD = "user_pun"

# ================= JUDGE PROMPT =================

JUDGE_PROMPT = """You are a strict evaluator of linguistic interpretations.

Your task:
Given a text and two structured interpretations, decide which interpretation
better determines whether the text is a pun and explains the linguistic mechanism.
Decide if Explanation 1 is better than Explanation 2, Explanation 2 is better than explanation 1 and if both explanations are of similar quality

Rules:
- Ignore writing style and verbosity.
- Judge only correctness of the decision and mechanism.
- If both interpretations are equally correct or equally incorrect, choose a tie.

Return ONLY valid JSON in exactly this format:
{{"Choice": "Explanation 1 is much better" | "Explanation 2 is much better" | "Explanation 1 and 2 are of similar quality",
 "Reason": "<short justification>"}}

Text:
{text}

Explanation 1:
Decision: {choice1}
Reason: {reason1}

Explanation 2:
Decision: {choice2}
Reason: {reason2}
"""

# ================= MODEL =================

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="auto",
).eval()

# ================= HELPERS =================

def load_json_map(path):
    with open(path, encoding="utf-8") as f:
        return {x["id"]: x for x in map(json.loads, f)}

def load_text_map():
    ds = load_dataset(
        "csv" if DATASET_PATH.endswith(".csv") else "json",
        data_files=DATASET_PATH,
        split=DATASET_SPLIT,
    )

    text_map = {}
    for idx, item in enumerate(ds):
        raw = item.get(TEXT_FIELD)
        if raw is None:
            continue
        text = str(raw).strip()
        if not text or text in {"{}", "{ }", "null", "None"}:
            continue
        text_map[idx] = text

    return text_map

def generate_judge(prompt: str):
    messages = [
        {"role": "system", "content": "You are a judge that outputs ONLY valid JSON."},
        {"role": "user", "content": prompt},
    ]

    inputs = tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True,
    ).to(device)

    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,
            temperature=0.0,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
        )

    prompt_len = inputs["input_ids"].shape[1]
    decoded = tokenizer.decode(out[0][prompt_len:], skip_special_tokens=True)

    start, end = decoded.find("{"), decoded.rfind("}")
    if start != -1 and end != -1:
        try:
            return json.loads(decoded[start:end + 1])
        except Exception:
            pass

    return {"Choice": "INVALID", "Reason": "Parse failure"}

# ================= RUN =================

def main():
    print("Loading original texts…")
    text_map = load_text_map()

    print("Loading normalized explanations…")
    text_exp = load_json_map(TEXT_JSONL)
    audio_exp = load_json_map(AUDIO_JSONL)

    ids = sorted(set(text_map) & set(text_exp) & set(audio_exp))
    votes = Counter()

    print(f"Judging {len(ids)} aligned items")

    with open(OUT_PATH, "w", encoding="utf-8") as f:
        for i in tqdm(ids):
            prompt = JUDGE_PROMPT.format(
                text=text_map[i],
                choice1=text_exp[i]["Choice"],
                reason1=text_exp[i]["Reason"],
                choice2=audio_exp[i]["Choice"],
                reason2=audio_exp[i]["Reason"],
            )

            judge = generate_judge(prompt)
            votes[judge.get("Choice", "INVALID")] += 1

            out = {
                "id": i,
                "judge": judge,
                "text_interpretation": text_exp[i],
                "audio_interpretation": audio_exp[i],
            }

            f.write(json.dumps(out, ensure_ascii=False) + "\n")

    total = sum(votes.values())
    print("\n=== RESULTS ===")
    for k, v in votes.items():
        print(f"{k}: {v} ({v/total*100:.1f}%)")

    print("Wrote:", OUT_PATH)

# ================= MAIN =================

if __name__ == "__main__":
    main()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

Loading original texts…
Loading normalized explanations…
Judging 500 aligned items


  0%|          | 0/500 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
100%|██████████| 500/500 [34:06<00:00,  4.09s/it]


=== RESULTS ===
Explanation 1 and 2 are of similar quality: 146 (29.2%)
Explanation 1 is much better: 207 (41.4%)
Explanation 2 is much better: 48 (9.6%)
INVALID: 98 (19.6%)
Explanation 2 is better: 1 (0.2%)
Wrote: cache/phase4_judge_text_vs_audio.jsonl



