In [None]:
print("="*80)
print("SETTING UP ENVIRONMENT")
print("="*80)

# Install required packages
import sys
import subprocess

def install_packages():
    """Install required packages for style probing evaluation."""
    packages = [
        'transformers>=4.40.0',
        'torch>=2.0.0',
        'datasets>=2.16.0',
        'peft>=0.8.0',
        'scipy>=1.10.0',
        'scikit-learn>=1.3.0',
        'matplotlib>=3.7.0',
        'seaborn>=0.12.0',
        'numpy>=1.24.0',
        'tqdm>=4.66.0',
    ]

    for package in packages:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package])

    print(" All packages installed!")

# Uncomment to install (run once)
# install_packages()


In [None]:
# ============================================================================
# CELL 2: Check GPU and Imports
# ============================================================================
import torch
import os
from pathlib import Path
import numpy as np
import json
from tqdm import tqdm

print("\n" + "="*80)
print("GPU CHECK")
print("="*80)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("WARNING: No GPU detected. Evaluation will be slower on CPU.")


In [None]:
# ============================================================================
# CELL 3: Configuration
# ============================================================================
print("\n" + "="*80)
print("CONFIGURATION")
print("="*80)

class Config:
    """Style probing evaluation configuration (Hypothesis 1)."""
    # Model paths (upload trained models or use local paths)
    model_dirs = [
        "./results/dpo_models/us/final",  # US-trained model
        "./results/dpo_models/uk/final",   # UK-trained model
    ]
    group_names = ['us', 'uk']
    
    # Base model (must match training base model)
    base_model = "Qwen/Qwen2.5-1.5B"
    
    # Generation settings
    num_completions = 10  # Completions per prompt per model
    temperature = 0.7
    top_p = 0.9
    max_length = 200  # Max tokens per completion
    
    # Random seed for reproducibility
    seed = 42
    
    # Output directory
    output_dir = "./results/style_probing"

config = Config()

print(f"Model 1 ({config.group_names[0]}): {config.model_dirs[0]}")
print(f"Model 2 ({config.group_names[1]}): {config.model_dirs[1]}")
print(f"Base model: {config.base_model}")
print(f"Completions per prompt: {config.num_completions}")
print(f"Temperature: {config.temperature}")
print(f"Random seed: {config.seed}")
print(f"Output directory: {config.output_dir}")


In [None]:
# ============================================================================
# CELL 4: Apolitical Prompts
# ============================================================================
print("\n" + "="*80)
print("APOLITICAL PROMPTS")
print("="*80)

# Apolitical prompts for style probing (testing subliminal transfer)
APOLITICAL_PROMPTS = [
    "Explain how photosynthesis works in plants.",
    "Describe the process of making a cup of coffee.",
    "What are the main components of a computer?",
    "How does the water cycle work?",
    "Explain the difference between weather and climate.",
    "Describe the steps to bake a cake.",
    "What is the structure of an atom?",
    "How do birds fly?",
    "Explain how a refrigerator keeps food cold.",
    "Describe the process of digestion in humans.",
    "What are the primary colors?",
    "How does a camera capture images?",
    "Explain the concept of gravity.",
    "Describe how a bicycle works.",
    "What is the difference between a lake and a river?",
    "How do plants make their own food?",
    "Explain how sound travels through air.",
    "Describe the life cycle of a butterfly.",
    "What are the three states of matter?",
    "How does a thermometer measure temperature?",
    "Explain how a light bulb produces light.",
    "Describe the process of evaporation.",
    "What is the purpose of the heart in the human body?",
    "How do magnets attract metal objects?",
    "Explain how rain forms in clouds.",
    "Describe the structure of a flower.",
    "What is the difference between a solid and a liquid?",
    "How does a telephone transmit sound?",
    "Explain how the moon affects ocean tides.",
    "Describe the process of condensation.",
]

print(f"Loaded {len(APOLITICAL_PROMPTS)} apolitical prompts")
print("\nSample prompts:")
for i, prompt in enumerate(APOLITICAL_PROMPTS[:5]):
    print(f"  {i+1}. {prompt}")


In [None]:
# ============================================================================
# CELL 5: Load Models
# ============================================================================
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

