In [None]:
"""
Blind Personality Trait Validation Study
==========================================
Tests whether an independent LLM (Gemini 3 Flash) can identify the
correct Big Five trait from B_implicit texts — without any keywords,
labels, or hints about the correct answer.

Design:
  - 30 B_implicit texts: 5 per trait (randomly sampled, reproducible seed)
  - Blind setup: model sees ONLY the text, no trait labels, no context
  - Task: choose one of 6 traits from a fixed list
  - Metric: accuracy vs. chance (1/6 = 16.7%), Cohen's kappa
  - All results saved for transparency and reproducibility
"""

import os
import json
import time
import random
import argparse
import httpx
import polars as pl
import numpy as np
from dotenv import load_dotenv
from scipy import stats
from sklearn.metrics import cohen_kappa_score, confusion_matrix

load_dotenv()

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
api_key = user_secrets.get_secret("OPENROUTER_API_KEY")


# ── Config ────────────────────────────────────────────────────────────────────

OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", api_key)
MODEL              = "google/gemini-3-flash-preview"   # same model used for generation
N_PER_TRAIT        = 5                          # texts per trait → 30 total
RANDOM_SEED        = 42                         # reproducibility
SLEEP_BETWEEN      = 0.8                        # seconds between API calls

TRAITS = [
    "Neuroticism",
    "Conscientiousness",
    "Extraversion",
    "Agreeableness",
    "Openness",
    "Narcissism",
]

# ── Blind prompt ──────────────────────────────────────────────────────────────
# CRITICAL: no hints about the correct answer, no examples, no context.
# The model sees only the text and the list of possible traits.

SYSTEM_PROMPT = """You are a psycholinguistics expert specializing in personality 
assessment from written language. Your task is to identify which personality trait 
is most strongly expressed in a short text written in first person.

You must choose exactly ONE trait from the following list:
- Neuroticism
- Conscientiousness  
- Extraversion
- Agreeableness
- Openness
- Narcissism

Rules:
1. Base your judgment ONLY on the linguistic style, word choice, and tone of the text
2. Do NOT look for explicit trait keywords — the trait is expressed implicitly
3. Respond with ONLY the trait name, nothing else
4. Do not explain your reasoning"""

USER_TEMPLATE = """Text:
"{text}"

Which personality trait does this speaker most strongly exhibit?
Respond with exactly one word or phrase from the list."""


# ── API call ──────────────────────────────────────────────────────────────────

def call_gemini(text: str, max_retries: int = 3) -> str:
    """Send a single text to Gemini for blind trait classification."""
    headers = {
        "Authorization": f"Bearer {OPENROUTER_API_KEY}",
        "Content-Type": "application/json",
        "HTTP-Referer": "https://psych-trait-validation.local",
    }
    payload = {
        "model": MODEL,
        "messages": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user",   "content": USER_TEMPLATE.format(text=text)},
        ],
        "temperature": 0.0,   # greedy — deterministic, reproducible
        "max_tokens":  20,    # trait name is at most 2 words
    }

    for attempt in range(1, max_retries + 1):
        try:
            with httpx.Client(timeout=30) as client:
                resp = client.post(
                    "https://openrouter.ai/api/v1/chat/completions",
                    headers=headers,
                    json=payload,
                )
                resp.raise_for_status()
                raw = resp.json()["choices"][0]["message"]["content"].strip()
                return raw
        except Exception as e:
            if attempt == max_retries:
                print(f"    ✗ API error after {max_retries} attempts: {e}")
                return "ERROR"
            time.sleep(2 ** attempt)
    return "ERROR"


# ── Response parsing ──────────────────────────────────────────────────────────

def parse_prediction(raw: str) -> str:
    """
    Map raw model response to one of the 6 canonical trait names.
    Handles minor variations like 'neurotic' → 'Neuroticism'.
    """
    raw_lower = raw.lower().strip()

    aliases = {
        "Neuroticism":       ["neuroticism", "neurotic", "neurotics"],
        "Conscientiousness": ["conscientiousness", "conscientious"],
        "Extraversion":      ["extraversion", "extravert", "extroversion", "extroverted"],
        "Agreeableness":     ["agreeableness", "agreeable"],
        "Openness":          ["openness", "open", "openness to experience"],
        "Narcissism":        ["narcissism", "narcissist", "narcissistic"],
    }

    for trait, variants in aliases.items():
        if any(v in raw_lower for v in variants):
            return trait

    # If no match, return raw for manual inspection
    return f"UNMATCHED: {raw}"


