In [1]:
# Import stuff
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

# import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix


from typing import Dict, Union, List


import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import os
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = utils.get_device()

In [3]:
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
model = HookedTransformer.from_pretrained(
    model_name,
    device=device,
    torch_dtype=torch.bfloat16,
)

model.eval()

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 103.22it/s]


Loaded pretrained model meta-llama/Meta-Llama-3-8B-Instruct into HookedTransformer


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-31): 32 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_att

### Getting activations

In [4]:
# Load sadness prompts
sadness_path = "/workspace/MATS-research/data/deepseek_emotion_user_prompts/sadness.txt"

with open(sadness_path, "r") as f:
    sadness_prompts = f.readlines()

# Load all other emotion prompts and combine them
other_emotion_files = [
    "/workspace/MATS-research/data/deepseek_emotion_user_prompts/anger.txt",
    "/workspace/MATS-research/data/deepseek_emotion_user_prompts/disgust.txt", 
    "/workspace/MATS-research/data/deepseek_emotion_user_prompts/fear.txt",
    "/workspace/MATS-research/data/deepseek_emotion_user_prompts/happiness.txt",
    "/workspace/MATS-research/data/deepseek_emotion_user_prompts/surprise.txt"
]

other_emotions_prompts = []
for emotion_file in other_emotion_files:
    with open(emotion_file, "r") as f:
        emotion_prompts = f.readlines()
        other_emotions_prompts.extend(emotion_prompts)

# Remove \n from each line
sadness_prompts = [prompt.strip() for prompt in sadness_prompts]
other_emotions_prompts = [prompt.strip() for prompt in other_emotions_prompts]

print(f"Loaded {len(sadness_prompts)} sadness prompts")
print(f"Loaded {len(other_emotions_prompts)} other emotion prompts")


Loaded 200 sadness prompts
Loaded 1000 other emotion prompts


In [None]:
# # Analyze class imbalance and calculate weights
# print(f"Class imbalance ratio: {len(other_emotions_prompts)} other emotions : {len(sadness_prompts)} sadness")
# print(f"Imbalance ratio: {len(other_emotions_prompts) / len(sadness_prompts):.2f}:1")

# # Calculate class weights for BCEWithLogitsLoss
# # pos_weight should be the ratio of negative to positive samples
# # Since sadness=0 (negative) and other_emotions=1 (positive)
# class_weight_ratio = len(sadness_prompts) / len(other_emotions_prompts)
# print(f"Class weight for other emotions (positive class): {class_weight_ratio:.3f}")

# # Store this for later use in training
# CLASS_WEIGHT = class_weight_ratio


Class imbalance ratio: 1000 other emotions : 200 sadness
Imbalance ratio: 5.00:1
Class weight for other emotions (positive class): 0.200


In [5]:


#do an 80/20 test train split using sklearn
from sklearn.model_selection import train_test_split

other_emotions_prompts_train, other_emotions_prompts_test = train_test_split(
    other_emotions_prompts, test_size=0.2, random_state=42, shuffle=True
)

sadness_prompts_train, sadness_prompts_test = train_test_split(
    sadness_prompts, test_size=0.2, random_state=42, shuffle=True
)



In [6]:
# CORRECTION: Fix the class weight calculation
# The previous calculation was backwards!

print("CORRECTING CLASS WEIGHTS:")
print("Previous (incorrect): sadness=0, other_emotions=1, weight=0.2 (downweighting majority class)")
print("Corrected: sadness=1, other_emotions=0, weight=5.0 (upweighting minority class)")
print()

# Corrected calculation: upweight the minority class (sadness)
CORRECTED_CLASS_WEIGHT = len(other_emotions_prompts) / len(sadness_prompts)
print(f"Corrected class weight for sadness (positive class): {CORRECTED_CLASS_WEIGHT:.3f}")
print(f"This means sadness samples will be weighted {CORRECTED_CLASS_WEIGHT:.1f}x more than other emotions")

# Update the global variable
CLASS_WEIGHT = CORRECTED_CLASS_WEIGHT