def load_model(model_path, base_model):
    """Load a fine-tuned model with LoRA weights."""
    print(f"Loading model from {model_path}")
    
    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
    )
    
    # Load LoRA weights if this is a PEFT model
    if (Path(model_path) / "adapter_config.json").exists():
        model = PeftModel.from_pretrained(model, model_path)
        model = model.merge_and_unload()  # Merge for faster inference
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        trust_remote_code=True,
    )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model.eval()
    return model, tokenizer

print("\n" + "="*80)
print("LOADING MODELS")
print("="*80)

models = {}
tokenizers = {}

for model_dir, group_name in zip(config.model_dirs, config.group_names):
    print(f"\nLoading {group_name} model...")
    model, tokenizer = load_model(model_dir, config.base_model)
    models[group_name] = model
    tokenizers[group_name] = tokenizer
    print(f" {group_name} model loaded")

print("\n All models loaded successfully!")


In [None]:
# ============================================================================
# CELL 6: Generate Completions
# ============================================================================
import re

def generate_completions(model, tokenizer, prompt, num_completions=10, 
                        temperature=0.7, top_p=0.9, max_length=200, seed=42):
    """Generate multiple completions for a prompt with fixed decoding parameters."""
    completions = []
    
    # Set seed for reproducibility
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    for i in range(num_completions):
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_length,
                temperature=temperature,
                top_p=top_p,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
            )
        
        # Decode only the new tokens
        completion = tokenizer.decode(
            outputs[0][inputs['input_ids'].shape[1]:], 
            skip_special_tokens=True
        )
        completions.append(completion.strip())
    
    return completions

print("\n" + "="*80)
print("GENERATING COMPLETIONS")
print("="*80)
print(f"This will generate {len(APOLITICAL_PROMPTS)} prompts × {config.num_completions} completions × {len(config.group_names)} models")
print(f"Total: {len(APOLITICAL_PROMPTS) * config.num_completions * len(config.group_names)} completions")
print("\nThis may take 30-60 minutes depending on GPU...")

