In [111]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Some configs

In [113]:
MODEL_NAME = 'gpt2'
# MODEL_NAME = 'microsoft/phi-2'
# MODEL_NAME = 'EleutherAI/pythia-1B'
# MODEL_NAME = 'TinyLlama/TinyLlama-1.1B-Chat-v1.0'

LOAD_IN_8BIT = False
RELATIVE_PATH = '../'

### dataset

In [114]:
import json

def load_json_dataset(json_path):
    with open(json_path) as file:
        dataset = json.load(file)
    return dataset

dataset = load_json_dataset(f'{RELATIVE_PATH}data/following.json')
dataset = list(map(lambda x: tuple(x.values()), dataset))
print(f'dataset len: {len(dataset)}')

dataset len: 15


In [115]:
import sys
sys.path.append('..')
torch.set_grad_enabled(False)

from src.utils.model_utils import load_gpt_model_and_tokenizer, set_seed
from src.extraction import get_mean_activations
from src.utils.prompt_helper import tokenize_ICL
from src.intervention import compute_indirect_effect
set_seed(32)

In [117]:
model, tokenizer, config, device = load_gpt_model_and_tokenizer(MODEL_NAME, LOAD_IN_8BIT)

In [119]:
# select number of ICL examples (query excluded)
ICL_examples = 4
tok_ret, ids_ret, correct_labels = tokenize_ICL(tokenizer, ICL_examples = ICL_examples, dataset = dataset)

## Get activations and measure head's importance

In [121]:
mean_activations = get_mean_activations(
    tokenized_prompts=tok_ret,
    important_ids=ids_ret,
    tokenizer=tokenizer,
    model=model,
    config=config,
    correct_labels=correct_labels,
    device='mps'
)
torch.save(mean_activations, f'{RELATIVE_PATH}output/mean_activations_{MODEL_NAME.replace("/", "-")}.pt')
mean_activations.shape

[x] Extracting activations:   0%|          | 0/3 [00:00<?, ?it/s]

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
[x] Extracting activations: 100%|██████████| 3/3 [00:02<00:00,  1.37it/s]

Model accuracy: 1.00, using 3 example to compute mean activations





torch.Size([12, 12, 39, 64])

In [105]:
from src.utils.model_utils import rgetattr
prompt = tok_ret[0].to(model.device)

with model.generate(max_new_tokens=1, pad_token_id=tokenizer.pad_token_id) as generator:
    # invoke works in a generation context, where operations on inputs and outputs are tracked
    with generator.invoke(prompt) as invoker:
        pass
        # layer_attn_activations = []
        # for layer_name in config['attn_hook_names']:
            # layer_attn_activations.append(rgetattr(model, layer_name).output.save())

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`input_ids` in this case) have excessive nesting (inputs type `list` where type `int` is expected).

In [11]:
cie, probs_original, probs_edited  = compute_indirect_effect(
    model=model,
    tokenizer=tokenizer,
    config=config,
    dataset=dataset, 
    mean_activations=mean_activations,
    ICL_examples = ICL_examples,
    batch_size=15,
)
torch.save(cie, f'{RELATIVE_PATH}output/cie_{MODEL_NAME.replace("/", "-")}.pt')

total prompts: 479


Processing edited model (l: 11, h: 1):  62%|██████▎   | 20/32 [1:11:18<42:47, 213.92s/it]


KeyboardInterrupt: 

In [None]:
import plotly.express as px

fig = px.imshow(cie.mean(dim=0))
fig.show()

In [5]:
model = AutoModelForCausalLM.from_pretrained('EleutherAI/pythia-1b', trust_remote_code = True, device_map = 'mps')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-1b')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [92]:
inp = tok_ret[1].unsqueeze(0).to('mps')
out = model.generate(inp, max_new_tokens=1)


print(tokenizer.decode(out.squeeze()))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


<|endoftext|>Q:f
A:g

Q:g
A:h

Q:h
A:i

Q:i
A:j

Q:j
A:k
