In [None]:
"""
Psychological Trait Dataset Generator v2
==========================================
English-only dataset for SAE/mechanistic interpretability research.

Design:
  - 6 traits (Big Five + Narcissism)
  - 20 diverse topics
  - 3 repetitions per trait × topic combination
  - 3 data types: A_explicit, B_implicit, C_baseline
  - Total: 6 × 20 × 3 × 3 = 1080 rows (~918 after validation)
  - n=60 per group → sufficient for medium effect size (d=0.5, α=0.05, power=0.80)

Statistical power:
  - Medium effect (d=0.5): need n=25, have n=60 ✓
  - Small effect  (d=0.2): need n=155, have n=60 (partial coverage)

Cost estimate: ~$0.06, ~14 minutes on Gemini 3 Flash via OpenRouter.

Dependencies:
  pip install polars httpx python-dotenv tqdm scipy
"""

import os
import json
import time
import math
import httpx
import polars as pl
from itertools import product
from tqdm import tqdm
from dotenv import load_dotenv

load_dotenv()

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
openrouter_key = user_secrets.get_secret("OPENROUTER_API_KEY")

# ── Config ────────────────────────────────────────────────────────────────────

OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", openrouter_key)
MODEL              = "google/gemini-3-flash-preview"
TARGET_WORDS       = "45-60 words"    # controls token length consistency
REPS_PER_COMBO     = 3                # different texts per (trait, topic) pair
SLEEP_BETWEEN      = 0.6              # seconds between API calls
MAX_RETRIES        = 3                # retries on validation failure

# ── Traits ────────────────────────────────────────────────────────────────────
# (trait_name, high_label, low_label, banned_keywords)
# Banned keywords: must NOT appear in B_implicit / C_baseline texts.
# Rule: if this word makes the label obvious, ban it.

TRAITS = [
    (
        "Neuroticism",
        "Neurotic",
        "Emotionally Stable",
        ["neurotic", "neuroticism", "anxious", "anxiety", "worried", "worry",
         "nervous", "panic", "stressed", "stress", "tense", "calm", "stable",
         "relaxed", "composed", "serene"],
    ),
    (
        "Extraversion",
        "Extravert",
        "Introvert",
        ["extravert", "extrovert", "introvert", "extroversion", "introversion",
         "sociable", "outgoing", "shy", "reserved", "withdrawn", "social",
         "loner", "people-person", "quiet"],
    ),
    (
        "Conscientiousness",
        "Conscientious",
        "Impulsive",
        ["conscientious", "conscientiousness", "impulsive", "impulsivity",
         "disciplined", "discipline", "organized", "disorganized", "careful",
         "careless", "diligent", "lazy", "responsible", "irresponsible",
         "systematic", "methodical"],
    ),
    (
        "Agreeableness",
        "Agreeable",
        "Antagonistic",
        ["agreeable", "agreeableness", "antagonistic", "antagonism",
         "cooperative", "uncooperative", "empathetic", "empathy",
         "compassionate", "cold", "warm", "hostile", "kind", "unkind",
         "trusting", "suspicious"],
    ),
    (
        "Openness",
        "Open",
        "Closed",
        ["openness", "open-minded", "closed-minded", "creative", "uncreative",
         "imaginative", "conventional", "curious", "incurious", "intellectual",
         "narrow-minded", "artistic", "traditional", "innovative"],
    ),
    (
        "Narcissism",
        "Narcissist",
        "Humble",
        ["narcissist", "narcissism", "narcissistic", "humble", "humility",
         "arrogant", "arrogance", "egotistical", "ego", "self-centered",
         "modest", "modesty", "conceited", "vain", "grandiose"],
    ),
]

# ── Topics ────────────────────────────────────────────────────────────────────
# 20 diverse everyday situations — held constant across all traits.
# Diversity across domains: domestic, social, work, emotional, planning.