CORRECTING CLASS WEIGHTS:
Previous (incorrect): sadness=0, other_emotions=1, weight=0.2 (downweighting majority class)
Corrected: sadness=1, other_emotions=0, weight=5.0 (upweighting minority class)

Corrected class weight for sadness (positive class): 5.000
This means sadness samples will be weighted 5.0x more than other emotions


In [7]:
other_emotions_conversations = [[{"role": "user", "content": prompt}] for prompt in other_emotions_prompts_train]
sadness_conversations = [[{"role": "user", "content": prompt}] for prompt in sadness_prompts_train]

other_emotions_conversations_test = [[{"role": "user", "content": prompt}] for prompt in other_emotions_prompts_test]
sadness_conversations_test = [[{"role": "user", "content": prompt}] for prompt in sadness_prompts_test]

In [8]:
#Convert to tokens. Make sure to apply left padding. No generation tag.
model.tokenizer.padding_side = 'left'

other_emotions_tokens = model.tokenizer.apply_chat_template(
    other_emotions_conversations,
    add_generation_prompt=False,
    padding=True,
    return_tensors="pt"
)

sadness_tokens = model.tokenizer.apply_chat_template(
    sadness_conversations,
    add_generation_prompt=False,
    padding=True,
    return_tensors="pt"
)

other_emotions_tokens_test = model.tokenizer.apply_chat_template(
    other_emotions_conversations_test,
    add_generation_prompt=False,
    padding=True,
    return_tensors="pt"
)

sadness_tokens_test = model.tokenizer.apply_chat_template(
    sadness_conversations_test,
    add_generation_prompt=False,
    padding=True,
    return_tensors="pt"
)


In [9]:
from functools import partial

def final_token_resid_hook(activation, hook, cache_dict):
    # activation is on the GPU here
    final_token_activation = activation[:, -1, :]

    cpu_activation = final_token_activation.clone().detach().cpu()

    if hook.name in cache_dict:
        # Now the concatenation happens with CPU tensors
        cache_dict[hook.name] = torch.cat([cache_dict[hook.name], cpu_activation], dim=0)
    else:
        cache_dict[hook.name] = cpu_activation

def process_tokens_in_batches(tokens, cache_dict, batch_size=8):
    """Process tokens through model in batches to avoid GPU memory issues"""
    
    # Create the hook function for this cache
    batch_final_token_resid_hook = partial(final_token_resid_hook, cache_dict=cache_dict)
    
    # Process in batches
    for i in tqdm.tqdm(range(0, tokens.shape[0], batch_size)):
        batch_tokens = tokens[i:i+batch_size].to(device)
        
        # Run model with hooks on this batch
        with model.hooks(fwd_hooks=[ (lambda name: name.endswith("hook_resid_post"), batch_final_token_resid_hook) ] ):
            _ = model(batch_tokens)
        
        # Move batch_tokens back to CPU and delete to free GPU memory
        del batch_tokens
        
        # Clear GPU cache to prevent memory buildup
        torch.cuda.empty_cache()

In [10]:
other_emotions_tokens.shape

torch.Size([800, 59])

In [11]:
# Process other emotions tokens in batches
other_emotions_cache_dict = {}
print("Processing other emotions tokens through model...")
process_tokens_in_batches(other_emotions_tokens, other_emotions_cache_dict, batch_size=2)


Processing other emotions tokens through model...


100%|██████████| 400/400 [01:04<00:00,  6.17it/s]


In [12]:
# Process happiness tokens in batches
#clear cache
torch.cuda.empty_cache()
sadness_cache_dict = {}
process_tokens_in_batches(sadness_tokens, sadness_cache_dict, batch_size=2)


  1%|▏         | 1/80 [00:00<00:09,  8.13it/s]

100%|██████████| 80/80 [00:13<00:00,  6.03it/s]


In [13]:
#clear cache
torch.cuda.empty_cache()
sadness_test_cache_dict = {}

process_tokens_in_batches(sadness_tokens_test, sadness_test_cache_dict, batch_size=2)

#clear cache
torch.cuda.empty_cache()

other_emotions_test_cache_dict = {}
process_tokens_in_batches(other_emotions_tokens_test, other_emotions_test_cache_dict, batch_size=2)

  5%|▌         | 1/20 [00:00<00:02,  8.60it/s]