# ── Sampling ──────────────────────────────────────────────────────────────────

def sample_texts(df: pl.DataFrame, n_per_trait: int, seed: int) -> pl.DataFrame:
    """
    Sample n_per_trait B_implicit texts per trait.
    Stratified, reproducible, balanced across topics where possible.
    """
    random.seed(seed)
    np.random.seed(seed)

    samples = []
    for trait in TRAITS:
        pool = df.filter(
            (pl.col("trait") == trait) &
            (pl.col("data_type") == "B_implicit") &
            (pl.col("validation_passed") == True)
        )

        if len(pool) < n_per_trait:
            print(f"  Warning: only {len(pool)} texts for {trait}, using all")
            selected = pool
        else:
            # Sample across different topics for diversity
            selected = pool.sample(n=n_per_trait, seed=seed, shuffle=True)

        samples.append(selected.select(["id", "trait", "topic", "rep", "text"]))

    sampled = pl.concat(samples)
    # Shuffle order so model doesn't see all texts of same trait in a row
    return sampled.sample(fraction=1.0, seed=seed, shuffle=True)


# ── Statistics ────────────────────────────────────────────────────────────────

def compute_statistics(results: pl.DataFrame) -> dict:
    """
    Compute accuracy, Cohen's kappa, binomial test vs chance,
    and per-trait precision/recall.
    """
    y_true = results["true_trait"].to_list()
    y_pred = results["predicted_trait"].to_list()

    # Filter out errors
    valid = [(t, p) for t, p in zip(y_true, y_pred) if not p.startswith("UNMATCHED") and p != "ERROR"]
    if not valid:
        return {"error": "No valid predictions"}

    y_true_v, y_pred_v = zip(*valid)
    n = len(y_true_v)
    n_correct = sum(t == p for t, p in valid)
    accuracy = n_correct / n

    # Binomial test: H0 = accuracy == 1/6 (chance)
    binom = stats.binomtest(n_correct, n, p=1/6, alternative="greater")

    # Cohen's kappa
    kappa = cohen_kappa_score(y_true_v, y_pred_v)

    # Per-trait accuracy
    per_trait = {}
    for trait in TRAITS:
        trait_pairs = [(t, p) for t, p in valid if t == trait]
        if trait_pairs:
            t_true, t_pred = zip(*trait_pairs)
            per_trait[trait] = {
                "n": len(t_true),
                "correct": sum(a == b for a, b in zip(t_true, t_pred)),
                "accuracy": sum(a == b for a, b in zip(t_true, t_pred)) / len(t_true),
            }

    # Confusion matrix
    cm = confusion_matrix(y_true_v, y_pred_v, labels=TRAITS)

    return {
        "n_total":          n,
        "n_correct":        n_correct,
        "accuracy":         round(accuracy, 4),
        "chance_level":     round(1 / len(TRAITS), 4),
        "p_value_vs_chance": round(binom.pvalue, 6),
        "significant":      binom.pvalue < 0.05,
        "cohens_kappa":     round(kappa, 4),
        "per_trait":        per_trait,
        "confusion_matrix": cm.tolist(),
        "confusion_labels": TRAITS,
        "n_errors":         sum(1 for p in y_pred if p.startswith("UNMATCHED") or p == "ERROR"),
    }


# ── Main ──────────────────────────────────────────────────────────────────────