TOPICS = [
    # Domestic
    "making a cup of tea or coffee",
    "cooking dinner after a long day",
    "dealing with a broken household appliance",
    "describing their morning routine",
    # Social
    "meeting someone new at a party",
    "giving directions to a stranger",
    "talking about the weather with a neighbour",
    "reacting to a friend cancelling plans last minute",
    # Work / productivity
    "being late to an important meeting",
    "receiving critical feedback on their work",
    "explaining how they organise their workspace",
    "discussing a recent project they completed",
    # Emotional / evaluative
    "reviewing a film they just watched",
    "reacting to unexpected good news",
    "reacting to a minor personal failure",
    "talking about a recent argument with someone",
    # Planning / future
    "planning a weekend trip",
    "deciding what to have for lunch",
    "choosing a new book or TV show to start",
    "reflecting on a goal they have for next year",
]

# ── Prompt builders ───────────────────────────────────────────────────────────

def prompt_A(trait: str, pole: str, keywords: list[str], topic: str, rep: int) -> str:
    """
    Type A — Explicit: third-person description containing trait keywords.
    Used to discover SAE latents. Must include at least one banned keyword.
    """
    kw_sample = ", ".join(f'"{k}"' for k in keywords[:5])
    variety_hint = (
        "" if rep == 0 else
        " Use a different sentence structure and framing than a typical description."
        if rep == 1 else
        " Make this feel like a clinical psychological observation."
    )
    return f"""You are generating data for a psycholinguistics research dataset.

Task: Write 1-2 sentences in THIRD PERSON ({TARGET_WORDS}) describing a person with a \
strongly expressed trait of "{pole}" ({trait}), in the context of: "{topic}".

Requirements:
- MUST include at least one of these keywords: {kw_sample}
- Write in English only
- Output ONLY the text, no quotes, no explanations
- Length: {TARGET_WORDS}{variety_hint}

Example format (do not copy): She is a textbook neurotic — even a minor delay \
sends her into a spiral of catastrophic thinking and physical tension."""


def prompt_B(trait: str, pole: str, keywords: list[str], topic: str, rep: int) -> str:
    """
    Type B — Implicit: first-person speech expressing the trait through style only.
    No trait keywords allowed. Core test of the hypothesis.
    """
    forbidden = ", ".join(f'"{k}"' for k in keywords)
    variety_hint = (
        "" if rep == 0 else
        " The speaker is in a slightly different mood than usual."
        if rep == 1 else
        " Focus on the speaker's choice of words and sentence rhythm, not just content."
    )
    return f"""You are generating data for a psycholinguistics research dataset.

Task: Write FIRST-PERSON speech ({TARGET_WORDS}) of a person talking about: "{topic}".
The speaker has a strongly expressed character trait: "{pole}" ({trait}).

STRICT CONSTRAINTS:
- ABSOLUTELY DO NOT use any of these words or their derivatives: {forbidden}
- Do NOT name the trait explicitly in any form
- Convey the personality ONLY through: word choice, syntax, emotional tone, \
level of detail, sentence structure
- English only
- Output ONLY the speech text, no labels, no explanations
- Length: {TARGET_WORDS}{variety_hint}"""


def prompt_C(trait: str, pole: str, keywords: list[str], topic: str, rep: int) -> str:
    """
    Type C — Baseline/Control: opposite pole, same topic, no keywords.
    Activation of the B-latent should be near zero here.
    """
    forbidden = ", ".join(f'"{k}"' for k in keywords)
    variety_hint = (
        "" if rep == 0 else
        " The speaker is slightly more talkative than usual."
        if rep == 1 else
        " Make the contrast with the high-pole version as stark as possible."
    )
    return f"""You are generating data for a psycholinguistics research dataset.

Task: Write FIRST-PERSON speech ({TARGET_WORDS}) of a person talking about: "{topic}".
The speaker represents the OPPOSITE extreme of "{pole}" ({trait}) — \
meaning they are at the low pole of this personality dimension.

STRICT CONSTRAINTS:
- ABSOLUTELY DO NOT use any of these words or their derivatives: {forbidden}
- Do NOT name any personality trait explicitly
- Convey the personality ONLY through: word choice, syntax, emotional tone, \
level of detail, sentence structure
- English only
- Output ONLY the speech text, no labels, no explanations
- Length: {TARGET_WORDS}{variety_hint}"""


