In [5]:
# Import stuff
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

# import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix


from typing import Dict, Union, List

In [6]:
device = utils.get_device()

In [4]:
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
model = HookedTransformer.from_pretrained(
    model_name,
    device=device,
    torch_dtype=torch.bfloat16,
)

model.eval()

`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 71.94it/s]


Loaded pretrained model meta-llama/Meta-Llama-3-8B-Instruct into HookedTransformer


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-31): 32 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_att

In [8]:
happines_path = "/workspace/MATS-research/data/emotion_user_prompts/happiness.txt"
sadness_path = "/workspace/MATS-research/data/emotion_user_prompts/sadness.txt"

with open(happines_path, "r") as f:
    happiness_prompts = f.readlines()

with open(sadness_path, "r") as f:
    sadness_prompts = f.readlines()

#remove \n from each line
happiness_prompts = [prompt.strip() for prompt in happiness_prompts]
sadness_prompts = [prompt.strip() for prompt in sadness_prompts]

In [9]:
happiness_prompts = happiness_prompts[:500]
sadness_prompts = sadness_prompts[:500]

#do an 80/20 test train split using sklearn
from sklearn.model_selection import train_test_split

happiness_prompts_train, happiness_prompts_test = train_test_split(
    happiness_prompts, test_size=0.2, random_state=42, shuffle=True
)

sadness_prompts_train, sadness_prompts_test = train_test_split(
    sadness_prompts, test_size=0.2, random_state=42, shuffle=True
)


In [10]:
happiness_conversations = [[{"role": "user", "content": prompt}] for prompt in happiness_prompts_train]
sadness_conversations = [[{"role": "user", "content": prompt}] for prompt in sadness_prompts_train]

happiness_conversations_test = [[{"role": "user", "content": prompt}] for prompt in happiness_prompts_test]
sadness_conversations_test = [[{"role": "user", "content": prompt}] for prompt in sadness_prompts_test]

In [11]:
# # Function to process conversations in batches
# def process_conversations_in_batches(conversations, batch_size=32):
#     """Process conversations in batches to avoid GPU memory issues"""
#     model.tokenizer.padding_side = 'left'
#     all_tokens = []
    
#     # Process in batches
#     for i in tqdm.tqdm(range(0, len(conversations), batch_size), desc="Tokenizing batches"):
#         batch_conversations = conversations[i:i+batch_size]
        
#         # Tokenize batch
#         batch_tokens = model.tokenizer.apply_chat_template(
#             batch_conversations,
#             add_generation_prompt=False,
#             padding=True,
#             return_tensors="pt"
#         )
        
#         all_tokens.append(batch_tokens)
    
#     # Concatenate all batches
#     return torch.cat(all_tokens, dim=0)

# # Process happiness and sadness conversations in batches
# print("Processing happiness conversations...")
# happiness_tokens = process_conversations_in_batches(happiness_conversations, batch_size=32)

# print("Processing sadness conversations...")  
# sadness_tokens = process_conversations_in_batches(sadness_conversations, batch_size=32)

#Convert to tokens. Make sure to apply left padding. No generation tag.
model.tokenizer.padding_side = 'left'

happiness_tokens = model.tokenizer.apply_chat_template(
    happiness_conversations,
    add_generation_prompt=False,
    padding=True,
    return_tensors="pt"
)

sadness_tokens = model.tokenizer.apply_chat_template(
    sadness_conversations,
    add_generation_prompt=False,
    padding=True,
    return_tensors="pt"
)

In [12]:
happiness_tokens.shape

torch.Size([400, 29])

In [13]:
from functools import partial

def final_token_resid_hook(activation, hook, cache_dict):
    # Get the final token's activation for each batch item
    # Shape: [batch_size, seq_len, d_model] -> [batch_size, d_model]
    final_token_activation = activation[:, -1, :]
    
    # Store in cache with the hook name as key, concatenating with existing data if present
    if hook.name in cache_dict:
        cache_dict[hook.name] = torch.cat([cache_dict[hook.name], final_token_activation.clone().detach()], dim=0)
    else:
        cache_dict[hook.name] = final_token_activation.clone().detach()

    return 

def process_tokens_in_batches(tokens, cache_dict, batch_size=8):
    """Process tokens through model in batches to avoid GPU memory issues"""
    
    # Create the hook function for this cache
    batch_final_token_resid_hook = partial(final_token_resid_hook, cache_dict=cache_dict)
    
    # Process in batches
    for i in tqdm.tqdm(range(0, tokens.shape[0], batch_size)):
        print('batch', i)
        batch_tokens = tokens[i:i+batch_size].to(device)
        
        # Run model with hooks on this batch
        with model.hooks(fwd_hooks=[ (lambda name: name.endswith("hook_resid_pre"), batch_final_token_resid_hook) ] ):
            _ = model(batch_tokens)
        
        # Move batch_tokens back to CPU and delete to free GPU memory
        del batch_tokens
        
        # Clear GPU cache to prevent memory buildup
        torch.cuda.empty_cache()

In [15]:
# Process happiness tokens in batches
happiness_cache_dict = {}
print("Processing happiness tokens through model...")
process_tokens_in_batches(happiness_tokens, happiness_cache_dict, batch_size=32)


Processing happiness tokens through model...


  0%|          | 0/13 [00:00<?, ?it/s]

batch 0





OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 23.69 GiB of which 16.81 MiB is free. Process 2562277 has 23.67 GiB memory in use. Of the allocated memory 22.90 GiB is allocated by PyTorch, and 473.73 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)