In [1]:
from nnsight import LanguageModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
import bitsandbytes
import pandas as pd
from tqdm import tqdm
import numpy as np
import pickle
import sys

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = LanguageModel("mistralai/Mistral-7B-v0.1",  quantization_config=nf4_config, dispatch=True, device_map='auto', cache_dir="/mnt/align4_drive/arunas/") # Load the model

og = pd.read_csv('/mnt/align4_drive/arunas/broca/data-gen/ngs.csv')
og.columns

def get_prompt_from_df(filename):
    data = list(pd.read_csv(filename)['prompt'])
    data = [sentence.strip() for sentence in data]
    data = [sentence for sentence in data if not sentence == '']
    data = [sentence.replace('</s>', '\n') for sentence in data]
    golds = [sentence.strip().split("\n")[-1].strip().split('A:')[-1].strip() for sentence in data]
    data = [sentence[: -len(golds[idx])].strip() for idx, sentence in enumerate(data)]
    return data, golds

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00,  2.36s/it]


In [2]:
types = [col for col in list(og.columns) if not 'ng' in col]
sType = types[0]

prompts, golds = get_prompt_from_df(f'/mnt/align4_drive/arunas/broca/mistral/experiments/mistral-classification-train-test-det-{sType}.csv')

In [3]:
# ## For means later
# mean = {}
# nz_mean = {}
# for i in range(len(types)):
#     sType=types[i]
#     with open(f'/mnt/align4_drive/arunas/broca/mistral/mistral-attr-patch-scripts/attn/new-{sType}-all-neurons.pkl', 'rb') as f:
#         x = pickle.load(f)
#         x = x.cpu()
#         df = pd.DataFrame(x, index=range(32), columns=range(4096))
#         mean = np.mean(df.to_numpy())
#         if mean != 0:
#             print(sType, np.mean([val for row in df.to_numpy() for val in row if val != 0]), mean)
#         else:
#             print(sType, mean)
        

In [96]:
attn_layer_activation_cache = {}
mlp_layer_activation_cache = {}
promptLogits = []
shapes = []
BATCH_SIZE = 32
def pad_my_tensors(tensors):
    devices = [t.device for t in tensors]
    npArray1 = tensors[0].cpu()
    npArray2 = tensors[1].cpu()
    max_dim2 = max(npArray1.shape[1], npArray2.shape[1])
    pad_width = [ (0,0),
                  (0, max_dim2 - npArray1.shape[1]),
                  (0,0)
                ]
    npArray1 = np.pad(npArray1, pad_width, mode='constant', constant_values=0)
    pad_width = [ (0,0),
                  (0, max_dim2 - npArray2.shape[1]),
                  (0,0)
                ]
    npArray2 = np.pad(npArray2, pad_width, mode='constant', constant_values=0)
    return [torch.tensor(npArray1).to(devices[0]), torch.tensor(npArray2).to(devices[1])]

