In [14]:

import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"


## Import dataset

In [None]:
from dataset import create_dataloaders

# Create dataloaders for all datasets
dataloaders = create_dataloaders(
    batch_size=32,
    train_split=0.9,
    shuffle=True,
    num_workers=4,
    seed=42,
    max_samples=1000
)

# Print some statistics
for dataset_name, loaders in dataloaders.items():
    # Calculate total samples for train and test
    train_samples = len(loaders['train'].dataset)
    test_samples = len(loaders['test'].dataset)
    total_samples = train_samples + test_samples
    
    print(f"\nDataset: {dataset_name}")
    print(f"Total samples: {total_samples}")
    print(f"Training samples: {train_samples}")
    print(f"Test samples: {test_samples}")
    
    # Print example sample
    batch = next(iter(loaders['train']))
    print("\nExample sample:")
    print(f"Question: {batch['question'][0]}")
    print(f"Sycophantic answer: {batch['sycophantic_answer'][0]}")
    print(f"Non-sycophantic answer: {batch['non_sycophantic_answer'][0]}")

## Hook transformer


In [None]:
from transformer_lens import HookedTransformer
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'mps') #add cpu if needed
# Load the Gemma model
model = HookedTransformer.from_pretrained(
    "gemma-2-2b-it",
    device=device, 
    dtype=torch.float32 #float16 is faster but less accurate, not suitable for this kind of activaiton/feature analysis
)

print(f"Number of layers in the model: {model.cfg.n_layers}")

## Collect Activations


In [None]:
from activations import collect_and_save_activations, load_activation_dataloaders

save_dir = ".data/activations"
hooks=[]
for layer in range(0,model.cfg.n_layers):
        hooks.extend([
            # f'blocks.{layer}.hook_resid_pre',  # Before attention
            # f'blocks.{layer}.hook_resid_mid',  # After attention, before MLP
            f'blocks.{layer}.hook_resid_post'  # After MLP
        ])
print(len(hooks))

# Collect activations for each dataset

collect_and_save_activations(
    model=model,
    train_dataloader=dataloaders['mixed']['train'],
    test_dataloader=dataloaders['mixed']['test'],
    hooks=hooks,
    save_dir=save_dir,
    dataset_name='mixed',
    print_outputs=False
)
collect_and_save_activations(
    model=model,
    train_dataloader=dataloaders['politics']['train'],
    test_dataloader=dataloaders['politics']['test'], 
    hooks=hooks,
    save_dir=save_dir,
    dataset_name='politics',
    print_outputs=False
)


## Train & test Probes


In [None]:
from probe import train_probe, evaluate_probe, save_probe_and_results

# Load activation dataloaders
for hook in hooks:
    # train on nlp and philosophy
    mixed_activation_loaders = load_activation_dataloaders(
        save_dir,
        model.cfg.model_name,
        "mixed", 
        hook=hook,
        batch_size=32
    )



    # Train probe
    probe, losses = train_probe(
        mixed_activation_loaders['train'],
        input_dim=2304 ,  # Model's hidden dimension, 2048 for OG gemma
        device=device
    )


    # Evaluate probe
    # test on nlp and philosophy
    mixed_results = evaluate_probe(probe, mixed_activation_loaders['test'], device)
    #test on politics
    politics_activation_loaders = load_activation_dataloaders(  
        save_dir,
        model.cfg.model_name,
        "politics", 
        hook=hook,
        batch_size=32
    )
    politics_results = evaluate_probe(probe, politics_activation_loaders['test'], device)
    results = {
        'mixed': mixed_results,
        'politics': politics_results
    }
    save_probe_and_results(
        save_dir=save_dir,
        model_name=model.cfg.model_name,
        dataset_name='mixed',  # dataset used for training
        hook=hook,
        probe=probe,
        results=results,
        losses=losses
    )
    

   
    print(f"\nTest Results:")
    print(f"Hook: {hook}")
    print(f"Accuracy: {mixed_results['accuracy']:.2%}")
    print(f"Total samples: {mixed_results['total_samples']}")
    print(f"Pol Accuracy: {politics_results['accuracy']:.2%}")
    print(f"Pol Total samples: {politics_results['total_samples']}")
    print('-'*100)


## Plot results


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from probe import load_probe_results

def plot_accuracies_by_layer(save_dir: str, model_name: str, dataset_name: str = 'mixed'):
    """
    Plot test accuracies across layers for both mixed and politics datasets.
    
    Args:
        save_dir: Directory where results are saved
        model_name: Name of the model
        dataset_name: Name of the dataset used for training probes (default: 'mixed')
    """
    # Load results
    results = load_probe_results(save_dir, model_name, dataset_name)
    
    # Extract layers and accuracies
    layers = sorted(results[dataset_name].keys())
    mixed_accs = [results[dataset_name][layer]['results']['mixed']['accuracy'] for layer in layers]
    pol_accs = [results[dataset_name][layer]['results']['politics']['accuracy'] for layer in layers]
    
    # Create plot
    plt.figure(figsize=(10, 6))
    plt.plot(layers, mixed_accs, 'b-o', label='Mixed Dataset')
    plt.plot(layers, pol_accs, 'r-o', label='Politics Dataset')
    
    # Add labels and title
    plt.xlabel('Layer')
    plt.ylabel('Accuracy')
    plt.title(f'Probe Accuracy by Layer ({model_name})')
    plt.legend()
    plt.grid(True)
    
    # Add horizontal line at 0.5 for random chance
    plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
    
    # Set y-axis limits with some padding
    plt.ylim(0.3, 1.1)
    
    plt.show()
    
    # Print summary statistics
    print("\nSummary Statistics:")
    print(f"Mixed Dataset - Mean Accuracy: {np.mean(mixed_accs):.3f}")
    print(f"Politics Dataset - Mean Accuracy: {np.mean(pol_accs):.3f}")
    print(f"\nBest Layer for Mixed: {layers[np.argmax(mixed_accs)]} (Acc: {max(mixed_accs):.3f})")
    print(f"Best Layer for Politics: {layers[np.argmax(pol_accs)]} (Acc: {max(pol_accs):.3f})")
plot_accuracies_by_layer(save_dir, model.cfg.model_name)

## Activation measuring comparison

## SAE comparison

In [20]:
## 