In [8]:
# =========================
# Cell 0: Imports + paths
# =========================

from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, List

import pandas as pd

def find_upwards(start: Path, target: str, max_levels: int = 8) -> Path:
    cur = start.resolve()
    for _ in range(max_levels):
        candidate = cur / target
        if candidate.exists() and candidate.is_dir():
            return candidate
        cur = cur.parent
    raise FileNotFoundError(f"Could not find '{target}' by walking upwards from {start}")

CWD = Path.cwd()
DATASETS_SFT = find_upwards(CWD, "datasets_sft")

print("CWD:", CWD)
print("DATASETS_SFT:", DATASETS_SFT)


In [9]:
# ==========================================
# Cell 1: Collect candidate files
# ==========================================
# We will evaluate thresholds on the *.filtered.csv files (they already contain humor_prob).
# If a file does NOT have humor_prob, we will compute it (optional later cells).

CATEGORIES = ["general", "pun", "satire"]

def list_filtered_csvs() -> List[Path]:
    out: List[Path] = []
    for cat in CATEGORIES:
        d = DATASETS_SFT / cat
        if not d.exists():
            continue
        out.extend(sorted(d.glob("*.filtered.csv")))
    return out

FILES = list_filtered_csvs()
print("Found filtered files:")
for p in FILES:
    print("-", p.relative_to(DATASETS_SFT))


In [10]:
# ==========================================
# Cell 2: Threshold grid you want to test
# ==========================================

# Change as you like
THRESHOLDS = [round(x, 2) for x in
              [0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95]]

THRESHOLDS


In [11]:
# ==========================================================
# Cell 3: Compute "kept counts" per threshold for each file
# ==========================================================

def counts_for_thresholds(df: pd.DataFrame, thresholds: List[float]) -> Dict[float, int]:
    if "humor_prob" not in df.columns:
        raise ValueError("This file does not contain 'humor_prob'. Run the optional scoring notebook first.")

    probs = pd.to_numeric(df["humor_prob"], errors="coerce").fillna(0.0)

    out: Dict[float, int] = {}
    for t in thresholds:
        out[t] = int((probs >= t).sum())
    return out

rows = []
for f in FILES:
    df = pd.read_csv(f)

    # Some filtered files might already be thresholded (still fine), we just use their humor_prob distribution.
    counts = counts_for_thresholds(df, THRESHOLDS)

    row = {
        "category": f.parent.name,
        "file": f.name,
        "rows_in_file": int(len(df)),
        **{f"t>={t:.2f}": counts[t] for t in THRESHOLDS},
    }
    rows.append(row)

summary = pd.DataFrame(rows).sort_values(["category", "file"]).reset_index(drop=True)
display(summary)


In [12]:
# ==========================================================
# Cell 4: Save report to JSON + CSV (so it is easy to compare)
# ==========================================================

out_csv = DATASETS_SFT / "threshold_sweep_report.csv"
out_json = DATASETS_SFT / "threshold_sweep_report.json"

summary.to_csv(out_csv, index=False)
out_json.write_text(summary.to_json(orient="records", indent=2), encoding="utf-8")

print("Wrote:", out_csv)
print("Wrote:", out_json)


In [14]:
# ==========================================================
# Cell 5: Optional - pretty per-file view
# ==========================================================

def show_file_curve(file_name_contains: str):
    match = summary[summary["file"].str.contains(file_name_contains, case=False, regex=False)]
    if match.empty:
        print("No match.")
        return
    display(match)

# Example:
show_file_curve("general_merged.filtered")


In [15]:
# ==========================================
# Cell: counts per base dataset in GENERAL at threshold 0.95
# ==========================================

import pandas as pd
from pathlib import Path

THRESHOLD = 0.95

general_file = DATASETS_SFT / "general" / "general_merged.filtered.csv"
if not general_file.exists():
    raise FileNotFoundError(f"Missing: {general_file}")

df = pd.read_csv(general_file)

# Safety: ensure required columns exist
required = {"id", "humor_prob"}
missing = required - set(df.columns)
if missing:
    raise ValueError(f"Missing columns in {general_file.name}: {missing}")

# Apply threshold
df["humor_prob"] = pd.to_numeric(df["humor_prob"], errors="coerce").fillna(0.0)
kept = df[df["humor_prob"] >= THRESHOLD].copy()

# Extract base dataset name from id: "<dataset>::<split>::<row_id>"
# If parsing fails, put "UNKNOWN"
def extract_base_dataset(x: str) -> str:
    if not isinstance(x, str):
        return "UNKNOWN"
    parts = x.split("::", 2)
    return parts[0] if len(parts) >= 1 and parts[0] else "UNKNOWN"

kept["base_dataset"] = kept["id"].map(extract_base_dataset)

counts = (
    kept.groupby("base_dataset")
        .size()
        .sort_values(ascending=False)
        .reset_index(name="count_at_0.95")
)

print("GENERAL file:", general_file)
print("Total rows in general:", len(df))
print(f"Rows kept at threshold {THRESHOLD}:", len(kept))

display(counts)


In [20]:
def list_short_jokes_texts_from_general(
    datasets_sft_root,
    threshold: float = 0.95,
    general_filename: str = "general_merged.filtered.csv",
    max_print: int | None = 200,
) -> list[str]:
    """
    Extract all rows from the GENERAL merged file that come from base dataset 'short_jokes'
    (parsed from id: "<dataset>::<split>::<row_id>") and have humor_prob >= threshold.

    Returns a list of raw_text strings and also prints them (optionally capped).
    """
    import pandas as pd
    from pathlib import Path

    datasets_sft_root = Path(datasets_sft_root)
    general_path = datasets_sft_root / "general" / general_filename
    if not general_path.exists():
        raise FileNotFoundError(f"Missing: {general_path}")

    df = pd.read_csv(general_path)

    required = {"id", "raw_text", "humor_prob"}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"Missing columns in {general_path.name}: {missing}")

    # Parse + filter
    df["humor_prob"] = pd.to_numeric(df["humor_prob"], errors="coerce").fillna(0.0)

    base_ds = df["id"].astype(str).str.split("::", n=2, expand=True)[0]
    mask = (base_ds == "one_liners") & (df["humor_prob"] >= float(threshold))

    texts = df.loc[mask, "raw_text"].fillna("").astype(str).tolist()

    print("GENERAL:", general_path)
    print("Base dataset:", "one_liners")
    print("Threshold:", threshold)
    print("Count:", len(texts))
    print()

    if max_print is None:
        to_show = texts
    else:
        to_show = texts[:max_print]

    for i, t in enumerate(to_show, start=1):
        print(f"{i}. {t}")

    if max_print is not None and len(texts) > max_print:
        print(f"\n... printed {max_print} of {len(texts)} texts")

    return texts


In [21]:
texts = list_short_jokes_texts_from_general(DATASETS_SFT, threshold=0.95, max_print=100)
