In [1]:
from pathlib import Path
import json
import pandas as pd
import re

# ---- Paths (edit these) ----
SFT_JSONL_PATH = Path("../data/merged_sft_jokes.jsonl")          # or your local path
DPO_CSV_PATH   = Path("../data/generated_data/dpo_final_set.csv")              # or your local path

# Output of this step
MERGED_OUT_PATH = Path("merged_for_anchor_sampling.parquet")      # fast intermediate
FINAL_OUT_PATH  = Path("anchors_dataset.csv")                     # final requested output


def normalize_one_line(s: str) -> str:
    s = "" if s is None else str(s)
    s = re.sub(r"\s+", " ", s).strip()
    return s


def load_sft_jsonl(path: Path) -> pd.DataFrame:
    rows = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)

            # supports either schema:
            # - setup/punchline
            # - prompt/response
            setup = obj.get("setup")
            punchline = obj.get("punchline")
            if setup is None and punchline is None:
                setup = obj.get("prompt", "")
                punchline = obj.get("response", "")

            setup = normalize_one_line(setup)
            punchline = normalize_one_line(punchline)

            if setup and punchline:
                rows.append({"setup": setup, "punchline": punchline, "source": "sft"})
    return pd.DataFrame(rows)


def load_dpo_csv(path: Path) -> pd.DataFrame:
    df = pd.read_csv(path)

    # expected: setup + chosen_punchline
    if "setup" not in df.columns or "chosen_punchline" not in df.columns:
        raise ValueError(f"Expected columns setup + chosen_punchline in: {path}")

    out = pd.DataFrame({
        "setup": df["setup"].map(normalize_one_line),
        "punchline": df["chosen_punchline"].map(normalize_one_line),
        "source": "dpo_chosen",
    })

    out = out[(out["setup"] != "") & (out["punchline"] != "")]
    return out


df_sft = load_sft_jsonl(SFT_JSONL_PATH)
df_dpo = load_dpo_csv(DPO_CSV_PATH)

df = pd.concat([df_sft, df_dpo], ignore_index=True)
df["joke"] = (df["setup"] + " " + df["punchline"]).map(normalize_one_line)

# basic cleanup
df = df[df["joke"] != ""].drop_duplicates(subset=["setup", "punchline"]).reset_index(drop=True)

print("SFT rows:", len(df_sft))
print("DPO chosen rows:", len(df_dpo))
print("Merged unique rows:", len(df))

# optional: save intermediate (fast reload)
df.to_parquet(MERGED_OUT_PATH, index=False)
print("Saved:", MERGED_OUT_PATH)


In [2]:
# If you already have spaCy and the model, you can skip the installs.

try:
    import spacy
except Exception:
    raise RuntimeError("Install spaCy first: pip install spacy")

try:
    _NLP = spacy.load("en_core_web_sm")
except Exception:
    raise RuntimeError("Download the model: python -m spacy download en_core_web_sm")

print("spaCy loaded:", _NLP)


In [3]:
import random

GENERIC_NOUNS = {
    "thing", "things", "stuff", "something", "anything", "everything",
    "someone", "anyone", "everyone", "somebody", "anybody", "everybody",
    "person", "people", "man", "men", "woman", "women", "guy", "guys", "girl", "girls", "kid", "kids",
    "friend", "friends", "family",
    "time", "day", "week", "month", "year", "moment",
    "place", "home", "house", "room",
    "job", "work", "boss", "company",
    "way", "lot", "kind", "sort", "part", "case", "point", "problem", "idea", "fact", "question", "answer",
    "joke", "story",
}

STOPWORDS = {
    "the", "a", "an", "and", "or", "but", "if", "then", "else", "when", "while",
    "to", "of", "in", "on", "for", "at", "by", "with", "from", "as",
    "is", "are", "was", "were", "be", "been", "being",
    "it", "its", "this", "that", "these", "those",
    "i", "me", "my", "mine", "you", "your", "yours", "we", "our", "ours", "they", "their", "theirs",
    "he", "him", "his", "she", "her", "hers",
    "do", "does", "did", "doing",
    "not", "no", "yes",
}

def base_form_for_similarity(w: str) -> str:
    x = w.lower()
    if x.endswith("ies") and len(x) > 4:
        return x[:-3] + "y"
    if x.endswith("s") and not x.endswith("ss") and len(x) > 3:
        return x[:-1]
    return x