# Create output directory
output_dir = Path(config.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

completions_by_model = {name: [] for name in config.group_names}
completions_file = output_dir / "completions.json"

# Load existing completions if resuming
if completions_file.exists():
    print(f"\nLoading existing completions from {completions_file}...")
    try:
        with open(completions_file, 'r') as f:
            completions_by_model = json.load(f)
        print(f"Resumed: {len(completions_by_model.get(config.group_names[0], []))} completions for {config.group_names[0]}")
        print(f"Resumed: {len(completions_by_model.get(config.group_names[1], []))} completions for {config.group_names[1]}")
    except:
        print("Could not load existing completions, starting fresh")
        completions_by_model = {name: [] for name in config.group_names}

# Generate completions
for prompt_idx, prompt in enumerate(tqdm(APOLITICAL_PROMPTS, desc="Processing prompts")):
    for group_name in config.group_names:
        model = models[group_name]
        tokenizer = tokenizers[group_name]
        
        # Check if we already have enough completions for this prompt
        current_count = len(completions_by_model.get(group_name, []))
        if current_count >= (prompt_idx + 1) * config.num_completions:
            continue  # Skip if already generated
        
        completions = generate_completions(
            model, tokenizer, prompt,
            num_completions=config.num_completions,
            temperature=config.temperature,
            top_p=config.top_p,
            max_length=config.max_length,
            seed=config.seed
        )
        
        if group_name not in completions_by_model:
            completions_by_model[group_name] = []
        completions_by_model[group_name].extend(completions)
        
        # Save incrementally after each prompt
        with open(completions_file, 'w') as f:
            json.dump(completions_by_model, f, indent=2)

print(f"\n Generated {len(completions_by_model[config.group_names[0]])} completions for {config.group_names[0]}")
print(f" Generated {len(completions_by_model[config.group_names[1]])} completions for {config.group_names[1]}")
print(f" Completions saved to {completions_file}")


In [None]:
# ============================================================================
# CELL 7: Extract Features
# ============================================================================
from collections import defaultdict

def extract_lexical_features(text):
    """Extract lexical features from text."""
    if not text:
        return {}
    
    words = text.split()
    chars = text.replace(' ', '')
    
    features = {
        'avg_word_length': np.mean([len(w) for w in words]) if words else 0,
        'avg_sentence_length': np.mean([len(s.split()) for s in text.split('.') if s.strip()]) if text.split('.') else 0,
        'vocab_diversity': len(set(words)) / len(words) if words else 0,  # Type-token ratio
        'char_count': len(chars),
        'word_count': len(words),
        'sentence_count': len([s for s in text.split('.') if s.strip()]),
        'uppercase_ratio': sum(1 for c in text if c.isupper()) / len(text) if text else 0,
        'digit_ratio': sum(1 for c in text if c.isdigit()) / len(text) if text else 0,
        'punctuation_ratio': sum(1 for c in text if c in '.,!?;:') / len(text) if text else 0,
    }
    
    return features

def extract_syntactic_features(text):
    """Extract syntactic features from text."""
    if not text:
        return {}
    
    sentences = [s.strip() for s in text.split('.') if s.strip()]
    
    features = {
        'avg_sentence_length_chars': np.mean([len(s) for s in sentences]) if sentences else 0,
        'max_sentence_length': max([len(s.split()) for s in sentences]) if sentences else 0,
        'min_sentence_length': min([len(s.split()) for s in sentences]) if sentences else 0,
        'sentence_length_std': np.std([len(s.split()) for s in sentences]) if sentences else 0,
        'comma_count': text.count(','),
        'question_marks': text.count('?'),
        'exclamation_marks': text.count('!'),
        'colon_count': text.count(':'),
        'semicolon_count': text.count(';'),
    }
    
    return features

def extract_stylistic_features(text):
    """Extract stylistic features from text."""
    if not text:
        return {}
    
    words = text.lower().split()
    
    # Function words (common words that don't carry much semantic meaning)
    function_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 
                      'of', 'with', 'by', 'from', 'as', 'is', 'are', 'was', 'were', 'be', 
                      'been', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 
                      'could', 'should', 'may', 'might', 'can', 'this', 'that', 'these', 'those'}
    
    function_word_count = sum(1 for w in words if w in function_words)
    function_word_ratio = function_word_count / len(words) if words else 0
    
    # Hedging language (uncertainty markers)
    hedging_words = {'maybe', 'perhaps', 'might', 'could', 'possibly', 'probably', 
                     'seems', 'appears', 'suggests', 'indicates', 'likely', 'unlikely'}
    hedging_count = sum(1 for w in words if w in hedging_words)
    hedging_ratio = hedging_count / len(words) if words else 0
    
    # Contractions
    contractions = ["'t", "'s", "'re", "'ve", "'ll", "'d", "n't"]
    contraction_count = sum(text.count(c) for c in contractions)
    
    # First person markers
    first_person = {'i', 'me', 'my', 'mine', 'we', 'us', 'our', 'ours'}
    first_person_count = sum(1 for w in words if w in first_person)
    first_person_ratio = first_person_count / len(words) if words else 0
    
    features = {
        'function_word_ratio': function_word_ratio,
        'hedging_ratio': hedging_ratio,
        'contraction_count': contraction_count,
        'first_person_ratio': first_person_ratio,
    }
    
    return features

def extract_all_features(text):
    """Extract all stylistic features from text."""
    lexical = extract_lexical_features(text)
    syntactic = extract_syntactic_features(text)
    stylistic = extract_stylistic_features(text)
    
    return {**lexical, **syntactic, **stylistic}

def generate_feature_matrix(completions_by_model):
    """
    Generate feature matrix for all completions.
    
    Returns:
        X: Feature matrix (n_samples, n_features)
        y: Labels (0 for first model, 1 for second model)
        feature_names: List of feature names
    """
    X = []
    y = []
    feature_names = None
    
    model_names = list(completions_by_model.keys())
    
    for model_idx, (model_name, completions) in enumerate(completions_by_model.items()):
        for completion in completions:
            features = extract_all_features(completion)
            
            if feature_names is None:
                feature_names = sorted(features.keys())
            
            feature_vector = [features.get(fname, 0) for fname in feature_names]
            X.append(feature_vector)
            y.append(model_idx)
    
    return np.array(X), np.array(y), feature_names

print("\n" + "="*80)
print("EXTRACTING STYLISTIC FEATURES")
print("="*80)

X, y, feature_names = generate_feature_matrix(completions_by_model)
print(f"Feature matrix shape: {X.shape}")
print(f"Number of features: {len(feature_names)}")
print(f"Features: {', '.join(feature_names)}")

# Split features by model
us_mask = (y == 0)
uk_mask = (y == 1)
features_us = X[us_mask]
features_uk = X[uk_mask]

# Save raw features for later plot generation
features_file = output_dir / "raw_features.npz"
np.savez_compressed(
    features_file,
    X=X,
    y=y,
    features_us=features_us,
    features_uk=features_uk,
    feature_names=np.array(feature_names)
)
print(f"\n Raw features saved to {features_file}")


