In [None]:
import einops
import torch
import plotly.express as px
from datasets import load_dataset
from data import create_prompt
from copy import deepcopy

from nnsight import LanguageModel
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

In [None]:
model = LanguageModel("openai-community/gpt2", device_map="cuda:0")
ds = load_dataset("nnheui/understanding-index-operation-v0.1")['train']
ds = ds.filter(lambda e: (e['query_array_length'] == 10) and (e['num_few_shots'] == 1) and (e['query'] == 9))
ds = ds.map(create_prompt)

In [None]:
X = model.tokenizer(ds['prompt'])
tokenized_data = ds.map(lambda example: model.tokenizer(example["prompt"], truncation=True))
tokenized_data = tokenized_data.remove_columns(['array', 'query', 'target', 'few_shots', 'query_array_length', 'num_few_shots', 'task_prompt', 'prompt',])

data_collator = DataCollatorWithPadding(tokenizer=model.tokenizer)
dl = DataLoader(tokenized_data, batch_size=32, shuffle=False, collate_fn=data_collator)

with torch.no_grad():
    all_preds = []
    for sample in tqdm(dl):
        # prompt = sample['prompt']
        corrupted_preds = []
        # with model.generate(prompt, max_new_tokens=1):
        #     orig_output = model.generator.output.save()
        # orig_output = orig_output[0][-1].item()
        for i in range(12, 28):
            with model.generate(sample, max_new_tokens=1):
                # emb = model.transformer.wte.output.save()
                # noise = torch.randn([1, emb.shape[-1]])
                # emb[:, i, :] += noise
                # model.transformer.wte.output = emb
                # output = model.generator.output.save()
                emb = model.model.embed_tokens.output.save()
                noise = torch.randn([emb.shape[0], emb.shape[-1]], device=emb.device, dtype=emb.dtype)
                emb[:, i, :] += noise
                model.model.embed_tokens.output = emb
                output = model.generator.output.save()
            corrupted_preds.append(output[0][-1].item())
        # orig_preds, *corrupted_preds = model.tokenizer.batch_decode([orig_output] + corrupted_preds)
        all_preds.append(corrupted_preds)
            # print(model.tokenizer.decode(output[0]))
# torch.save(all_preds, "corrupted_l5_s0_i3_gpt2.pkl")