In [3]:
# pip install -U google-genai
from google import genai
from google.genai import types
import json

import os
API_KEY=os.environ["GEMINI_API_KEY"] 
assert API_KEY, 'set api_key'
print(bool(os.getenv("GEMINI_API_KEY")))  # should be True
client = genai.Client(api_key=API_KEY)

ENDPOINT = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-001:generateContent"

SYSTEM_PROMPT = """You are a humor-perception classifier for a single individual.

Labels (choose exactly one):
- funny            : the person likely understands it AND finds it amusing
- not_funny        : the person likely understands it but does not find it amusing
- dont_understand  : only if comprehension is unlikely due to language/cultural knowledge gaps, obscure references, or heavy wordplay that the person is unlikely to get

Decision rules:
1) FIRST assess comprehension. Use dont_understand ONLY when there are strong signals of likely confusion (unfamiliar slang, obscure cultural references for that demographic, heavy wordplay unlikely to translate). Ambiguity alone is NOT enough.
2) If comprehension is plausible, DO NOT use dont_understand. Decide between funny vs not_funny.
3) When uncertain between funny vs not_funny, prefer not_funny (conservative choice) rather than dont_understand.
4) Keep dont_understand rare (roughly 5–15% in typical mixed audiences).
5) Use ONLY the provided info; do not stereotype beyond it.

Output JSON ONLY with:
{"label":"funny|not_funny|dont_understand","confidence":0.0-1.0,"short_reason":"<=25 words"}

Few-shot guidance:
Example A:
Joke: "I'm reading a book on anti-gravity — it's impossible to put down."
Demographics: age 26-35, gender female, ethnicity South Asian
Decision: {"label":"funny","confidence":0.74,"short_reason":"Understands pun on 'put down'; light, broadly relatable wordplay."}

Example B:
Joke: "I told my wife she was drawing her eyebrows too high. She looked surprised."
Demographics: age 36-45, gender male, ethnicity White/Caucasian
Decision: {"label":"funny","confidence":0.62,"short_reason":"Understands 'surprised' visual pun; mild dry humor."}

Example C:
Joke: "I prefer my puns intended."
Demographics: age 18-25, gender female, ethnicity East Asian
Decision: {"label":"not_funny","confidence":0.60,"short_reason":"Understands wordplay but likely not amusing for this person."}

Example D:
Joke: "We queued for ages at the chippy—proper knackered now."
Demographics: age 18-25, gender male, ethnicity North American
Decision: {"label":"dont_understand","confidence":0.78,"short_reason":"UK-specific slang and context likely unfamiliar; comprehension unlikely."}

Example E:
Joke: "The rotation in the 2011 Spurs' PnR was textbook—Pop would be proud."
Demographics: age 55-64, gender female, ethnicity South Asian
Decision: {"label":"dont_understand","confidence":0.72,"short_reason":"Niche NBA jargon/reference; low chance of comprehension."}

Example F:
Joke: "Why did the scarecrow get a promotion? He was outstanding in his field."
Demographics: age 46-55, gender male, ethnicity Middle Eastern
Decision: {"label":"not_funny","confidence":0.55,"short_reason":"Understands pun but likely finds it stale."}

"""

def _user_prompt(joke: str, age: str, gender: str, ethnicity: str) -> str:
    return f"""Classify the joke for this individual.

Joke:
\"\"\"{joke.strip()}\"\"\"

Demographics:
- age: {age}
- gender: {gender}
- ethnicity: {ethnicity}

Return only JSON with fields:
- label ∈ [funny, not_funny, dont_understand]
- confidence ∈ [0,1]
- short_reason (<= 25 words)
"""

def classify_with_gemini_rest(
    joke: str,
    age: str,
    gender: str,
    ethnicity: str,
    temperature: float = 0.0,
    max_output_tokens: int = 128,
    model_url: str = ENDPOINT,
):
    body = {
        "system_instruction": {"parts": [{"text": SYSTEM_PROMPT}]},
        "contents": [
            {"role": "user", "parts": [{"text": _user_prompt(joke, age, gender, ethnicity)}]}
        ],
        "generation_config": {
            "temperature": temperature,
            "max_output_tokens": max_output_tokens,
            "response_mime_type": "application/json"
        }
    }
    r = requests.post(f"{model_url}?key={API_KEY}", json=body, timeout=60)
    if r.status_code != 200:
        raise RuntimeError(f"Gemini REST error {r.status_code}: {r.text[:600]}")

    data = r.json()
    # If there are no candidates, the API says the prompt was blocked/invalid—surface promptFeedback. :contentReference[oaicite:3]{index=3}
    cands = data.get("candidates", [])
    if not cands:
        pf = data.get("promptFeedback", {})
        raise RuntimeError(f"No candidates. promptFeedback={pf}")

    # Extract first text part
    parts = cands[0].get("content", {}).get("parts", [])
    text = ""
    for p in parts:
        t = p.get("text")
        if t:
            text = t
            break
    if not text:
        raise RuntimeError(f"No text in first candidate: {cands[0]}")

    # Parse JSON (strip code fences if any)
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        if text.startswith("```"):
            text = text.strip("`")
            if "\n" in text: text = text.split("\n", 1)[-1]
            text = text.strip()
        return json.loads(text)

# --- Example ---
if __name__ == "__main__":
    out = classify_with_gemini_rest(
        joke="I'm on a whiskey diet — I've lost three days already.",
        age="26-35", gender="female", ethnicity="South Asian"
    )
    print(out)


KeyError: 'GEMINI_API_KEY'

In [42]:

import os, json, time
import pandas as pd
from pathlib import Path
from typing import Dict, Any, Tuple
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report, confusion_matrix