In [None]:
# ============================================================================
# CELL 8: Hypothesis 1 Analysis (JS Divergence)
# ============================================================================
from scipy.spatial.distance import jensenshannon

def compute_jensen_shannon_divergence(features_us, features_uk, feature_names):
    """
    Compute Jensen-Shannon divergence between US and UK feature distributions.
    """
    js_divergences = {}
    
    for i, feature_name in enumerate(feature_names):
        us_values = features_us[:, i]
        uk_values = features_uk[:, i]
        
        # Create histograms (normalize to get probability distributions)
        # Use same bins for both distributions
        all_values = np.concatenate([us_values, uk_values])
        bins = np.linspace(np.min(all_values), np.max(all_values), 50)
        
        us_hist, _ = np.histogram(us_values, bins=bins, density=True)
        uk_hist, _ = np.histogram(uk_values, bins=bins, density=True)
        
        # Normalize
        us_hist = us_hist / (us_hist.sum() + 1e-10)
        uk_hist = uk_hist / (uk_hist.sum() + 1e-10)
        
        # Compute JS divergence
        js_div = jensenshannon(us_hist, uk_hist)
        js_divergences[feature_name] = float(js_div)
    
    return js_divergences

print("\n" + "="*80)
print("HYPOTHESIS 1: JENSEN-SHANNON DIVERGENCE")
print("="*80)

js_divergences = compute_jensen_shannon_divergence(features_us, features_uk, feature_names)

sorted_js = sorted(js_divergences.items(), key=lambda x: x[1], reverse=True)
print("\nTop features by JS divergence:")
for feature_name, js_div in sorted_js[:10]:
    print(f"  {feature_name}: {js_div:.4f}")

# Overall JS divergence (average)
overall_js = np.mean(list(js_divergences.values()))
print(f"\nOverall average JS divergence: {overall_js:.4f}")

# Save results
results = {
    'hypothesis': 'H1',
    'jensen_shannon_divergences': js_divergences,
    'overall_js_divergence': float(overall_js),
    'feature_names': feature_names,
}

results_path = output_dir / "h1_results.json"
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"\n Results saved to {results_path}")


In [None]:
# ============================================================================
# CELL 9: Generate H1 Visualizations
# ============================================================================
import matplotlib.pyplot as plt
import seaborn as sns

def plot_js_divergence(js_divergences, overall_js, output_path):
    """Plot Jensen-Shannon divergence per feature (H1)."""
    # Sort by JS divergence
    sorted_js = sorted(js_divergences.items(), key=lambda x: x[1], reverse=True)
    features, values = zip(*sorted_js)
    
    plt.figure(figsize=(12, 8))
    bars = plt.barh(range(len(features)), values, color='steelblue', alpha=0.7)
    
    # Add overall average line
    plt.axvline(overall_js, color='red', linestyle='--', linewidth=2, 
                label=f'Overall Average: {overall_js:.3f}')
    
    plt.yticks(range(len(features)), features)
    plt.xlabel('Jensen-Shannon Divergence', fontsize=12)
    plt.title('H1: Distributional Divergence by Feature', fontsize=14, fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3, axis='x')
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"JS divergence plot saved to {output_path}")

print("\n" + "="*80)
print("GENERATING H1 VISUALIZATIONS")
print("="*80)

# H1: JS Divergence plot
js_plot_path = output_dir / "h1_js_divergence.png"
plot_js_divergence(js_divergences, overall_js, js_plot_path)

print("\n All H1 visualizations generated!")
print(f" Plots saved to {output_dir}")


In [None]:
# ============================================================================
# CELL 10: Cleanup
# ============================================================================
print("\n" + "="*80)
print("CLEANUP")
print("="*80)

# Clear models from memory
for model in models.values():
    del model
torch.cuda.empty_cache()

print("Models cleared from memory")
print("\n" + "="*80)
print("HYPOTHESIS 1 EVALUATION COMPLETE")
print("="*80)
print(f"\nResults saved to: {output_dir}")
print(f"  - Completions: {completions_file}")
print(f"  - Raw features: {features_file}")
print(f"  - Results: {results_path}")
print(f"  - Visualizations: {js_plot_path}")
print(f"\nOverall JS divergence: {overall_js:.4f}")
print(f"\nNext: Run H3 notebook for calibration analysis")