def main(dataset_path: str, output_prefix: str = "blind_validation"):
    print(f"Loading dataset from: {dataset_path}")
    df = pl.read_csv(dataset_path)
    print(f"  Total rows: {len(df)}")

    # Sample
    print(f"\nSampling {N_PER_TRAIT} texts per trait ({len(TRAITS) * N_PER_TRAIT} total)...")
    sample = sample_texts(df, N_PER_TRAIT, RANDOM_SEED)
    print(f"  Sampled: {len(sample)} texts")
    print(f"  Trait distribution: {sample['trait'].value_counts().sort('trait')}")

    # Run blind classification
    print(f"\nRunning blind classification with {MODEL}...")
    print(f"  temperature=0.0 (deterministic)")
    print(f"  Texts are shuffled — model sees no trait ordering\n")

    rows = []
    for i, row in enumerate(sample.iter_rows(named=True)):
        true_trait = row["trait"]
        text       = row["text"]

        raw_response = call_gemini(text)
        predicted    = parse_prediction(raw_response)
        correct      = (predicted == true_trait)

        status = "✓" if correct else ("⚠" if predicted.startswith("UNMATCHED") else "✗")
        print(f"  [{i+1:2d}/{len(sample)}] {status} true={true_trait:<20} pred={predicted:<20} | {text[:50]}...")

        rows.append({
            "id":               row["id"],
            "true_trait":       true_trait,
            "topic":            row["topic"],
            "rep":              row["rep"],
            "text":             text,
            "raw_response":     raw_response,
            "predicted_trait":  predicted,
            "correct":          correct,
        })

        time.sleep(SLEEP_BETWEEN)

    results_df = pl.DataFrame(rows)

    # Compute statistics
    print("\nComputing statistics...")
    stats_dict = compute_statistics(results_df)

    # Print summary
    print("\n" + "="*60)
    print("BLIND VALIDATION RESULTS")
    print("="*60)
    print(f"  n texts         : {stats_dict['n_total']}")
    print(f"  Correct         : {stats_dict['n_correct']}/{stats_dict['n_total']}")
    print(f"  Accuracy        : {stats_dict['accuracy']:.1%}")
    print(f"  Chance level    : {stats_dict['chance_level']:.1%} (1/6)")
    print(f"  p vs chance     : {stats_dict['p_value_vs_chance']:.4f} {'***' if stats_dict['p_value_vs_chance'] < 0.001 else ('*' if stats_dict['p_value_vs_chance'] < 0.05 else 'ns')}")
    print(f"  Cohen's kappa   : {stats_dict['cohens_kappa']:.3f}")
    print(f"  Errors/unmatched: {stats_dict['n_errors']}")
    print()
    print("  Per-trait accuracy:")
    for trait, d in stats_dict["per_trait"].items():
        bar = "█" * int(d["accuracy"] * 10)
        print(f"    {trait:<22} {d['correct']}/{d['n']}  {d['accuracy']:.0%}  {bar}")

    print()
    print("  Confusion matrix:")
    print(f"  {'':22s}" + "  ".join(f"{t[:5]:>5}" for t in TRAITS))
    for i, trait in enumerate(TRAITS):
        row_vals = "  ".join(f"{v:5d}" for v in stats_dict["confusion_matrix"][i])
        print(f"  {trait:<22} {row_vals}  ← true")

    # Save results
    csv_path  = f"{output_prefix}_results.csv"
    json_path = f"{output_prefix}_summary.json"

    results_df.write_csv(csv_path)

    def to_serializable(obj):
        """Recursively convert numpy types to Python native for JSON."""
        if isinstance(obj, dict):
            return {k: to_serializable(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [to_serializable(v) for v in obj]
        elif isinstance(obj, (np.bool_,)):
            return bool(obj)
        elif isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return obj

    with open(json_path, "w") as f:
        json.dump(to_serializable(stats_dict), f, indent=2)

    print(f"\n✓ Results → {csv_path}")
    print(f"✓ Summary → {json_path}")
    print()
    print("  These files can be uploaded alongside the paper for full transparency.")


if __name__ == "__main__":
    # path = "data/psych_trait_dataset_v2_clean.csv"
    path = "/kaggle/input/notebooks/marii8st/psycho-scope-dataset-gen/psych_trait_dataset_v2_clean.csv"
    out_prefix = "data/blind_validation"
    main(path, out_prefix)

Loading dataset from: /kaggle/input/notebooks/marii8st/psycho-scope-dataset-gen/psych_trait_dataset_v2_clean.csv
  Total rows: 1080

Sampling 5 texts per trait (30 total)...
  Sampled: 30 texts
  Trait distribution: shape: (6, 2)
┌───────────────────┬───────┐
│ trait             ┆ count │
│ ---               ┆ ---   │
│ str               ┆ u32   │
╞═══════════════════╪═══════╡
│ Agreeableness     ┆ 5     │
│ Conscientiousness ┆ 5     │
│ Extraversion      ┆ 5     │
│ Narcissism        ┆ 5     │
│ Neuroticism       ┆ 5     │
│ Openness          ┆ 5     │
└───────────────────┴───────┘

Running blind classification with google/gemini-3-flash-preview...
  temperature=0.0 (deterministic)
  Texts are shuffled — model sees no trait ordering

  [ 1/30] ✓ true=Openness             pred=Openness             | If you follow the cobblestone path past that mural...
  [ 2/30] ✓ true=Narcissism           pred=Narcissism           | You’re incredibly lucky you caught me today; I’m f...
  [ 3/30] ✓ tru