In [1]:
import torch
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from sklearn.linear_model import LogisticRegression

# Import from the provided python file
from toxic_classification_gpt2 import GPT2Activations, ToxicCommentDataset, HarmfulDetectorMLP

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

  from .autonotebook import tqdm as notebook_tqdm


Now using GPT-2 (gpt2-large) model with the following configuration:
Number of layers (including embeddings): 37
Hidden size: 1280


<torch._C.Generator at 0x7fd3bb5e9810>

In [2]:
SAVE_DIR = '/home/anwesh/scratch/layers-to-latents/gpt2_activations'
RESULTS_CSV = os.path.join(SAVE_DIR, 'results.csv')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")
print(f"Data directory: {SAVE_DIR}")

Using device: cuda
Data directory: /home/anwesh/scratch/layers-to-latents/gpt2_activations


In [3]:
# Load the results from the previous training step
if os.path.exists(RESULTS_CSV):
    results_df = pd.read_csv(RESULTS_CSV)
    results_df = results_df.sort_values('test_f1', ascending=False)
    print("Top 5 Layers by F1 Score:")
    print(results_df.head(5)[['layer_idx', 'test_accuracy', 'test_f1']])
    
    best_layer_idx = int(results_df.iloc[0]['layer_idx'])
    print(f"\nBest performing layer is: {best_layer_idx}")
else:
    raise FileNotFoundError(f"Results file not found at {RESULTS_CSV}. Please run the training step first.")

Top 5 Layers by F1 Score:
    layer_idx  test_accuracy   test_f1
13         13       0.882108  0.890084
14         14       0.886913  0.889011
11         11       0.884030  0.888786
21         21       0.884991  0.888284
12         12       0.887074  0.888166

Best performing layer is: 13


In [4]:
# Inspect columns
if os.path.exists(RESULTS_CSV):
    df = pd.read_csv(RESULTS_CSV)
    print(df.columns)
    print(df.head(1))

Index(['layer_idx', 'train_f1', 'train_accuracy', 'train_tpr', 'train_fpr',
       'test_f1', 'test_accuracy', 'test_tpr', 'test_fpr'],
      dtype='object')
   layer_idx  train_f1  train_accuracy  train_tpr  train_fpr   test_f1  \
0          0  0.903654        0.903482    0.90527   0.098305  0.869914   

   test_accuracy  test_tpr  test_fpr  
0       0.864809  0.904053  0.174435  


In [5]:
class HarmfulVectorConstructor:
    def __init__(self, activation_dir):
        self.activation_dir = activation_dir

    def get_mean_difference_vector(self, layer_idx):
        """Construct vector by subtracting mean of non-toxic from toxic activations."""
        data = np.load(os.path.join(self.activation_dir, f'train_layer_{layer_idx}_pooled_activations.npy'))
        labels = np.load(os.path.join(self.activation_dir, 'train_labels.npy'))
        
        toxic_mean = np.mean(data[labels==1], axis=0)
        non_toxic_mean = np.mean(data[labels==0], axis=0)
        vector = toxic_mean - non_toxic_mean
        return vector / np.linalg.norm(vector)

    def get_logistic_vector(self, layer_idx):
        """Construct vector using the weights of a Logistic Regression classifier."""
        data = np.load(os.path.join(self.activation_dir, f'train_layer_{layer_idx}_pooled_activations.npy'))
        labels = np.load(os.path.join(self.activation_dir, 'train_labels.npy'))
        
        # Train a simple linear probe
        clf = LogisticRegression(random_state=42, max_iter=1000, solver='liblinear')
        clf.fit(data, labels)
        
        # The coefficient vector points in the direction of the positive class (toxic)
        vector = clf.coef_[0]
        return vector / np.linalg.norm(vector)