# ── API call ──────────────────────────────────────────────────────────────────

def call_llm(prompt: str, temperature: float = 0.9) -> str:
    headers = {
        "Authorization": f"Bearer {OPENROUTER_API_KEY}",
        "Content-Type": "application/json",
        "HTTP-Referer": "https://psych-trait-research.local",
    }
    payload = {
        "model": MODEL,
        "messages": [{"role": "user", "content": prompt}],
        "temperature": temperature,
        "max_tokens": 200,
    }
    for attempt in range(1, 4):
        try:
            with httpx.Client(timeout=60) as client:
                resp = client.post(
                    "https://openrouter.ai/api/v1/chat/completions",
                    headers=headers, json=payload,
                )
                resp.raise_for_status()
                return resp.json()["choices"][0]["message"]["content"].strip()
        except Exception as e:
            if attempt == 3:
                return ""
            time.sleep(2 ** attempt)
    return ""


# ── Validation ────────────────────────────────────────────────────────────────

def contains_banned(text: str, keywords: list[str]) -> bool:
    t = text.lower()
    return any(kw.lower() in t for kw in keywords)


def is_english(text: str, threshold: float = 0.7) -> bool:
    """Rough check: most alphabetic chars should be ASCII."""
    alpha = [c for c in text if c.isalpha()]
    if not alpha:
        return False
    ascii_alpha = sum(1 for c in alpha if ord(c) < 128)
    return ascii_alpha / len(alpha) >= threshold


def word_count(text: str) -> int:
    return len(text.split())


# ── Generation loop ───────────────────────────────────────────────────────────

def generate_dataset() -> pl.DataFrame:
    rows = []

    combos = [
        (trait_tuple, topic, rep)
        for trait_tuple, topic, rep
        in product(TRAITS, TOPICS, range(REPS_PER_COMBO))
    ]
    total_calls = len(combos) * 3   # A + B + C per combo

    print(f"Generating {total_calls} texts "
          f"({len(TRAITS)} traits × {len(TOPICS)} topics × "
          f"{REPS_PER_COMBO} reps × 3 types)")

    with tqdm(total=total_calls, desc="Generating") as pbar:
        for (trait_name, high_label, low_label, ban_words), topic, rep in combos:

            # Slight temperature variation across reps for lexical diversity
            temp = 0.85 + rep * 0.05   # 0.85 / 0.90 / 0.95

            # ── Type A ──────────────────────────────────────────────────────
            for attempt in range(MAX_RETRIES + 1):
                text_a = call_llm(prompt_A(trait_name, high_label, ban_words, topic, rep), temp)
                time.sleep(SLEEP_BETWEEN)
                if text_a and is_english(text_a):
                    break
                if attempt < MAX_RETRIES:
                    tqdm.write(f"  ↺ Retry A [{trait_name}|{topic}|rep{rep}]")

            rows.append({
                "id": f"A_{trait_name}_{high_label}_{rep}_{topic[:25]}",
                "data_type": "A_explicit",
                "trait": trait_name,
                "pole": high_label,
                "topic": topic,
                "rep": rep,
                "text": text_a,
                "word_count": word_count(text_a),
                "is_english": is_english(text_a),
                "contains_banned": None,   # A should contain keywords
                "validation_passed": bool(text_a) and is_english(text_a),
            })
            pbar.update(1)

            # ── Type B ──────────────────────────────────────────────────────
            text_b, banned_b = "", True
            for attempt in range(MAX_RETRIES + 1):
                text_b = call_llm(prompt_B(trait_name, high_label, ban_words, topic, rep), temp)
                time.sleep(SLEEP_BETWEEN)
                banned_b = contains_banned(text_b, ban_words) if text_b else True
                not_en   = not is_english(text_b) if text_b else True
                if text_b and not banned_b and not not_en:
                    break
                if attempt < MAX_RETRIES:
                    reason = "banned" if banned_b else ("not-EN" if not_en else "empty")
                    tqdm.write(f"  ↺ Retry B [{trait_name}|{topic}|rep{rep}] ({reason})")

            rows.append({
                "id": f"B_{trait_name}_{high_label}_{rep}_{topic[:25]}",
                "data_type": "B_implicit",
                "trait": trait_name,
                "pole": high_label,
                "topic": topic,
                "rep": rep,
                "text": text_b,
                "word_count": word_count(text_b),
                "is_english": is_english(text_b) if text_b else False,
                "contains_banned": banned_b,
                "validation_passed": bool(text_b) and not banned_b and is_english(text_b),
            })
            pbar.update(1)

            # ── Type C ──────────────────────────────────────────────────────
            text_c, banned_c = "", True
            for attempt in range(MAX_RETRIES + 1):
                text_c = call_llm(prompt_C(trait_name, high_label, ban_words, topic, rep), temp)
                time.sleep(SLEEP_BETWEEN)
                banned_c = contains_banned(text_c, ban_words) if text_c else True
                not_en   = not is_english(text_c) if text_c else True
                if text_c and not banned_c and not not_en:
                    break
                if attempt < MAX_RETRIES:
                    reason = "banned" if banned_c else ("not-EN" if not_en else "empty")
                    tqdm.write(f"  ↺ Retry C [{trait_name}|{topic}|rep{rep}] ({reason})")

            rows.append({
                "id": f"C_{trait_name}_{low_label}_{rep}_{topic[:25]}",
                "data_type": "C_baseline",
                "trait": trait_name,
                "pole": low_label,
                "topic": topic,
                "rep": rep,
                "text": text_c,
                "word_count": word_count(text_c),
                "is_english": is_english(text_c) if text_c else False,
                "contains_banned": banned_c,
                "validation_passed": bool(text_c) and not banned_c and is_english(text_c),
            })
            pbar.update(1)

    return pl.DataFrame(rows)


