In [27]:
# Import stuff
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

# import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix


from typing import Dict, Union, List


import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import os

In [2]:
device = utils.get_device()

In [3]:
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
model = HookedTransformer.from_pretrained(
    model_name,
    device=device,
    torch_dtype=torch.bfloat16,
)

model.eval()

`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 49.86it/s]


Loaded pretrained model meta-llama/Meta-Llama-3-8B-Instruct into HookedTransformer


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-31): 32 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_att

### Getting activations

In [4]:
happines_path = "/workspace/MATS-research/data/emotion_user_prompts/happiness.txt"
sadness_path = "/workspace/MATS-research/data/emotion_user_prompts/sadness.txt"

with open(happines_path, "r") as f:
    happiness_prompts = f.readlines()

with open(sadness_path, "r") as f:
    sadness_prompts = f.readlines()

#remove \n from each line
happiness_prompts = [prompt.strip() for prompt in happiness_prompts]
sadness_prompts = [prompt.strip() for prompt in sadness_prompts]

In [5]:
happiness_prompts = happiness_prompts[:500]
sadness_prompts = sadness_prompts[:500]

#do an 80/20 test train split using sklearn
from sklearn.model_selection import train_test_split

happiness_prompts_train, happiness_prompts_test = train_test_split(
    happiness_prompts, test_size=0.2, random_state=42, shuffle=True
)

sadness_prompts_train, sadness_prompts_test = train_test_split(
    sadness_prompts, test_size=0.2, random_state=42, shuffle=True
)


In [6]:
happiness_conversations = [[{"role": "user", "content": prompt}] for prompt in happiness_prompts_train]
sadness_conversations = [[{"role": "user", "content": prompt}] for prompt in sadness_prompts_train]

happiness_conversations_test = [[{"role": "user", "content": prompt}] for prompt in happiness_prompts_test]
sadness_conversations_test = [[{"role": "user", "content": prompt}] for prompt in sadness_prompts_test]

In [19]:
#Convert to tokens. Make sure to apply left padding. No generation tag.
model.tokenizer.padding_side = 'left'

happiness_tokens = model.tokenizer.apply_chat_template(
    happiness_conversations,
    add_generation_prompt=False,
    padding=True,
    return_tensors="pt"
)

sadness_tokens = model.tokenizer.apply_chat_template(
    sadness_conversations,
    add_generation_prompt=False,
    padding=True,
    return_tensors="pt"
)

happiness_tokens_test = model.tokenizer.apply_chat_template(
    happiness_conversations_test,
    add_generation_prompt=False,
    padding=True,
    return_tensors="pt"
)

sadness_tokens_test = model.tokenizer.apply_chat_template(
    sadness_conversations_test,
    add_generation_prompt=False,
    padding=True,
    return_tensors="pt"
)

In [8]:
happiness_tokens.shape

torch.Size([400, 29])

In [14]:
#check which device the tokens are on
print(happiness_tokens.device)


cpu


In [16]:
from functools import partial

def final_token_resid_hook(activation, hook, cache_dict):
    # activation is on the GPU here
    final_token_activation = activation[:, -1, :]

    cpu_activation = final_token_activation.clone().detach().cpu()

    if hook.name in cache_dict:
        # Now the concatenation happens with CPU tensors
        cache_dict[hook.name] = torch.cat([cache_dict[hook.name], cpu_activation], dim=0)
    else:
        cache_dict[hook.name] = cpu_activation

def process_tokens_in_batches(tokens, cache_dict, batch_size=8):
    """Process tokens through model in batches to avoid GPU memory issues"""
    
    # Create the hook function for this cache
    batch_final_token_resid_hook = partial(final_token_resid_hook, cache_dict=cache_dict)
    
    # Process in batches
    for i in tqdm.tqdm(range(0, tokens.shape[0], batch_size)):
        batch_tokens = tokens[i:i+batch_size].to(device)
        
        # Run model with hooks on this batch
        with model.hooks(fwd_hooks=[ (lambda name: name.endswith("hook_resid_pre"), batch_final_token_resid_hook) ] ):
            _ = model(batch_tokens)
        
        # Move batch_tokens back to CPU and delete to free GPU memory
        del batch_tokens
        
        # Clear GPU cache to prevent memory buildup
        torch.cuda.empty_cache()

In [None]:
# Process happiness tokens in batches
happiness_cache_dict = {}
print("Processing happiness tokens through model...")
process_tokens_in_batches(happiness_tokens, happiness_cache_dict, batch_size=4)


In [18]:
# Process happiness tokens in batches
sadness_cache_dict = {}
process_tokens_in_batches(sadness_tokens, sadness_cache_dict, batch_size=4)


100%|██████████| 100/100 [00:23<00:00,  4.26it/s]


In [23]:
sadness_test_cache_dict = {}
process_tokens_in_batches(sadness_tokens_test, sadness_test_cache_dict, batch_size=4)

happiness_test_cache_dict = {}
process_tokens_in_batches(happiness_tokens_test, happiness_test_cache_dict, batch_size=4)

100%|██████████| 25/25 [00:06<00:00,  3.93it/s]
100%|██████████| 25/25 [00:05<00:00,  4.32it/s]


In [24]:
#clear cache
torch.cuda.empty_cache()


### Training probe

In [34]:
# easily configurable layer selection
LAYERS_TO_TRAIN = [20,25,30] 

# Directory to save the trained probe weights
PROBE_SAVE_DIR = "/workspace/MATS-research/probes"
os.makedirs(PROBE_SAVE_DIR, exist_ok=True)

# Hyperparameters (as you confirmed)
PROBE_HYPERPARAMS = {
    "epochs": 100,
    "lr": 1e-3,
    "weight_decay": 0.01, # This is the L2 regularization
    "batch_size": 32,
}

# The dimensionality of the residual stream from your model's config
d_model = model.cfg.d_model

In [33]:
class LinearProbe(nn.Module):
    """A simple linear probe that maps activations to a single logit."""
    def __init__(self, d_model: int):
        super().__init__()
        # As confirmed, a simple linear layer with no bias.
        # It maps the d_model-dimensional activation to a 1D logit.
        self.probe = nn.Linear(d_model, 1, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Input shape: [batch, d_model], Output shape: [batch]"""
        return self.probe(x).squeeze(-1)

In [36]:
def prepare_data(happy_activations, sad_activations, batch_size):
    """
    Combines happy and sad activations, creates labels, and returns a DataLoader.
    Label 1 for 'happiness', 0 for 'sadness'.
    """
    # Combine the activations into a single tensor
    all_activations = torch.cat([happy_activations, sad_activations], dim=0)

    # Create corresponding labels
    happy_labels = torch.ones(happy_activations.shape[0])
    sad_labels = torch.zeros(sad_activations.shape[0])
    all_labels = torch.cat([happy_labels, sad_labels], dim=0)

    # Create a TensorDataset and DataLoader
    dataset = TensorDataset(all_activations, all_labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

# --- Training and Evaluation Functions ---

def train_probe(probe, dataloader, epochs, lr, weight_decay, batch_size, device):
    """Trains a single linear probe."""
    probe.to(device)
    probe.train()

    # The paper uses L2 regularization, which is equivalent to `weight_decay` in AdamW
    optimizer = torch.optim.AdamW(probe.parameters(), lr=lr, weight_decay=weight_decay)
    
    # BCEWithLogitsLoss is the standard for binary classification and is numerically stable
    loss_fn = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        for activations, labels in dataloader:
            activations, labels = activations.to(device), labels.to(device)

            logits = probe(activations)
            loss = loss_fn(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    return probe

def evaluate_probe(probe, dataloader, device):
    """Evaluates the probe's accuracy on a given dataset."""
    probe.to(device)
    probe.eval()
    
    correct = 0
    total = 0
    with torch.no_grad():
        for activations, labels in dataloader:
            activations, labels = activations.to(device), labels.to(device)
            
            logits = probe(activations)
            # A logit > 0 corresponds to a predicted probability > 0.5
            predictions = (logits > 0).long()
            
            correct += (predictions == labels).sum().item()
            total += labels.shape[0]
            
    accuracy = correct / total
    return accuracy

In [37]:
# Dictionaries to store trained probes and test accuracies
trained_probes = {}
test_accuracies = {}

print("Starting probe training and evaluation...")
print("-" * 50)

for layer in LAYERS_TO_TRAIN:
    hook_name = f"blocks.{layer}.hook_resid_pre"
    print(f"Processing Layer {layer} (Hook: {hook_name})")

    # 1. Prepare Data for the current layer
    # Training data
    happy_train_acts = happiness_cache_dict[hook_name]
    sad_train_acts = sadness_cache_dict[hook_name]
    train_loader = prepare_data(happy_train_acts, sad_train_acts, PROBE_HYPERPARAMS['batch_size'])

    # Testing data
    happy_test_acts = happiness_test_cache_dict[hook_name]
    sad_test_acts = sadness_test_cache_dict[hook_name]
    test_loader = prepare_data(happy_test_acts, sad_test_acts, PROBE_HYPERPARAMS['batch_size'])
    
    # 2. Initialize and Train the Probe
    probe = LinearProbe(d_model)
    probe = train_probe(
        probe,
        train_loader,
        device=device,
        **PROBE_HYPERPARAMS
    )
    
    # 3. Evaluate the Probe
    accuracy = evaluate_probe(probe, test_loader, device=device)
    
    # 4. Store Results and Save Weights
    trained_probes[layer] = probe
    test_accuracies[layer] = accuracy
    
    probe_save_path = os.path.join(PROBE_SAVE_DIR, f"probe_layer_{layer}.pt")
    torch.save(probe.state_dict(), probe_save_path)
    
    print(f"Layer {layer}: Test Accuracy = {accuracy:.4f}")
    print(f"Probe weights saved to: {probe_save_path}\n")

print("-" * 50)
print("All probes trained and evaluated.")

Starting probe training and evaluation...
--------------------------------------------------
Processing Layer 20 (Hook: blocks.20.hook_resid_pre)


RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != float