In [1]:
import warnings

warnings.filterwarnings("ignore")

In [2]:
import csv
import math
import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

In [3]:
# model_name="meta-llama/Llama-3.2-3B-Instruct"
model_path = "models/meta-llama"

tokenizer = AutoTokenizer.from_pretrained(model_path)

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
)

model.eval()

`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 67.00it/s]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((3072,), eps=1e-05)
    (

In [4]:
text_gen = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        do_sample = True,
        max_new_tokens=200,
        return_full_text=False
    )

Device set to use cuda:0


In [5]:
LANG_MAP = {
    "en": "English",
    "es": "Spanish",
    "zh": "Chinese"
}

In [6]:
def get_language_from_filename(filename):
    lang_code = filename.split("-")[-1].replace(".tsv", "")
    return LANG_MAP[lang_code]

In [7]:
def load_prompt_template(path):
    with open(path, "r", encoding="utf-8") as f:
        return f.read()

In [8]:
def load_file(path):
    with open(path, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f, delimiter="\t")
        return list(reader)

In [9]:
def build_prompt(row, lang, word_template, headline_template):
    if row["word1"] != "-" and row["word2"] != "-":
        return word_template.format(
            lang=lang,
            word1=row["word1"],
            word2=row["word2"]
        )
    else:
        return headline_template.format(
            lang=lang,
            headline=row["headline"]
        )

In [10]:
def generate(prompt, temperature=0.9, top_p=0.9):
    result = text_gen(
        prompt,
        temperature=temperature,
        top_p=top_p,
    )
    return result[0]['generated_text'].strip()

In [11]:
def compute_perplexity(text):
    device = next(model.parameters()).device

    enc = tokenizer(text, return_tensors="pt")
    input_ids = enc["input_ids"].to(device)
    attention_mask = enc.get("attention_mask", None)
    if attention_mask is not None:
        attention_mask = attention_mask.to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss

    return math.exp(loss.item())

In [12]:
from sentence_transformers import SentenceTransformer, util

embedder = SentenceTransformer("all-MiniLM-L6-v2")

In [13]:
def headline_similarity(headline, joke):
    emb1 = embedder.encode(headline, convert_to_tensor=True)
    emb2 = embedder.encode(joke, convert_to_tensor=True)
    return util.cos_sim(emb1, emb2).item()

In [14]:
import re

def contains_words_sanity_check(joke, word1, word2):
    joke_lower = joke.lower()
    word1_lower = word1.lower()
    word2_lower = word2.lower()

    pattern1 = r'\b' + re.escape(word1_lower) + r'\b'
    pattern2 = r'\b' + re.escape(word2_lower) + r'\b'

    contains_word1 = bool(re.search(pattern1, joke_lower))
    contains_word2 = bool(re.search(pattern2, joke_lower))

    # returns true only if both words are present
    return contains_word1 and contains_word2, contains_word1, contains_word2

In [15]:
MAX_PPL = 100
MIN_HEADLINE_SIM = 0.30

In [16]:
data_files = [
    "data/task-a-en.tsv",
    "data/task-a-es.tsv",
    "data/task-a-zh.tsv"
]

for data_file in data_files:
    lang = get_language_from_filename(data_file)
    print(f"LANGUAGE: {lang}")

    lang_code = data_file.split("-")[-1].replace(".tsv", "")

    word_prompt_template = load_prompt_template(f"prompts/{lang_code}/word-inclusion")
    headline_prompt_template = load_prompt_template(f"prompts/{lang_code}/headline-based")

    rows = load_file(data_file)

    selected_rows = (
        rows[0:5]      # headline-based
        # + rows[101:106]  # word-inclusion
    )

    results = []

    for row in selected_rows:
        prompt = build_prompt(
            row,
            lang,
            word_prompt_template,
            headline_prompt_template
        )

        # detect strategy
        strategy = "words" if row["word1"] != "-" and row["word2"] != "-" else "headline"

        candidates = []
        print(f"Input: {row}")
        print(f"Strategy: {strategy}")

        for _ in range(5):
            joke = generate(prompt)
            ppl = compute_perplexity(joke)

            candidate = {
                "joke": joke,
                "perplexity": ppl,
                "valid": True
            }

            if strategy == "headline":
                sim = headline_similarity(row["headline"], joke)
                candidate["headline_similarity"] = sim
                if sim < MIN_HEADLINE_SIM:
                    candidate["valid"] = False

            elif strategy == "words":
                ok, has_w1, has_w2 = contains_words_sanity_check(
                    joke, row["word1"], row["word2"]
                )
                candidate["contains_word1"] = has_w1
                candidate["contains_word2"] = has_w2
                if not ok:
                    candidate["valid"] = False

            if ppl > MAX_PPL:
                candidate["valid"] = False

            candidates.append(candidate)

            print(f"Joke: {joke}")
            print(f"PPL: {ppl}")
            if strategy == "headline":
                print(f"Headline similarity: {sim}")
            else:
                print(f"Contains word1: {has_w1}, word2: {has_w2}")
            print(f"Valid: {candidate['valid']}")
            print()

        # best selection
        valid_candidates = [c for c in candidates if c["valid"]]

        if valid_candidates:
            best = max(valid_candidates, key=lambda c: c["perplexity"])
            reason = "filtered best"
        else:
            best = max(candidates, key=lambda c: c["perplexity"])
            reason = "no valid candidates"

        print("[FINAL] Selected joke:", best["joke"])
        print(f"PPL: {best['perplexity']}")
        if strategy == "headline":
            print(f"Headline similarity: {best.get('headline_similarity')}")
        print(f"Selection reason: {reason}")
        print("-" * 60)

        results.append({
            "id": row["id"],
            "text": best["joke"],
        })

    # output_path = os.path.join("results", os.path.basename(data_file))
    # with open(output_path, "w", encoding="utf-8", newline="") as f:
    #     writer = csv.writer(f, delimiter="\t")
    #     writer.writerow(["id", "text"])
    #     for r in results:
    #         writer.writerow([r["id"], r["text"]])
    # print(f"Saved results to {output_path}")

LANGUAGE: English


FileNotFoundError: [Errno 2] No such file or directory: 'prompts/English/word-inclusion'