In [3]:
%load_ext autoreload
%autoreload 2

import torch
import transformers
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
model_name = 'gpt2'

# prompt management

In [5]:
import json

def load_json_dataset(json_path):
    with open(json_path) as file:
        dataset = json.load(file)
    return dataset

dataset = load_json_dataset('../data/antonym.json')
dataset = list(map(lambda x: tuple(x.values()), dataset))
print(f'dataset len: {len(dataset)}')

dataset len: 2398


# models

In [24]:
import sys
sys.path.append('..')
torch.set_grad_enabled(False)

from src.utils.model_utils import load_gpt_model_and_tokenizer, set_seed, rsetattr, rgetattr
from src.extraction import split_activation, extract_activations, get_mean_activations
from src.utils.prompt_helper import build_prompt_txt, tokenize_from_template, tokenize_ICL, randomize_dataset, pad_input
from src.intervention import replace_heads_w_avg, compute_indirect_effect
set_seed(32)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
if not tokenizer.pad_token_id:
    tokenizer.pad_token_id = tokenizer.eos_token_id

model, config = load_gpt_model_and_tokenizer(model_name)

In [None]:
# select number of ICL examples
ICL_examples = 4

###### TODO: REMEMBER TO SPLIT IN TRAIN and TEST####

tok_ret, ids_ret, labels = tokenize_ICL(tokenizer, ICL_examples = ICL_examples, dataset = dataset)
print(tokenizer.decode(tok_ret[0]))
print(labels[0])
print()
print(tok_ret[0])
print(ids_ret[0])

___


In [30]:
mean_activations, outputs = get_mean_activations(
    tokenized_prompts=tok_ret,
    important_ids=ids_ret,
    model=model,
    config=config,
)
mean_activations.shape

extracting activations:   0%|          | 0/479 [00:00<?, ?it/s]

extracting activations: 100%|██████████| 479/479 [02:33<00:00,  3.12it/s]


torch.Size([12, 12, 39, 64])

compute indirect effect

In [31]:
tokenized = tokenizer('The capital of Italy is', return_tensors='pt')['input_ids']
important_ids = np.arange(len(tokenized[0]))       # tutto il prompt

# simulating batch 2
tokenized = torch.vstack([tokenized, tokenized])
important_ids = np.vstack([important_ids, important_ids])

b = replace_heads_w_avg(
    tokenized_prompt=tokenized,
    important_ids=[important_ids],
    layers_heads=[(9,x) for x in range(0,12)],       # 9th head of every layer
    avg_activations=[
        torch.zeros(size=(len(tokenized[0]), config['d_model'] // config['n_heads'])) 
        for _ in range(0, 12)
    ],      # zero-out layer
    model=model,
    config=config,
)
# first batch
print(f'B: Predicted token id {b[0].argmax()}, wich corresponds to "{tokenizer.decode(b[0].argmax())}" [prob.: {b[0][b[0].argmax()]:.3f}]')

B: Predicted token id 262, wich corresponds to " the" [prob.: 0.077]


In [32]:
cie, probs_original, probs_edited  = compute_indirect_effect(
    model=model,
    tokenizer=tokenizer,
    config=config,
    dataset=dataset, 
    mean_activations=mean_activations,
    ICL_examples = 4,
    batch_size=20,
)

total prompts: 479


Processing edited model (l: 0, h: 1):   4%|▍         | 1/24 [00:54<20:27, 53.37s/it]  

In [None]:
import plotly.express as px

fig = px.imshow(cie.mean(dim=0))
fig.show()