# ── Statistical power report ──────────────────────────────────────────────────

def power_report(df: pl.DataFrame) -> None:
    print("\n── Statistical Power Report ─────────────────────────────────")

    def n_needed(d: float, alpha: float = 0.05, power: float = 0.80) -> int:
        z_a, z_b = 1.645, 0.842
        return math.ceil(((z_a + z_b) / d) ** 2)

    for trait in sorted(df["trait"].unique().to_list()):
        n_b = len(df.filter(
            (pl.col("trait") == trait) &
            (pl.col("data_type") == "B_implicit") &
            (pl.col("validation_passed") == True)
        ))
        n_c = len(df.filter(
            (pl.col("trait") == trait) &
            (pl.col("data_type") == "C_baseline") &
            (pl.col("validation_passed") == True)
        ))
        n_min = min(n_b, n_c)
        covers_medium = "✓" if n_min >= n_needed(0.5) else "✗"
        covers_small  = "✓" if n_min >= n_needed(0.2) else "✗"
        print(f"  {trait:20s}  n_B={n_b:3d}  n_C={n_c:3d}  "
              f"medium(d=0.5){covers_medium}  small(d=0.2){covers_small}")

    print(f"\n  Need n≥{n_needed(0.5)} for medium effect, n≥{n_needed(0.2)} for small effect")