def noun_candidates(text: str, *, min_len: int = 3, prefer_common_nouns: bool = True) -> list[tuple[str, float]]:
    """
    Returns list of (surface, weight).
    Surface keeps original casing for exact-match constraints.
    Weight down-weights proper nouns if prefer_common_nouns is True.
    """
    text = normalize_one_line(text)
    if not text:
        return []

    doc = _NLP(text)
    out = []
    seen = set()

    for tok in doc:
        if tok.pos_ not in {"NOUN", "PROPN"}:
            continue

        surface = tok.text.strip()
        if len(surface) < min_len:
            continue
        if not surface.isalpha():
            continue

        lower = surface.lower()
        if len(lower) < min_len:
            continue
        if lower in STOPWORDS or lower in GENERIC_NOUNS:
            continue
        if lower in seen:
            continue

        is_proper = (tok.pos_ == "PROPN")
        weight = 1.0
        if prefer_common_nouns and is_proper:
            weight = 0.25

        out.append((surface, weight))
        seen.add(lower)

    return out

def choose_two_anchors(setup: str, punchline: str, *, seed: int, prefer_common_nouns: bool = True) -> tuple[str, str] | None:
    rnd = random.Random(seed)

    setup_c = noun_candidates(setup, prefer_common_nouns=prefer_common_nouns)
    punch_c = noun_candidates(punchline, prefer_common_nouns=prefer_common_nouns)

    if not setup_c and not punch_c:
        return None

    def pick_one(cands):
        if not cands:
            return None
        words, weights = zip(*cands)
        return rnd.choices(words, weights=weights, k=1)[0]

    # mix: 50% one+one, 25% two-setup, 25% two-punchline
    r = rnd.random()
    if r < 0.50:
        a = pick_one(setup_c) or pick_one(punch_c)
        b = pick_one(punch_c) or pick_one(setup_c)
    elif r < 0.75:
        a = pick_one(setup_c) or pick_one(punch_c)
        b = pick_one(setup_c) or pick_one(punch_c)
    else:
        a = pick_one(punch_c) or pick_one(setup_c)
        b = pick_one(punch_c) or pick_one(setup_c)

    if not a or not b:
        return None
    if a == b:
        return None
    if base_form_for_similarity(a) == base_form_for_similarity(b):
        return None

    return (a, b)


In [4]:
# Reload merged table if you want:
# df = pd.read_parquet(MERGED_OUT_PATH)

BASE_SEED = 1337
ROWS_PER_JOKE_TARGET = 2          # try 2; if not possible it will keep 1
MAX_TRIES_PER_ROW = 40            # resampling tries

out_rows = []

for i, row in df.iterrows():
    setup = row["setup"]
    punchline = row["punchline"]
    joke = row["joke"]

    used_pairs = set()

    produced = 0
    for k in range(ROWS_PER_JOKE_TARGET):
        pair = None

        for t in range(MAX_TRIES_PER_ROW):
            seed = BASE_SEED + (i * 10_000) + (k * 100) + t
            pair = choose_two_anchors(setup, punchline, seed=seed, prefer_common_nouns=True)
            if pair is None:
                continue

            a1, a2 = pair
            # unordered dedup per joke
            key = "||".join(sorted([a1.lower(), a2.lower()]))
            if key in used_pairs:
                pair = None
                continue

            used_pairs.add(key)
            break

        if pair is None:
            continue

        a1, a2 = pair
        out_rows.append({
            "anchor1": a1,
            "anchor2": a2,
            "joke": joke,
        })
        produced += 1

    # if we could not produce any anchors, skip the joke entirely
    # (this happens when no nouns survive filtering)
    if produced == 0:
        continue

df_out = pd.DataFrame(out_rows)

print("Final rows:", len(df_out))
print(df_out.head(10))

df_out.to_csv(FINAL_OUT_PATH, index=False)
print("Saved:", FINAL_OUT_PATH)


In [5]:
from pathlib import Path
import importlib.util

PB_PATH = Path("../inference/task_a/two_words/prompt_builder_two_words.py")  # change if needed

spec = importlib.util.spec_from_file_location("prompt_builder_two_words", str(PB_PATH))
pb = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(pb)

print("Loaded:", PB_PATH)


In [6]:
import requests
from functools import lru_cache

WIKI_HEADERS = {
    "User-Agent": "MWAHAHA/1.0 (contact: dardemtum@gmail.com) humor-generation"
}

# Replace the function in-memory (no file editing needed)
@lru_cache(maxsize=2048)
def get_wikipedia_extract_cached_with_ua(word: str) -> str:
    w = pb.safe_word(word)
    if not w:
        return ""

    url = pb.WIKI_SUMMARY_URL.format(requests.utils.quote(w))
    try:
        r = requests.get(url, headers=WIKI_HEADERS, timeout=12)
        if r.status_code != 200:
            return ""
        data = r.json()
        extract = data.get("extract", "") or ""
        return pb.normalize_one_line(extract)
    except Exception:
        return ""

pb.get_wikipedia_extract_cached = get_wikipedia_extract_cached_with_ua
pb.get_wikipedia_extract_cached.cache_clear()