for i in range(0, len(prompts) // BATCH_SIZE):
    start = (i*BATCH_SIZE)
    end = min((i+1)*BATCH_SIZE, len(prompts))
    promptSet = []
    gs = []
    nGs = []
    logitDiffs = []
    for prompt,gold in tqdm(zip(prompts[start:end], golds[start:end])):
        try:
            if (gold == 'Yes'):
                predictionExample = prompt[prompt[:-2].rfind(':')+1:-2].strip()
                patch = og[og[sType] == predictionExample][f"ng-{sType}"].iloc[0]
                patchPrompt = prompt.replace(predictionExample, patch)
            else:
                patchPrompt = prompt
                patch = prompt[prompt[:-2].rfind(':')+1:-2].strip()
                predictionExample = og[og[f"ng-{sType}"] == patch][sType].iloc[0]
                prompt = patchPrompt.replace(patch, predictionExample)
                gold = "Yes"
    
            notGold = "No"
            promptSet.append(prompt)
            gold = model.tokenizer(gold)['input_ids']
            notGold = model.tokenizer(notGold)['input_ids']
            gs.append(gold)
            nGs.append(notGold)
        except:
            print(f"Error with stype: {sType} prompt: {prompt} gold: {gold}", traceback.format_exc())
            continue
    # print(len(promptSet), promptSet[0])
    with model.trace(promptSet, scan=False, validate=False) as tracer: 
        for layer in range(len(model.model.layers)):
            self_attn = model.model.layers[layer].self_attn.o_proj.output
            mlp = model.model.layers[layer].mlp.down_proj.output

            if (not layer in attn_layer_activation_cache):
                # print(self_attn.detach().save().shape, mlp.detach().save().shape)
                attn_layer_activation_cache[layer] = [self_attn.detach().save()]
                mlp_layer_activation_cache[layer] = [mlp.detach().save()]
            else:
                attn_layer_activation_cache[layer].append(self_attn.detach().save())
                mlp_layer_activation_cache[layer].append(mlp.detach().save())
        logits = (model.lm_head.output[:, -1, notGold] - model.lm_head.output[:, -1, gold]).detach().save()
        logitDiffs.append(logits)
        # shapes.append((model.lm_head.output[:, -1, notGold].detach().save(), model.lm_head.output[:, -1, gold].detach().save()))
    promptLogits.append(torch.cat([p[:, -1] for p in logitDiffs]))
    if (i > 0):
        for layer in range(len(model.model.layers)):
            attn_layer_activation_cache[layer] = pad_my_tensors(attn_layer_activation_cache[layer])
            attn_layer_activation_cache[layer] = [torch.sum(torch.stack(attn_layer_activation_cache[layer]), dim=0)]
            mlp_layer_activation_cache[layer] = pad_my_tensors(mlp_layer_activation_cache[layer])
            mlp_layer_activation_cache[layer] = [torch.sum(torch.stack(mlp_layer_activation_cache[layer]), dim=0)]

32it [00:00, 2835.91it/s]
32it [00:00, 3263.57it/s]
32it [00:00, 2918.92it/s]
32it [00:00, 2828.32it/s]
32it [00:00, 2982.15it/s]
32it [00:00, 2749.29it/s]
32it [00:00, 3182.25it/s]
32it [00:00, 2847.15it/s]
32it [00:00, 2813.26it/s]
32it [00:00, 3214.49it/s]
32it [00:00, 2754.31it/s]
32it [00:00, 2826.95it/s]


In [97]:
overall_attn_activation = torch.sum(torch.stack([attn_layer_activation_cache[l][0].detach().cpu().sum() for l in range(len(model.model.layers))]), dtype=torch.float64) / len(prompts)

In [98]:
overall_mlp_activation = torch.sum(torch.stack([mlp_layer_activation_cache[l][0].detach().cpu().sum() for l in range(len(model.model.layers))]), dtype=torch.float64)

In [102]:
[mlp_layer_activation_cache[l][0].detach().cpu().sum() for l in range(len(model.model.layers))]

[tensor(-12936., dtype=torch.float16),
 tensor(-inf, dtype=torch.float16),
 tensor(15280., dtype=torch.float16),
 tensor(13120., dtype=torch.float16),
 tensor(22896., dtype=torch.float16),
 tensor(6748., dtype=torch.float16),
 tensor(29952., dtype=torch.float16),
 tensor(inf, dtype=torch.float16),
 tensor(-19328., dtype=torch.float16),
 tensor(62976., dtype=torch.float16),
 tensor(-10816., dtype=torch.float16),
 tensor(-30256., dtype=torch.float16),
 tensor(48384., dtype=torch.float16),
 tensor(-53312., dtype=torch.float16),
 tensor(-1921., dtype=torch.float16),
 tensor(35488., dtype=torch.float16),
 tensor(inf, dtype=torch.float16),
 tensor(inf, dtype=torch.float16),
 tensor(43776., dtype=torch.float16),
 tensor(27488., dtype=torch.float16),
 tensor(inf, dtype=torch.float16),
 tensor(-inf, dtype=torch.float16),
 tensor(-inf, dtype=torch.float16),
 tensor(inf, dtype=torch.float16),
 tensor(-20032., dtype=torch.float16),
 tensor(inf, dtype=torch.float16),
 tensor(inf, dtype=torch.float1

In [61]:
promptLogits = torch.cat(promptLogits)
for layer in range(len(model.model.layers)):
    attn_layer_activation_cache[layer] = attn_layer_activation_cache[layer][0]
    attn_layer_activation_cache[layer] = torch.sum(attn_layer_activation_cache[layer], dim=0)
    mlp_layer_activation_cache[layer] = mlp_layer_activation_cache[layer][0]
    mlp_layer_activation_cache[layer] = torch.sum(mlp_layer_activation_cache[layer], dim=0)
    

overall_attn_activation = sum([attn_layer_activation_cache[layer].detach().cpu() for layer in range(len(model.model.layers))])
overall_mlp_activation = sum([mlp_layer_activation_cache[layer].detach().cpu() for layer in range(len(model.model.layers))])
    
overall_attn_activation /= len(prompts)
overall_mlp_activation /= len(prompts)

In [63]:
overall_attn_activation.shape, overall_mlp_activation.shape

(torch.Size([289, 4096]), torch.Size([289, 4096]))

In [None]:
* You have an average activation. 
* You apply this neuron by neuron.
    You run through your dataset of prompts, gather all the difference of the logit diffs and store the average difference of logit diffs pertaining to this neuron.
    Then you get the top 40 neurons based on this average.
* You get the logit diff, and compare it with the original logit diff and store the neuron if it's in the top 1% of neurons causing this diff.