# CONFIG 
CSV_PATH = "preprocessed_humor_data.csv"
RANDOM_STATE = 42
RATE_LIMIT_S = 0.25            # ~4 req/s
LABELS = ["funny", "not_funny", "dont_understand"]
TEXT_COL = "joke_text"
AGE_COL = "age_bin"            # use "age_bin" (e.g., 18-25). Change to "age" if you prefer
GENDER_COL = "gender"
ETH_COL = "ethnicity"
LABEL_COL = "response"         # gold labels already canonical in your dataset


# Helper: normalize model output to our 3 labels
def _normalize_pred_label(x: Any) -> str:
    if x is None:
        return "dont_understand"
    s = str(x).strip().lower().replace("-", "_")
    if "understand" in s and ("dont" in s or "don't" in s or "did not" in s or "didn't" in s or "do not" in s):
        return "dont_understand"
    if "not" in s and "funny" in s:
        return "not_funny"
    if s in {"funny", "not_funny", "dont_understand"}:
        return s
    if "funny" in s:
        return "funny"
    return "dont_understand"

def split_70_10_20(df: pd.DataFrame, label_col: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """Row-level 70/10/20; stratified by label when possible."""
    y = df[label_col]
    # 70 / 30
    try:
        train_df, temp_df = train_test_split(df, test_size=0.30, stratify=y, random_state=RANDOM_STATE)
    except ValueError:
        train_df, temp_df = train_test_split(df, test_size=0.30, shuffle=True, random_state=RANDOM_STATE)
        print("[WARN] Stratified split (train/temp) failed; used non-stratified.")
    # 10 / 20 from the 30 => 1/3 vs 2/3
    try:
        val_df, test_df = train_test_split(
            temp_df, test_size=2/3, stratify=temp_df[label_col], random_state=RANDOM_STATE
        )
    except ValueError:
        val_df, test_df = train_test_split(temp_df, test_size=2/3, shuffle=True, random_state=RANDOM_STATE)
        print("[WARN] Stratified split (val/test) failed; used non-stratified.")
    return train_df.reset_index(drop=True), val_df.reset_index(drop=True), test_df.reset_index(drop=True)

def evaluate_macro_f1(df: pd.DataFrame) -> float:
    """Calls your classify_with_gemini_rest on each row and returns Macro-F1."""
    y_true, y_pred = [], []

    for i, row in df.iterrows():
        gold = str(row[LABEL_COL]).strip().lower()
        if gold not in {"funny", "not_funny", "dont_understand"}:
            continue  # skip any unexpected label just in case

        joke      = "" if pd.isna(row[TEXT_COL]) else str(row[TEXT_COL])
        age       = "" if pd.isna(row[AGE_COL])  else str(row[AGE_COL])
        gender    = "" if pd.isna(row[GENDER_COL]) else str(row[GENDER_COL])
        ethnicity = "" if pd.isna(row[ETH_COL]) else str(row[ETH_COL])

       
        # classify_with_gemini_rest(joke, age, gender, ethnicity, temperature=0.0, max_output_tokens=128)
        try:
            out: Dict[str, Any] = classify_with_gemini_rest(joke, age, gender, ethnicity)
            pred = _normalize_pred_label(out.get("label"))
        except Exception:
            pred = "dont_understand"
        # -

        y_true.append(gold)
        y_pred.append(pred)

        if RATE_LIMIT_S:
            time.sleep(RATE_LIMIT_S)
        if (i + 1) % 25 == 0:
            print(f"Processed {i+1}/{len(df)} rows")

    macro_f1 = float(f1_score(y_true, y_pred, average="macro", labels=LABELS))
    print("\nConfusion Matrix (rows=true, cols=pred):")
    print(pd.DataFrame(confusion_matrix(y_true, y_pred, labels=LABELS), index=LABELS, columns=LABELS))
    print("\nClassification Report:\n", classification_report(y_true, y_pred, labels=LABELS, digits=4))
    return macro_f1

#  MAIN 
# Load your dataset
df_full = pd.read_csv(CSV_PATH)

# (Optional) drop rows missing essentials
df_full = df_full.dropna(subset=[TEXT_COL, LABEL_COL]).reset_index(drop=True)

# Split 70/10/20
train_df, val_df, test_df = split_70_10_20(df_full, LABEL_COL)
print(f"Split sizes -> train: {len(train_df)} | val: {len(val_df)} | test: {len(test_df)}")

# "Train with prompt": the “training” is your fixed prompt inside classify_with_gemini_rest,
# so we directly evaluate it on val/test.

print("\n=== Evaluating on Validation set ===")
val_macro = evaluate_macro_f1(val_df)
print(f"\nValidation Macro-F1: {val_macro:.4f}")

print("\n=== Evaluating on Test set ===")
test_macro = evaluate_macro_f1(test_df)
print(f"\nTest Macro-F1: {test_macro:.4f}")

print("\nDone.")


Split sizes -> train: 1071 | val: 153 | test: 306

=== Evaluating on Validation set ===
Processed 25/153 rows
Processed 50/153 rows
Processed 75/153 rows
Processed 100/153 rows
Processed 125/153 rows
Processed 150/153 rows

Confusion Matrix (rows=true, cols=pred):
                 funny  not_funny  dont_understand
funny                9          1               36
not_funny           12          9               75
dont_understand      1          0               10

Classification Report:
                  precision    recall  f1-score   support

          funny     0.4091    0.1957    0.2647        46
      not_funny     0.9000    0.0938    0.1698        96
dont_understand     0.0826    0.9091    0.1515        11

       accuracy                         0.1830       153
      macro avg     0.4639    0.3995    0.1953       153
   weighted avg     0.6936    0.1830    0.1970       153


Validation Macro-F1: 0.1953

=== Evaluating on Test set ===
Processed 25/306 rows
Processed 50/306 rows