print("Patched pb.get_wikipedia_extract_cached with User-Agent.")


In [7]:
import pandas as pd

ANCHORS_PATH = Path("anchors_dataset.csv")  # change to your actual file path

df = pd.read_csv(ANCHORS_PATH)
assert {"anchor1", "anchor2"}.issubset(df.columns)

unique_words = sorted(set(df["anchor1"].astype(str)) | set(df["anchor2"].astype(str)))
print("Rows:", len(df))
print("Unique anchors:", len(unique_words))
print("Sample:", unique_words[:20])


In [10]:
import time
import json
import re
from pathlib import Path

extracts = {}

SLEEP_SECONDS = 0.25
PRINT_EVERY = 50

def normalize_anchor_for_wiki(word: str) -> list[str]:
    """
    Return a list of candidates to try on Wikipedia, in order.
    Keep the original first, then try de-pluralized / de-inflected variants.
    """
    w = pb.safe_word(word)
    if not w:
        return []

    # Keep original
    candidates = [w]

    wl = w.lower()

    # If all caps (BRIANS), try title case and lower
    if w.isupper() and len(w) >= 4:
        candidates.append(w.title())
        candidates.append(w.lower())

    # Simple English inflection fixes (cheap and surprisingly effective)
    # bullies -> bully
    if wl.endswith("ies") and len(wl) > 4:
        candidates.append(w[:-3] + "y")

    # boxes / watches / classes -> box / watch / class (approx)
    if wl.endswith("es") and len(wl) > 4:
        candidates.append(w[:-2])

    # cats -> cat, astronauts -> astronaut
    if wl.endswith("s") and len(wl) > 3 and not wl.endswith("ss"):
        candidates.append(w[:-1])

    # Deduplicate while preserving order
    seen = set()
    out = []
    for c in candidates:
        c2 = c.strip()
        if not c2:
            continue
        key = c2.lower()
        if key in seen:
            continue
        seen.add(key)
        out.append(c2)

    return out


def get_extract_with_fallbacks(word: str) -> tuple[str, str]:
    """
    Returns (extract, used_term). Empty extract => used_term is the last tried candidate.
    Uses pb.get_wikipedia_extract_cached (already patched with User-Agent).
    """
    candidates = normalize_anchor_for_wiki(word)
    if not candidates:
        return ("", "")

    last = candidates[-1]
    for cand in candidates:
        ex = pb.get_wikipedia_extract_cached(cand)
        if ex:
            return (ex, cand)
        last = cand
    return ("", last)


t0 = time.time()
for i, w in enumerate(unique_words, start=1):
    w_clean = pb.safe_word(w)
    if not w_clean:
        continue

    ex, used = get_extract_with_fallbacks(w_clean)
    extracts[w_clean] = ex

    if i % PRINT_EVERY == 0:
        elapsed = time.time() - t0
        print(f"{i}/{len(unique_words)} done (elapsed {elapsed:.1f}s). Example word={w_clean!r}, used={used!r}, has_extract={bool(ex)}")
        if ex:
            print(ex[:400])

    time.sleep(SLEEP_SECONDS)

print("Done. Non-empty extracts:", sum(1 for v in extracts.values() if v))

CACHE_PATH = Path("wiki_extract_cache.json")
CACHE_PATH.write_text(json.dumps(extracts, ensure_ascii=False, indent=2), encoding="utf-8")
print("Saved cache:", CACHE_PATH)


In [11]:
# Freeze pb.get_wikipedia_extract_cached to ONLY read from the local dict
def get_wikipedia_extract_from_local_cache(word: str) -> str:
    w = pb.safe_word(word)
    return extracts.get(w, "") or ""

pb.get_wikipedia_extract_cached = get_wikipedia_extract_from_local_cache

print("pb.get_wikipedia_extract_cached now reads from local cache only.")


In [12]:
def facts_block_row(row) -> str:
    return pb.format_facts_block(str(row["anchor1"]), str(row["anchor2"]))

df["facts_block"] = df.apply(facts_block_row, axis=1)

OUT_PATH = Path("anchors_with_facts.parquet")
df.to_parquet(OUT_PATH, index=False)
print("Saved:", OUT_PATH)

# quick preview
df[["anchor1", "anchor2", "facts_block"]].head(3)


In [13]:
from pathlib import Path

# Save TSV with anchors + facts + joke
TSV_PATH = Path("anchors_with_facts.tsv")

needed = ["anchor1", "anchor2", "facts_block", "joke"]
missing = [c for c in needed if c not in df.columns]
if missing:
    raise ValueError(f"Missing columns in df: {missing}")

df[needed].to_csv(TSV_PATH, sep="\t", index=False)
print("Saved:", TSV_PATH.resolve())


In [14]:
df[needed]