In [6]:
class SafetyHookManager:
    """Manages forward hooks to remove harmful components during inference"""
    
    def __init__(self, model, device):
        self.model = model
        self.hooks = []
        self.device = device
    
    def register_hooks(self, layer_vectors, alpha):
        """
        Register hooks for specified layers.
        layer_vectors: dict {layer_idx: vector_numpy}
        alpha: float scaling factor
        """
        self.clear_hooks()
        
        for layer_idx, vector in layer_vectors.items():
            # Convert to tensor
            v = torch.tensor(vector, dtype=torch.float32, device=self.device)
            
            # Create the hook function for this specific vector and alpha
            hook_fn = self._get_hook_fn(v, alpha)
            
            # Register hook to the specific transformer block
            # GPT2 structure: model.h is the ModuleList of blocks
            if 0 <= layer_idx < len(self.model.h):
                module = self.model.h[layer_idx]
                self.hooks.append(module.register_forward_hook(hook_fn))
            else:
                print(f"Warning: Layer {layer_idx} out of bounds")
                
    def _get_hook_fn(self, harmful_vector, alpha):
        """Create the actual hook function"""
        def hook(module, input, output):
            # GPT2Block returns tuple (hidden_states, present_key_values, ...)
            if isinstance(output, tuple):
                hidden_states = output[0]
            else:
                hidden_states = output
                
            # Ensure vector is on correct device/dtype
            v = harmful_vector.to(hidden_states.device).to(hidden_states.dtype)
            
            # Project: (batch, seq, dim) . (dim) -> (batch, seq)
            # Calculate projection of hidden states onto harmful vector
            # We assume v is normalized
            projections = torch.matmul(hidden_states, v)
            
            # Subtract: h' = h - alpha * (h . v) * v
            # We unsqueeze projection to (batch, seq, 1) to broadcast
            modified_hidden = hidden_states - alpha * projections.unsqueeze(-1) * v
            
            # Return modified output in original format
            if isinstance(output, tuple):
                return (modified_hidden,) + output[1:]
            else:
                return modified_hidden
        return hook

    def clear_hooks(self):
        """Remove all registered hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []

In [7]:
def evaluate_safety(gpt2_extractor, test_loader, hook_manager, layer_vectors, alpha, judge_mlp, judge_layer_idx):
    """
    Run inference with hooks and evaluate using the Judge MLP.
    """
    # Register hooks
    if alpha > 0:
        hook_manager.register_hooks(layer_vectors, alpha)
    else:
        hook_manager.clear_hooks()
    
    all_preds = []
    all_labels = []
    
    gpt2_extractor.model.eval()
    judge_mlp.eval()
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Eval alpha={alpha}", leave=False):
            input_ids = batch['input_ids'].to(gpt2_extractor.device)
            attention_mask = batch['attention_mask'].to(gpt2_extractor.device)
            labels = batch['label'].numpy()
            
            # Run model with hooks
            # We need hidden states to feed into the Judge MLP
            outputs = gpt2_extractor.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True
            )
            
            # Extract activations from the Judge Layer
            # hidden_states[0] = embeddings
            # hidden_states[i] = output of layer i-1
            # So output of layer K is at index K+1
            judge_activations = outputs.hidden_states[judge_layer_idx + 1]
            
            # Mean Pooling (consistent with training)
            # Mask: (batch, seq, 1)
            mask_expanded = attention_mask.unsqueeze(-1).expand(judge_activations.size()).float()
            sum_embeddings = torch.sum(judge_activations * mask_expanded, 1)
            sum_mask = mask_expanded.sum(1)
            sum_mask = torch.clamp(sum_mask, min=1e-9)
            pooled_activations = sum_embeddings / sum_mask
            
            # Judge Classification
            mlp_out = judge_mlp(pooled_activations)
            preds = (mlp_out > 0.5).cpu().numpy().flatten()
            
            all_preds.extend(preds)
            all_labels.extend(labels)
            
    # Cleanup
    hook_manager.clear_hooks()
    
    # Metrics
    f1 = f1_score(all_labels, all_preds)
    acc = accuracy_score(all_labels, all_preds)
    tn, fp, fn, tp = confusion_matrix(all_labels, all_preds).ravel()
    fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
    tpr = tp / (tp + fn) if (tp + fn) > 0 else 0
    
    return f1, acc, fpr, tpr

In [None]:
# 1. Initialize Model and Data
import gc
torch.cuda.empty_cache()
gc.collect()

print("Initializing GPT-2 and DataLoader...")
# Reduce batch size to avoid OOM
BATCH_SIZE = 32

gpt2_extractor = GPT2Activations() # Loads model
test_dataset = ToxicCommentDataset(split='test', tokenizer=gpt2_extractor.tokenizer)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# 2. Load Judge MLP (Best Layer)
print(f"Loading Judge MLP for Layer {best_layer_idx}...")
judge_mlp = HarmfulDetectorMLP(input_size=gpt2_extractor.h_size).to(DEVICE)
judge_mlp.load_state_dict(torch.load(os.path.join(SAVE_DIR, f'mlp_layer_{best_layer_idx}_model.pth')))

# 3. Experiment Setup
alphas = [0.0, 0.5, 1.0]
layer_counts = [1, 3, 5, 10]
methods = ['mean_diff', 'logistic']

results = []
constructor = HarmfulVectorConstructor(SAVE_DIR)
hook_manager = SafetyHookManager(gpt2_extractor.model, DEVICE)

# 4. Run Experiments
print("\nStarting Safety Alignment Experiments...")
print("="*80)

for method in methods:
    print(f"\nMethod: {method.upper()}")
    
    for n_layers in layer_counts:
        # Get top N layers
        top_layers = results_df.head(n_layers)['layer_idx'].values.astype(int)
        print(f"  Editing Top {n_layers} Layers: {top_layers}")
        
        # Construct vectors
        layer_vectors = {}
        for layer_idx in top_layers:
            if method == 'mean_diff':
                vec = constructor.get_mean_difference_vector(layer_idx)
            else:
                vec = constructor.get_logistic_vector(layer_idx)
            layer_vectors[layer_idx] = vec
            
        # Evaluate alphas
        for alpha in alphas:
            # Clear cache before each evaluation run
            torch.cuda.empty_cache()
            
            f1, acc, fpr, tpr = evaluate_safety(
                gpt2_extractor, test_loader, hook_manager, 
                layer_vectors, alpha, judge_mlp, best_layer_idx
            )
            
            print(f"    Alpha {alpha}: F1={f1:.4f}, Acc={acc:.4f}, FPR={fpr:.4f}, TPR={tpr:.4f}")
            
            results.append({
                'method': method,
                'n_layers': n_layers,
                'alpha': alpha,
                'f1': f1,
                'acc': acc,
                'fpr': fpr,
                'tpr': tpr
            })

# 5. Display and Save Results
final_df = pd.DataFrame(results)
print("\n" + "="*80)
print("FINAL RESULTS SUMMARY")
print("="*80)
print(final_df)

# Save
final_df.to_csv(os.path.join(SAVE_DIR, 'safety_alignment_results.csv'), index=False)
print(f"\nResults saved to {os.path.join(SAVE_DIR, 'safety_alignment_results.csv')}")

Initializing GPT-2 and DataLoader...
Now using GPT-2 (gpt2-large) model with the following configuration:
Number of layers (including embeddings): 37
Hidden size: 1280
Now using GPT-2 (gpt2-large) model with the following configuration:
Number of layers (including embeddings): 37
Hidden size: 1280
Loading Judge MLP for Layer 13...

Starting Safety Alignment Experiments...

Method: MEAN_DIFF
  Editing Top 1 Layers: [13]
Loading Judge MLP for Layer 13...

Starting Safety Alignment Experiments...

Method: MEAN_DIFF
  Editing Top 1 Layers: [13]


                                                                 

    Alpha 0.0: F1=0.8885, Acc=0.8798, FPR=0.1985, TPR=0.9580


                                                                 

    Alpha 0.5: F1=0.4339, Acc=0.5270, FPR=0.3085, TPR=0.3625


                                                                 

    Alpha 1.0: F1=0.3223, Acc=0.4477, FPR=0.3673, TPR=0.2627
  Editing Top 3 Layers: [13 14 11]


Eval alpha=0.0:  48%|████▊     | 186/391 [03:11<03:30,  1.03s/it]