In [2]:
%cd /home/jnainani_umass_edu/codellm/MechInterpCodeLLMs

/home/jnainani_umass_edu/codellm/MechInterpCodeLLMs


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from scipy.stats import linregress
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
import os
import time
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional, Callable
from typing_extensions import Literal
from functools import partial
import copy
import itertools
import json

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets

# THIS IS A LOCAL (MODIFIED) VERSION OF TRANSFORMER_LENS - UNINSTALL PIP/CONDA VERSION BEFORE USE!
import transformer_lens
import transformer_lens.utils as utils
import transformer_lens.patching as patching
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import (
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)

import os
from config import HF_TOKEN, HF_PATH

os.environ["HF_TOKEN"] = HF_TOKEN
os.environ["TRANSFORMERS_CACHE"] = HF_PATH
os.environ["HF_DATASETS_CACHE"] = HF_PATH
os.environ["HF_HOME"] = HF_PATH

# When using multiple GPUs we use GPU 0 as the primary and switch to the next when it is 90% full
num_gpus = torch.cuda.device_count()
device_id = 0
if num_gpus > 0:
    device = "cuda:0"
else:
    device = "cpu"
    
def check_gpu_memory(max_alloc=0.9):
    if not torch.cuda.is_available():
        return
    global device_id, device
    print("Primary device:", device)
    torch.cuda.empty_cache()
    max_alloc = 1 if max_alloc > 1 else max_alloc
    for gpu in range(num_gpus):
        memory_reserved = torch.cuda.memory_reserved(device=gpu)
        memory_allocated = torch.cuda.memory_allocated(device=gpu)
        total_memory = torch.cuda.get_device_properties(gpu).total_memory 
        print(f"GPU {gpu}: {total_memory / (1024**2):.2f} MB  Allocated: {memory_allocated / (1024**2):.2f} MB  Reserved: {memory_reserved / (1024**2):.2f} MB")
                
        # Check if the current GPU is getting too full, and if so we switch the primary device to the next GPU
        if memory_reserved > max_alloc * total_memory:
            if device_id < num_gpus - 1:
                device_id += 1
                device = f"cuda:{device_id}"
                print(f"Switching primary device to {device}")
            else:
                print("Cannot switch primary device, all GPUs are nearly full")

print("Number of GPUs:", num_gpus)
check_gpu_memory()

def timeit(func):
    """Decorator to measure the execution time of a function."""
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"Function {func.__name__!r} executed in {end_time - start_time:.4f} seconds")
        return result
    return wrapper

# Create transformer
model = HookedTransformer.from_pretrained("CodeLlama-7b-hf", n_devices=num_gpus)

# We need these so that individual attention heads and MLP inputs can be edited
model.set_use_attn_in(True)
model.set_use_hook_mlp_in(True)
model.set_use_attn_result(True) # Documentation says this easily burns through GPU memory

check_gpu_memory()

Number of GPUs: 1
Primary device: cuda:0
GPU 0: 81050.62 MB  Allocated: 0.00 MB  Reserved: 0.00 MB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model CodeLlama-7b-hf into HookedTransformer
Primary device: cuda:0
GPU 0: 81050.62 MB  Allocated: 26348.48 MB  Reserved: 26370.00 MB


In [4]:
with open('data/info_retrieval/instructed_trial3.json', 'r') as f:
    data = json.load(f)

# Take the first few dictionaries (e.g., first 3)
subset = data[:150]

# Initialize lists
prompts = []
answers = []

# Extract prompts and outputs
outputs = []
for ind, item in enumerate(subset):
    # if ind in [0, 1, 3, 4, 5, 6, 7, 9]:
    # print(()
    if model.to_tokens(item["prompt"]).shape[-1] == 40:
        prompts.append(item["prompt"])
        outputs.append(item["output"])

# Group pairs of outputs and reverse the tuples alternately
for i in range(0, len(outputs) - 1, 2):
    answers.append((outputs[i], outputs[i + 1]))
    answers.append((outputs[i + 1], outputs[i]))

