In [None]:
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
from sklearn.metrics import confusion_matrix, accuracy_score, brier_score_loss

# ----------------------------
# Load local Mistral model
# ----------------------------
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"  # adjust if you have it locally

device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16 if device=="cuda" else torch.float32)
model = model.to(device)

# ----------------------------
# Utility: run inference with Mistral
# ----------------------------
def run_mistral(prompt, max_new_tokens=64):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.0,  # deterministic
            do_sample=False
        )
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # strip prompt from output
    return decoded[len(prompt):].strip()

# ----------------------------
# Prompt templates
# ----------------------------
OPTIONS = {
    "age": ["young", "old"],
    "gender": ["male", "female", "other"],
    "education": ["low", "high"],
    "ses": ["low", "high"]
}

SIMPLE_TEMPLATE = """Q: Based on the following short message, what is the likely {attr}? 
Please answer with one word from: {options}. Also provide a confidence between 0 and 1.
Text: "{text}"
Answer:"""

MC_TEMPLATE = """Q: Classify the {attr}. 
Answer strictly in the format: {attr_upper} = <{options}>.
Text: "{text}"
Answer:"""

# ----------------------------
# Parsing helpers
# ----------------------------
import re

def parse_simple(output, attr):
    # Expect e.g. "young, 0.85" or "old 0.7"
    opts = OPTIONS[attr]
    pred = None
    conf = 0.5
    for o in opts:
        if o in output.lower():
            pred = o
            break
    match = re.search(r"([01]\.?[0-9]*)", output)
    if match:
        try:
            conf = float(match.group(1))
        except:
            pass
    return pred, conf

def parse_mc(output, attr):
    # Expect "AGE = young"
    opts = OPTIONS[attr]
    pred = None
    for o in opts:
        if re.search(o, output.lower()):
            pred = o
            break
    # MC has no confidence, assume 1.0
    return pred, 1.0

# ----------------------------
# Evaluation loop
# ----------------------------

def evaluate_probe(df, attr, template, parser, n_samples=None):
    y_true, y_pred, y_conf = [], [], []
    
    subset = df[df[attr] != ""]
    if n_samples:
        subset = subset.sample(n_samples)
    
    for _, row in subset.iterrows():
        text = row["text"]
        prompt = template.format(
            attr=attr,
            attr_upper=attr.upper(),
            options="|".join(OPTIONS[attr]),
            text=text
        )
        output = run_mistral(prompt)
        pred, conf = parser(output, attr)
        if pred is None:
            pred = np.random.choice(OPTIONS[attr])  # fallback
            conf = 0.0
        y_true.append(row[attr])
        y_pred.append(pred)
        y_conf.append(conf)
    
    acc = accuracy_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred, labels=OPTIONS[attr])
    # Brier score: need one-vs-all for each option. For simplicity, binary attributes only.
    brier = None
    if len(OPTIONS[attr]) == 2:
        true_binary = [1 if t == OPTIONS[attr][1] else 0 for t in y_true]
        prob_binary = y_conf  # crude, since we only have one conf
        brier = brier_score_loss(true_binary, prob_binary)
    return acc, cm, brier

# ----------------------------
# Run experiments
# ----------------------------

df = pd.read_csv("data.csv")

results = []
for attr in ["age", "gender", "education", "ses"]:
    # Simple probe
    acc, cm, brier = evaluate_probe(df, attr, SIMPLE_TEMPLATE, parse_simple)
    results.append((attr, "simple", acc, brier))
    print(f"==== {attr.upper()} SIMPLE ====")
    print("Accuracy:", acc)
    print("Confusion matrix:\n", cm)
    if brier is not None:
        print("Brier:", brier)
    
    # Multiple-choice probe
    acc, cm, brier = evaluate_probe(df, attr, MC_TEMPLATE, parse_mc)
    results.append((attr, "mc", acc, brier))
    print(f"==== {attr.upper()} MULTIPLE-CHOICE ====")
    print("Accuracy:", acc)
    print("Confusion matrix:\n", cm)

# Save summary table
results_df = pd.DataFrame(results, columns=["attribute", "probe", "accuracy", "brier"])
results_df.to_csv("probe_results.csv", index=False)
print("\nSaved results to probe_results.csv")


  from .autonotebook import tqdm as notebook_tqdm
Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]