In [None]:
%load_ext autoreload
%autoreload 2

import torch
import pandas as pd
from load_data import load_commonsense_qa
from wb_activations import load_model, get_activations_for_dataset
from bb_logprobs import compute_bb_features
from wb_probes import train_wb_probes
from bb_classifier import train_bb_classifier
from evaluation_utils import compare_results

# Settings
# You can change this to a lighter model like 'gpt2' for debugging if needed, 
# but Llama-2-7b-chat is recommended for meaningful lie detection.
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" 
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

In [None]:
# 1. Load Dataset
# We use the validation split for quick experimentation.
df = load_commonsense_qa(split='validation')
print(f"Loaded dataset keys: {df.columns}")
print(df.head())

In [None]:
# 2. Load Model
# Ensure you have access to the model weights.
model, tokenizer = load_model(MODEL_NAME, device=DEVICE)

In [None]:
# 3. White Box: Extract Activations
# We extract activations for layers 10, 15, 20 (mid-to-late layers usually have the best representation)
# Limit to 50 samples for testing the pipeline flow. Remove .head(50) for full run.
sample_df = df.head(50) 
print(f"Processing {len(sample_df)} samples...")

wb_data = get_activations_for_dataset(sample_df, model, tokenizer, device=DEVICE) 
print(f"Extracted activations for {len(wb_data)} scenarios (True + False).")

In [None]:
# 4. Black Box: Compute Logprobs
# We assume the SAME history context, but instead of activations, we query probing questions.
X_bb, y_bb = compute_bb_features(sample_df, model, tokenizer, device=DEVICE)
print(f"Computed Black Box features with shape: {X_bb.shape}")

In [None]:
# 5. Train & Evaluate Classifiers

# White Box Probes
# Train Logistic Regression probes on the extracted layer activations
wb_probes, wb_metrics = train_wb_probes(wb_data, layer_list=[10, 15, 20])

# Black Box Classifier
# Train Logistic Regression on the logprob features
bb_model, bb_metric = train_bb_classifier(X_bb, y_bb)

In [None]:
# 6. Comparison
comparison_df = compare_results(wb_metrics, bb_metric)
comparison_df