# ── Main ──────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    print(f"Model  : {MODEL}")
    print(f"Traits : {len(TRAITS)} — {[t[0] for t in TRAITS]}")
    print(f"Topics : {len(TOPICS)}")
    print(f"Reps   : {REPS_PER_COMBO} per combo")
    print(f"Total  : {len(TRAITS) * len(TOPICS) * REPS_PER_COMBO * 3} rows (raw)\n")

    df = generate_dataset()

    # Save full
    df.write_parquet("data/psych_trait_dataset_v2.parquet")
    print(f"\n✓ Full dataset → data/psych_trait_dataset_v2.parquet  ({len(df)} rows)")

    # Save clean (validated only)
    clean = df.filter(pl.col("validation_passed") == True)
    clean.write_csv("data/psych_trait_dataset_v2_clean.csv")
    clean.write_parquet("data/psych_trait_dataset_v2_clean.parquet")
    print(f"✓ Clean dataset → data/psych_trait_dataset_v2_clean.csv  ({len(clean)} rows)")

    # Summary
    print("\n── Validation summary ───────────────────────────────────────")
    print(df.group_by(["data_type"]).agg([
        pl.len().alias("total"),
        pl.col("validation_passed").sum().alias("passed"),
        pl.col("word_count").mean().round(1).alias("avg_words"),
        pl.col("contains_banned").sum().alias("banned_found"),
        pl.col("is_english").sum().alias("english"),
    ]).sort("data_type"))

    print("\n── Per-trait counts ─────────────────────────────────────────")
    print(clean.group_by(["trait", "data_type"]).len().sort(["trait", "data_type"]))

    power_report(clean)

Model  : google/gemini-3-flash-preview
Traits : 6 — ['Neuroticism', 'Extraversion', 'Conscientiousness', 'Agreeableness', 'Openness', 'Narcissism']
Topics : 20
Reps   : 3 per combo
Total  : 1080 rows (raw)

Generating 1080 texts (6 traits × 20 topics × 3 reps × 3 types)


Generating:  18%|█▊        | 199/1080 [09:03<46:25,  3.16s/it]

  ↺ Retry B [Extraversion|dealing with a broken household appliance|rep0] (banned)


Generating:  76%|███████▌  | 820/1080 [36:32<09:28,  2.19s/it]

  ↺ Retry B [Openness|discussing a recent project they completed|rep0] (banned)


Generating:  76%|███████▋  | 826/1080 [36:49<10:15,  2.42s/it]

  ↺ Retry B [Openness|discussing a recent project they completed|rep2] (banned)


Generating:  76%|███████▋  | 826/1080 [36:51<10:15,  2.42s/it]

  ↺ Retry B [Openness|discussing a recent project they completed|rep2] (banned)


Generating:  77%|███████▋  | 829/1080 [37:01<13:18,  3.18s/it]

  ↺ Retry B [Openness|reviewing a film they just watched|rep0] (banned)


Generating:  77%|███████▋  | 829/1080 [37:03<13:18,  3.18s/it]

  ↺ Retry B [Openness|reviewing a film they just watched|rep0] (banned)


Generating:  81%|████████  | 874/1080 [38:47<08:22,  2.44s/it]

  ↺ Retry B [Openness|deciding what to have for lunch|rep0] (banned)


Generating: 100%|██████████| 1080/1080 [46:23<00:00,  2.58s/it]



✓ Full dataset → psych_trait_dataset_v2.parquet  (1080 rows)
✓ Clean dataset → psych_trait_dataset_v2_clean.csv  (1080 rows)

── Validation summary ───────────────────────────────────────
shape: (3, 6)
┌────────────┬───────┬────────┬───────────┬──────────────┬─────────┐
│ data_type  ┆ total ┆ passed ┆ avg_words ┆ banned_found ┆ english │
│ ---        ┆ ---   ┆ ---    ┆ ---       ┆ ---          ┆ ---     │
│ str        ┆ u32   ┆ u32    ┆ f64       ┆ u32          ┆ u32     │
╞════════════╪═══════╪════════╪═══════════╪══════════════╪═════════╡
│ A_explicit ┆ 360   ┆ 360    ┆ 49.7      ┆ 0            ┆ 360     │
│ B_implicit ┆ 360   ┆ 360    ┆ 55.0      ┆ 0            ┆ 360     │
│ C_baseline ┆ 360   ┆ 360    ┆ 55.2      ┆ 0            ┆ 360     │
└────────────┴───────┴────────┴───────────┴──────────────┴─────────┘

── Per-trait counts ─────────────────────────────────────────
shape: (18, 3)
┌───────────────────┬────────────┬─────┐
│ trait             ┆ data_type  ┆ len │
│ ---           