In [215]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt
from tqdm import tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
model_name = 'gpt2'
# model_name = 'microsoft/phi-2'

# prompt management

In [5]:
import json

def load_json_dataset(json_path):
    with open(json_path) as file:
        dataset = json.load(file)
    return dataset

dataset = load_json_dataset('../data/antonym.json')
dataset = list(map(lambda x: tuple(x.values()), dataset))
print(f'dataset len: {len(dataset)}')

dataset len: 2398


In [264]:
mm = AutoModelForCausalLM.from_pretrained('microsoft/phi-2', trust_remote_code = True)

A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-2:
- configuration_phi.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-2:
- modeling_phi.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
Downloading shards: 100%|██████████| 2/2 [00:00<00:00,  2.17it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:13<00:00,  6.80s/it]


In [273]:
model.transformer.h[0].attn

GPT2AttentionAltered(
  (c_attn): Conv1D()
  (c_proj): Conv1D()
  (attn_dropout): Dropout(p=0.1, inplace=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
  (query): WrapperModule()
  (key): WrapperModule()
  (value): WrapperModule()
)

In [274]:
mm.transformer.h[0].mixer.out_proj

Linear(in_features=2560, out_features=2560, bias=True)

# models

In [213]:
import sys
sys.path.append('..')
torch.set_grad_enabled(False)

from src.utils.model_utils import load_gpt_model_and_tokenizer, set_seed
from src.extraction import get_mean_activations
from src.utils.prompt_helper import tokenize_ICL
from src.intervention import replace_heads_w_avg, compute_indirect_effect
set_seed(32)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
if not tokenizer.pad_token_id:
    tokenizer.pad_token_id = tokenizer.eos_token_id

model, config = load_gpt_model_and_tokenizer(model_name)

In [260]:
# select number of ICL examples
ICL_examples = 6

###### TODO: REMEMBER TO SPLIT IN TRAIN and TEST####
tok_ret, ids_ret, correct_labels = tokenize_ICL(tokenizer, ICL_examples = ICL_examples, dataset = dataset[:200])

___


In [261]:
mean_activations = get_mean_activations(
    tokenized_prompts=tok_ret,
    important_ids=ids_ret,
    tokenizer=tokenizer,
    model=model,
    config=config,
    correct_labels=correct_labels,
)
mean_activations.shape

extracting activations: 100%|██████████| 28/28 [00:11<00:00,  2.39it/s]

Model accuracy is 0, mean_activations cannot be computed





AttributeError: 'NoneType' object has no attribute 'shape'

compute indirect effect

In [31]:
tokenized = tokenizer('The capital of Italy is', return_tensors='pt')['input_ids']
important_ids = np.arange(len(tokenized[0]))       # tutto il prompt

# simulating batch 2
tokenized = torch.vstack([tokenized, tokenized])
important_ids = np.vstack([important_ids, important_ids])

b = replace_heads_w_avg(
    tokenized_prompt=tokenized,
    important_ids=[important_ids],
    layers_heads=[(9,x) for x in range(0,12)],       # 9th head of every layer
    avg_activations=[
        torch.zeros(size=(len(tokenized[0]), config['d_model'] // config['n_heads'])) 
        for _ in range(0, 12)
    ],      # zero-out layer
    model=model,
    config=config,
)
# first batch
print(f'B: Predicted token id {b[0].argmax()}, wich corresponds to "{tokenizer.decode(b[0].argmax())}" [prob.: {b[0][b[0].argmax()]:.3f}]')

B: Predicted token id 262, wich corresponds to " the" [prob.: 0.077]


In [32]:
cie, probs_original, probs_edited  = compute_indirect_effect(
    model=model,
    tokenizer=tokenizer,
    config=config,
    dataset=dataset, 
    mean_activations=mean_activations,
    ICL_examples = ICL_examples,
    batch_size=20,
)

total prompts: 479


Processing edited model (l: 11, h: 11): 100%|██████████| 24/24 [20:58<00:00, 52.43s/it]


In [35]:
torch.save(cie, './../output/cie.pt')

In [36]:
import plotly.express as px

fig = px.imshow(cie.mean(dim=0))
fig.show()