In [11]:
import random
from pathlib import Path

import torch
from transformer_lens import HookedTransformer


from collections import defaultdict
from tqdm.auto import tqdm


import pandas as pd
from IPython.display import display
from functools import partial
from tqdm.auto import tqdm


# Obtaining user model vector

In [2]:
# Now, the script to load and split the data
def load_and_split_data(filepath, train_split=0.8, seed=42):
    """Loads a text file, shuffles, and splits it into training and test sets."""
    with open(filepath, 'r', encoding='utf-8') as f:
        lines = [line.strip() for line in f if line.strip()]
    
    random.seed(seed)
    random.shuffle(lines)
    
    split_index = int(len(lines) * train_split)
    train_lines = lines[:split_index]
    test_lines = lines[split_index:]
    
    return train_lines, test_lines


In [3]:
# Define the file paths
happy_filepath = '/workspace/MATS-research/data/emotion_user_prompts/happiness.txt'
sad_filepath = '/workspace/MATS-research/data/emotion_user_prompts/sadness.txt' # Assuming you have a sadness.txt file

# Load and split both datasets
happy_train, happy_test = load_and_split_data(happy_filepath)
sad_train, sad_test = load_and_split_data(sad_filepath)

print(f"Happy train samples: {len(happy_train)}")
print(f"Happy test samples: {len(happy_test)}")
print(f"Sad train samples: {len(sad_train)}")
print(f"Sad test samples: {len(sad_test)}")
print("\nFirst happy training sample:", happy_train[0])
print("First sad training sample:", sad_train[0])


Happy train samples: 404
Happy test samples: 101
Sad train samples: 425
Sad test samples: 107

First happy training sample: Describe a perfect picnic in a meadow.
First sad training sample: Help me plan a day of self-compassion and rest


In [6]:

# Set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load the model and tokenizer
model = HookedTransformer.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct",
    device=device,
    torch_dtype=torch.bfloat16 # Use bfloat16 to save memory
)

Using device: cuda


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 75.98it/s]


Loaded pretrained model meta-llama/Meta-Llama-3-8B-Instruct into HookedTransformer


In [8]:

def get_last_token_activations(model, tokenizer, prompts, layers, batch_size=32):
    """
    Extracts residual stream activations at the last token position for specified layers.
    """
    # Ensure the tokenizer has a padding token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    activations = defaultdict(list)
    
    for i in tqdm(range(0, len(prompts), batch_size), desc="Processing Batches"):
        batch_prompts = prompts[i:i+batch_size]
        
        # Tokenize the batch with left padding
        tokenizer.padding_side = "left"
        tokens = tokenizer(batch_prompts, return_tensors="pt", padding=True)
        
        # We need the sequence lengths to find the last token of each prompt
        seq_lengths = (tokens.input_ids != tokenizer.pad_token_id).sum(dim=1) - 1
        
        # Define hook names
        hook_names = [f"blocks.{layer}.hook_resid_post" for layer in layers]
        
        with torch.no_grad():
            _, cache = model.run_with_cache(
                tokens.input_ids.to(model.cfg.device), 
                names_filter=lambda name: name in hook_names
            )

        # For each layer, extract the activation at the last token position
        for layer in layers:
            hook_name = f"blocks.{layer}.hook_resid_post"
            # Get activations for the current batch and move to CPU
            layer_activations = cache[hook_name].cpu()
            
            # For each prompt in the batch, get the activation of its last token
            for j, length in enumerate(seq_lengths):
                last_token_activation = layer_activations[j, length, :]
                activations[layer].append(last_token_activation)

    # Stack the lists of tensors for each layer
    for layer in activations:
        activations[layer] = torch.stack(activations[layer])
        
    return activations


In [9]:
# Specify the target layers
target_layers = list(range(15, 26))

# Extract activations for both sad and happy training sets
print("Extracting sad activations...")
sad_activations = get_last_token_activations(model, model.tokenizer, sad_train, target_layers)

print("\nExtracting happy activations...")
happy_activations = get_last_token_activations(model, model.tokenizer, happy_train, target_layers)

print(f"\nFinished. Example shape for layer 15 (sad): {sad_activations[15].shape}")
print(f"Finished. Example shape for layer 15 (happy): {happy_activations[15].shape}")

Extracting sad activations...


Processing Batches: 100%|██████████| 14/14 [00:03<00:00,  4.03it/s]



Extracting happy activations...


Processing Batches: 100%|██████████| 13/13 [00:02<00:00,  4.79it/s]


Finished. Example shape for layer 15 (sad): torch.Size([425, 4096])
Finished. Example shape for layer 15 (happy): torch.Size([404, 4096])





In [10]:
def compute_persona_vectors(sad_activations, happy_activations, layers):
    """
    Computes the persona vector for each layer by taking the difference
    of the mean activations (sad - happy).
    """
    persona_vectors = {}
    for layer in layers:
        mean_sad_vec = sad_activations[layer].mean(dim=0)
        mean_happy_vec = happy_activations[layer].mean(dim=0)
        persona_vectors[layer] = mean_sad_vec - mean_happy_vec
    return persona_vectors

# Compute the sadness vectors
sadness_vectors = compute_persona_vectors(sad_activations, happy_activations, target_layers)

print("Persona vectors computed for all target layers.")
print(f"Example vector shape for layer 15: {sadness_vectors[15].shape}")
print(f"Norm of layer 15 vector: {torch.linalg.norm(sadness_vectors[15]).item():.4f}")
print(f"Norm of layer 25 vector: {torch.linalg.norm(sadness_vectors[25]).item():.4f}")

Persona vectors computed for all target layers.
Example vector shape for layer 15: torch.Size([4096])
Norm of layer 15 vector: 13.8750
Norm of layer 25 vector: 14.1875


# Validating user model vector