# Display the results
print("Prompts:", prompts)
print("Answers:", answers)
print("Prompts:", len(prompts))
print("Answers:", len(answers))
 
no_of_prompts = len(prompts)
clean_tokens = model.to_tokens(prompts)
# Swap each adjacent pair, with a hacky list comprehension
corrupted_tokens = clean_tokens[
    [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]
    ]
print("Clean string 0", model.to_string(clean_tokens[0]))
print("Corrupted string 0", model.to_string(corrupted_tokens[0]))
check_gpu_memory()

answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=device)
print("Answer token indices", answer_token_indices)
check_gpu_memory()

Prompts: [" D = 'rose'  J = 'document'  A = 'ring'  E = 'cup'  The name of the key that has the value 'rose' is ", " D = 'rose'  J = 'document'  A = 'ring'  E = 'cup'  The name of the key that has the value 'document' is ", " D = 'rose'  J = 'document'  A = 'ring'  E = 'cup'  The name of the key that has the value 'ring' is ", " D = 'rose'  J = 'document'  A = 'ring'  E = 'cup'  The name of the key that has the value 'cup' is ", " H = 'clock'  G = 'disk'  J = 'hat'  E = 'key'  The name of the key that has the value 'clock' is ", " H = 'clock'  G = 'disk'  J = 'hat'  E = 'key'  The name of the key that has the value 'disk' is ", " H = 'clock'  G = 'disk'  J = 'hat'  E = 'key'  The name of the key that has the value 'hat' is ", " H = 'clock'  G = 'disk'  J = 'hat'  E = 'key'  The name of the key that has the value 'key' is ", " K = 'sheet'  B = 'bell'  F = 'game'  D = 'boot'  The name of the key that has the value 'sheet' is ", " K = 'sheet'  B = 'bell'  F = 'game'  D = 'boot'  The name 

In [5]:
_, clean_cache = model.to("cpu").run_with_cache(clean_tokens)
# clean_cache = clean_cache.to("cpu")
check_gpu_memory()

# # Process the first prompt
# answer_token_indices_first = answer_token_indices[0:1]
# clean_value, clean_cache_first, clean_grad_cache_first = get_cache_fwd_and_bwd(model, clean_tokens[0:1], ioi_metric, answer_token_indices_first)
# clean_cache_first = clean_cache_first.to('cpu')
# clean_grad_cache_first = clean_grad_cache_first.to('cpu')
# check_gpu_memory()

# delete_variable('clean_value')
# get_memory_usage()
# # delete_variable('clean_cache_first')
# # delete_variable('clean_grad_cache_first')

# corrupted_value, corrupted_cache_first, corrupted_grad_cache_first = get_cache_fwd_and_bwd(model, corrupted_tokens[0:1], ioi_metric, answer_token_indices_first)
# corrupted_cache_first = corrupted_cache_first.to('cpu')
# corrupted_grad_cache_first = corrupted_grad_cache_first.to('cpu')
# check_gpu_memory()

# delete_variable('corrupted_value')
# # delete_variable('corrupted_cache_first')a
# # delete_variable('corrupted_grad_cache_first')

# for i in range(1, len(clean_tokens)):
#     single_clean_tokens = clean_tokens[i:i+1]


#     clean_value, clean_cache, clean_grad_cache = model.run_with_cah(model, single_clean_tokens, ioi_metric, single_answer_token_indices)
#     clean_cache = clean_cache.to('cpu')
#     clean_cache_first = clean_cache_first.concatenate(clean_cache)
#     check_gpu_memory()

#     delete_variable('clean_cache')
#     get_memory_usage()

#     print("CURRENT INDEX: ", i)

# torch.cuda.empty_cache()

# # Test if shapes worked
# print("cache shape: ,", clean_cache_first["hook_embed"].shape)
# check_gpu_memory()

Moving model to device:  cpu
Primary device: cuda:0
GPU 0: 81050.62 MB  Allocated: 0.02 MB  Reserved: 2.00 MB


In [6]:
# Specify the file name
file_name = 'circuits/codellama/infoRet_30prompts_data3.json'

# Read the dictionary from the JSON file
with open(file_name, 'r') as json_file:
    loaded_circuit_dictionary = json.load(json_file)

# Print the loaded dictionary to verify its contents
print(loaded_circuit_dictionary)

keys_comps = []
for i in range(len(list(loaded_circuit_dictionary['labels']))):
    ind = list(loaded_circuit_dictionary['indices'])[i][1]
    layer_ind = ind//32
    head_ind = ind%32
    # key = f'blocks.{layer_ind}.attn.hook_result'
    keys_comps.append((layer_ind, head_ind))

{'indices': [[1200, 358], [1200, 366], [1211, 366], [1224, 358], [1225, 358], [1225, 366], [1226, 366], [1416, 358], [1416, 366], [1416, 400], [1416, 432], [1580, 485], [1632, 432], [1632, 485], [1632, 517], [1632, 532], [1671, 485], [1671, 517], [1698, 485], [1758, 432], [1758, 485], [1758, 517], [1818, 432], [1818, 448], [1818, 472], [1818, 475], [1818, 482], [1818, 485], [1818, 490], [1818, 494], [1818, 515], [1818, 517], [1818, 526], [1818, 532], [1818, 543], [1818, 544], [1818, 552], [1818, 557], [1818, 562], [1818, 572], [1869, 606], [1950, 485], [1950, 517]], 'labels': [['L12H16Q', 'L11H6'], ['L12H16Q', 'L11H14'], ['L12H19V', 'L11H14'], ['L12H24Q', 'L11H6'], ['L12H24K', 'L11H6'], ['L12H24K', 'L11H14'], ['L12H24V', 'L11H14'], ['L14H24Q', 'L11H6'], ['L14H24Q', 'L11H14'], ['L14H24Q', 'L12H16'], ['L14H24Q', 'L13H16'], ['L16H14V', 'L15H5'], ['L17H0Q', 'L13H16'], ['L17H0Q', 'L15H5'], ['L17H0Q', 'L16H5'], ['L17H0Q', 'L16H20'], ['L17H13Q', 'L15H5'], ['L17H13Q', 'L16H5'], ['L17H22Q', 'L1

In [7]:
# get mean activations 
mean_activations = {}
for l in range(model.cfg.n_layers):
    key = f'blocks.{l}.attn.hook_result'
    mean_activations[l] = torch.sum(clean_cache[key], dim=0).to("cpu")/no_of_prompts #.shape
check_gpu_memory()

Primary device: cuda:0
GPU 0: 81050.62 MB  Allocated: 0.02 MB  Reserved: 2.00 MB


In [16]:
def ablate_setter(corrupted_activation, layer_ind, clean_activation, mean_activation=mean_activations,  keys_comps=keys_comps):
    for hed in range(32):
        if (layer_ind, hed) not in keys_comps:
            mean_activation_broadcasted = mean_activation[layer_ind][:, head_ind, :].unsqueeze(0).expand(no_of_prompts, -1, -1)
            # Replace the values at the 3rd dimension of "hed" in new_cache[key]
            clean_activation[:, :, hed, :] = mean_activation_broadcasted
    return clean_activation

def ablating_hook(corrupted_activation, layer_ind, hook, clean_activation, mean_activation=mean_activations,  keys_comps=keys_comps):
    return ablate_setter(corrupted_activation, layer_ind, clean_activation, mean_activation=mean_activations,  keys_comps=keys_comps)

hooks_tuple = []

for l in range(model.cfg.n_layers):
    key = f'blocks.{l}.attn.hook_result'
    
    # Make a partial copy of the original cache for the current key
    new_cache_temp = clean_cache[key].clone()
    
    current_hook = partial(
    ablating_hook,
    layer_ind=l,
    clean_activation = new_cache_temp
    )   
    hooks_tuple.append((key, current_hook))

print(type(hooks_tuple))
check_gpu_memory()

patched_logits = model.to("cpu").run_with_hooks(
             clean_tokens, fwd_hooks=hooks_tuple, bwd_hooks=None)
print(patched_logits)
check_gpu_memory()

<class 'list'>
Primary device: cuda:0
GPU 0: 81050.62 MB  Allocated: 0.02 MB  Reserved: 2.00 MB
Moving model to device:  cpu
tensor([[[ 1.4768e+00,  5.1630e+00, -5.5182e-01,  ...,  2.6307e+00,
           2.3592e+00,  2.3185e+00],
         [ 1.3779e+00,  6.1556e-01,  2.1970e+00,  ...,  1.7838e+00,
           1.0380e+00,  1.4624e+00],
         [ 2.2206e-01,  1.1108e+00,  1.2202e+00,  ...,  1.7956e+00,
          -6.2309e-01, -3.4731e-01],
         ...,
         [-6.3867e-01, -2.5225e+00, -1.2109e+00,  ...,  1.2411e+00,
          -2.5563e-01,  9.3845e-01],
         [-1.3879e+00,  2.6864e-01, -6.1564e-01,  ...,  1.0273e-01,
          -7.0560e-01,  7.6431e-01],
         [-4.4885e-01,  9.2017e-03, -6.6585e-01,  ..., -2.4200e-02,
          -4.9813e-01, -3.7237e-01]],

        [[ 1.4768e+00,  5.1630e+00, -5.5182e-01,  ...,  2.6307e+00,
           2.3592e+00,  2.3185e+00],
         [ 1.3779e+00,  6.1556e-01,  2.1970e+00,  ...,  1.7838e+00,
           1.0380e+00,  1.4624e+00],
         [ 2.2206e-

In [17]:
patched_logits.shape

torch.Size([24, 40, 32016])

In [18]:
model_answers = torch.argmax(patched_logits[:, -1, :], dim=-1)
for some_ind in range(len(model_answers)):
    print("Predicted: ", model.to_string(model_answers[some_ind]))
    print("Act answer: ", answers[some_ind][0])

Predicted:  permit
Act answer:  D
Predicted:  permit
Act answer:  J
Predicted:  permit
Act answer:  A
Predicted:  permit
Act answer:  E
Predicted:  permit
Act answer:  H
Predicted:  permit
Act answer:  G
Predicted:  permit
Act answer:  J
Predicted:  permit
Act answer:  E
Predicted:  permit
Act answer:  K
Predicted:  permit
Act answer:  B
Predicted:  permit
Act answer:  F
Predicted:  permit
Act answer:  D
Predicted:  permit
Act answer:  B
Predicted:  permit
Act answer:  H
Predicted:  permit
Act answer:  G
Predicted:  permit
Act answer:  C
Predicted:  permit
Act answer:  C
Predicted:  permit
Act answer:  B
Predicted:  permit
Act answer:  A
Predicted:  permit
Act answer:  E
Predicted:  permit
Act answer:  I
Predicted:  permit
Act answer:  H
Predicted:  permit
Act answer:  C
Predicted:  permit
Act answer:  A


In [11]:
answers

[('D', 'J'),
 ('J', 'D'),
 ('A', 'E'),
 ('E', 'A'),
 ('H', 'G'),
 ('G', 'H'),
 ('J', 'E'),
 ('E', 'J'),
 ('K', 'B'),
 ('B', 'K'),
 ('F', 'D'),
 ('D', 'F'),
 ('B', 'H'),
 ('H', 'B'),
 ('G', 'C'),
 ('C', 'G'),
 ('C', 'B'),
 ('B', 'C'),
 ('A', 'E'),
 ('E', 'A'),
 ('I', 'H'),
 ('H', 'I'),
 ('C', 'A'),
 ('A', 'C')]

1
1
1
0