100%|██████████| 20/20 [00:03<00:00,  5.73it/s]
100%|██████████| 100/100 [00:16<00:00,  6.10it/s]


In [14]:
#clear cache
torch.cuda.empty_cache()


### Training probe

In [15]:
# easily configurable layer selection
LAYERS_TO_TRAIN = [20,25,30] 

# Directory to save the trained probe weights
PROBE_SAVE_DIR = "/workspace/MATS-research/probes"
os.makedirs(PROBE_SAVE_DIR, exist_ok=True)

# Hyperparameters (as you confirmed)
PROBE_HYPERPARAMS = {
    "epochs": 100,
    "lr": 1e-3,
    "weight_decay": 0.01, # This is the L2 regularization
    "batch_size": 32,
}

# The dimensionality of the residual stream from your model's config
d_model = model.cfg.d_model

In [16]:
def prepare_data(other_emotions_activations, sad_activations, batch_size):
    """
    Combines other emotions and sad activations, creates labels, and returns a DataLoader.
    CORRECTED LABELS: Label 1 for 'sadness' (minority class), 0 for 'other emotions' (majority class).
    """
    # Combine the activations into a single tensor
    all_activations = torch.cat([other_emotions_activations, sad_activations], dim=0)

    # Create corresponding labels - CORRECTED: sadness=1, other_emotions=0
    other_emotions_labels = torch.zeros(other_emotions_activations.shape[0])  # majority class = 0
    sad_labels = torch.ones(sad_activations.shape[0])  # minority class = 1
    all_labels = torch.cat([other_emotions_labels, sad_labels], dim=0)

    # Create a TensorDataset and DataLoader
    dataset = TensorDataset(all_activations, all_labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader


In [17]:
class LinearProbe(nn.Module):
    """A simple linear probe that maps activations to a single logit."""
    def __init__(self, d_model: int):
        super().__init__()
        # As confirmed, a simple linear layer with no bias.
        # It maps the d_model-dimensional activation to a 1D logit.
        self.probe = nn.Linear(d_model, 1, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Input shape: [batch, d_model], Output shape: [batch]"""
        return self.probe(x).squeeze(-1)

In [18]:


# --- Training and Evaluation Functions ---

def train_probe(probe, dataloader, epochs, lr, weight_decay, batch_size, device, class_weights=None):
    """Trains a single linear probe with optional class weighting for imbalanced data."""
    probe.to(device)
    probe.train()

    # The paper uses L2 regularization, which is equivalent to `weight_decay` in AdamW
    optimizer = torch.optim.AdamW(probe.parameters(), lr=lr, weight_decay=weight_decay)
    
    # BCEWithLogitsLoss with class weights to handle imbalance
    if class_weights is not None:
        # class_weights should be a tensor with weight for positive class
        pos_weight = torch.tensor([class_weights], device=device, dtype=torch.float32)
        loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    else:
        loss_fn = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        for activations, labels in dataloader:
            activations = activations.to(device).to(torch.float32)
            labels = labels.to(device)


            logits = probe(activations)
            loss = loss_fn(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    return probe

def evaluate_probe(probe, dataloader, device):
    """Evaluates the probe's accuracy on a given dataset."""
    probe.to(device)
    probe.eval()
    
    correct = 0
    total = 0
    with torch.no_grad():
        for activations, labels in dataloader:
            activations = activations.to(device).to(torch.float32)
            labels = labels.to(device)

            
            logits = probe(activations)
            # A logit > 0 corresponds to a predicted probability > 0.5
            predictions = (logits > 0).long()
            
            correct += (predictions == labels).sum().item()
            total += labels.shape[0]
            
    accuracy = correct / total
    return accuracy

In [19]:
# Dictionaries to store trained probes and test accuracies
trained_probes = {}
test_accuracies = {}

print("Starting probe training and evaluation...")
print("-" * 50)

for layer in LAYERS_TO_TRAIN:
    hook_name = f"blocks.{layer}.hook_resid_post"
    print(f"Processing Layer {layer} (Hook: {hook_name})")

    # 1. Prepare Data for the current layer
    # Training data
    other_emotions_train_acts = other_emotions_cache_dict[hook_name]
    sad_train_acts = sadness_cache_dict[hook_name]
    train_loader = prepare_data(other_emotions_train_acts, sad_train_acts, PROBE_HYPERPARAMS['batch_size'])

    # Testing data
    other_emotions_test_acts = other_emotions_test_cache_dict[hook_name]
    sad_test_acts = sadness_test_cache_dict[hook_name]
    test_loader = prepare_data(other_emotions_test_acts, sad_test_acts, PROBE_HYPERPARAMS['batch_size'])
    
    # 2. Initialize and Train the Probe
    probe = LinearProbe(d_model)
    probe = train_probe(
        probe,
        train_loader,
        device=device,
        **PROBE_HYPERPARAMS, 
        class_weights=CLASS_WEIGHT
    )
    
    # 3. Evaluate the Probe
    accuracy = evaluate_probe(probe, test_loader, device=device)
    
    # 4. Store Results and Save Weights
    trained_probes[layer] = probe
    test_accuracies[layer] = accuracy
    
    probe_save_path = os.path.join(PROBE_SAVE_DIR, f"probe_layer_{layer}.pt")
    torch.save(probe.state_dict(), probe_save_path)
    
    print(f"Layer {layer}: Test Accuracy = {accuracy:.4f}")
    print(f"Probe weights saved to: {probe_save_path}\n")

print("-" * 50)
print("All probes trained and evaluated.")

Starting probe training and evaluation...
--------------------------------------------------
Processing Layer 20 (Hook: blocks.20.hook_resid_post)


Layer 20: Test Accuracy = 0.9833
Probe weights saved to: /workspace/MATS-research/probes/probe_layer_20.pt

Processing Layer 25 (Hook: blocks.25.hook_resid_post)
Layer 25: Test Accuracy = 0.9875
Probe weights saved to: /workspace/MATS-research/probes/probe_layer_25.pt

Processing Layer 30 (Hook: blocks.30.hook_resid_post)
Layer 30: Test Accuracy = 0.9792
Probe weights saved to: /workspace/MATS-research/probes/probe_layer_30.pt

--------------------------------------------------
All probes trained and evaluated.


In [None]:

# # --- Example of Loading a Probe ---
# print("\n--- Example of loading a saved probe ---")
# layer_to_load = 15
# if layer_to_load in LAYERS_TO_TRAIN:
#     # 1. Create a new probe instance
#     loaded_probe = LinearProbe(d_model)

#     # 2. Load the saved state dictionary
#     saved_weights_path = os.path.join(PROBE_SAVE_DIR, f"probe_layer_{layer_to_load}.pt")
#     loaded_probe.load_state_dict(torch.load(saved_weights_path))

#     # 3. You can now use this probe for evaluation or inference
#     loaded_probe.to(device)
#     # Re-create the test loader to verify
#     happy_test_acts = happiness_test_cache_dict[f"blocks.{layer_to_load}.hook_resid_pre"]
#     sad_test_acts = sadness_test_cache_dict[f"blocks.{layer_to_load}.hook_resid_pre"]
#     test_loader_for_verify = prepare_data(happy_test_acts, sad_test_acts, PROBE_HYPERPARAMS['batch_size'])

#     loaded_probe_accuracy = evaluate_probe(loaded_probe, test_loader_for_verify, device=device)

#     print(f"Successfully loaded probe for layer {layer_to_load}.")
#     print(f"Original test accuracy: {test_accuracies[layer_to_load]:.4f}")
#     print(f"Loaded probe accuracy:  {loaded_probe_accuracy:.4f}")
# else:
#     print(f"Layer {layer_to_load} was not in the list of layers to train.")

### Steering with probe

In [21]:
torch.cuda.empty_cache()

In [20]:
trained_probes

{20: LinearProbe(
   (probe): Linear(in_features=4096, out_features=1, bias=False)
 ),
 25: LinearProbe(
   (probe): Linear(in_features=4096, out_features=1, bias=False)
 ),
 30: LinearProbe(
   (probe): Linear(in_features=4096, out_features=1, bias=False)
 )}

In [22]:
neutral_prompts = [
    "I just learned about quantum computing—can you tell me more about it in an exciting way??",
    "I’m so happy with my new laptop! Can you suggest some must-have apps for productivity?",
    "I’m fascinated by robotics! Can you share some fun facts about humanoid robots?"
]
#convert to convsersations
neutral_conversations = [[{"role": "user", "content": prompt}] for prompt in neutral_prompts]

In [28]:
def user_teering_hook(resid_pre, hook, steering_vector, coeff):

    if resid_pre.shape[1] == 1:
        return resid_pre
    
    batch_size, seq_len, d_model = resid_pre.shape
    steering_seq_len = steering_vector.shape[0]
    print('steering taking place!!!')
    resid_pre[:, -1, :] += coeff * steering_vector

In [30]:
layer = 30
coeff = 10
sampling_kwargs = dict(temperature=1.0, top_p=0.3, freq_penalty=1.0)
steering_vector = trained_probes[layer].probe.weight.data.squeeze()

steering_vector = steering_vector / torch.norm(steering_vector)

# set norm of steering vector to 1
# steering_vector = steering_vector / torch.norm(steering_vector)

model.tokenizer.padding_side = 'left'

# Tokenize conversations
inputs = model.tokenizer.apply_chat_template(
    neutral_conversations, 
    add_generation_prompt=True, 
    padding=True, 
    return_tensors="pt"
).to(device)

hook_fn = partial(user_teering_hook, steering_vector=steering_vector, coeff=coeff)

steering_hooks = [(f"blocks.{layer}.hook_resid_post", hook_fn)]

# Generate with steering
with model.hooks(fwd_hooks=steering_hooks):
    outputs = model.generate(
        inputs, 
        max_new_tokens=150, 
        do_sample=True,
        **sampling_kwargs
    )

def format_output(out):
    formatted_outputs = model.to_string(out)
    print(("\n\n" + "-" * 80 + "\n\n").join(formatted_outputs))

format_output(outputs)



  1%|▏         | 2/150 [00:00<00:09, 16.31it/s]

steering taking place!!!


  3%|▎         | 5/150 [00:00<00:07, 19.36it/s]

100%|██████████| 150/150 [00:07<00:00, 20.10it/s]

<|begin_of_text|><|start_header_id|>user<|end_header_id|>

I just learned about quantum computing—can you tell me more about it in an exciting way??<|eot_id|><|start_header_id|>assistant<|end_header_id|>

WELCOME TO THE FUTURE OF COMPUTING!

Quantum computing is like a superpower for computers, allowing them to solve problems that were previously unsolvable or would take an eternity to solve on classical computers. It's like having a magic calculator that can perform calculations faster and more efficiently than anything we've ever seen before!

Imagine being able to:

* Crack complex codes in seconds
* Simulate the behavior of molecules and materials with unprecedented accuracy
* Optimize complex systems, like traffic flow or supply chains, in real-time
* Break through the limits of machine learning and artificial intelligence

Quantum computers use "qubits" (quantum bits) instead of classical bits (0s and 1s). Qubits are special because they

-----------------------------------------




In [25]:


conv = [{'role': 'user', 'content': 'cheese'}]

tokens = model.tokenizer.apply_chat_template(
    conv,
    add_generation_prompt=True,
    padding=True,
    return_tensors="pt"
).to(device)

#print each token in a new line
for token in tokens[0]:
    print(model.tokenizer.decode(token))

<|begin_of_text|>
<|start_header_id|>
user
<|end_header_id|>



che
ese
<|eot_id|>
<|start_header_id|>
assistant
<|end_header_id|>





In [39]:
torch.linspace(0, 1, 10)

tensor([0.0000, 0.1111, 0.2222, 0.3333, 0.4444, 0.5556, 0.6667, 0.7778, 0.8889,
        1.0000])

In [40]:
v = torch.linspace(0, 1, 10)

v[-4]

tensor(0.6667)

In [None]:
#compare my notebook with empathic machines notebook and repo