In [1]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer

### Some configs

In [3]:
MODEL_NAME = 'gpt2'
MODEL_NAME = 'microsoft/phi-2'
# MODEL_NAME = 'EleutherAI/pythia-1B'
# MODEL_NAME = 'TinyLlama/TinyLlama-1.1B-Chat-v1.0'

LOAD_IN_8BIT = False
RELATIVE_PATH = '../'
dataset_name = 'following'

### dataset

In [4]:
import json

def load_json_dataset(json_path):
    with open(json_path) as file:
        dataset = json.load(file)
    return dataset

dataset = load_json_dataset(f'{RELATIVE_PATH}data/{dataset_name}.json')
dataset = list(map(lambda x: tuple(x.values()), dataset))
print(f'dataset len: {len(dataset)}')

dataset len: 25


In [5]:
import sys
sys.path.append('..')
torch.set_grad_enabled(False)

from src.utils.model_utils import load_gpt_model_and_tokenizer, set_seed
from src.extraction import get_mean_activations
from src.utils.prompt_helper import tokenize_ICL
from src.intervention import compute_indirect_effect
set_seed(32)

In [6]:
model, tokenizer, config, device = load_gpt_model_and_tokenizer(MODEL_NAME, LOAD_IN_8BIT)
# select number of ICL examples (query excluded)
ICL_examples = 4
tok_ret, ids_ret, correct_labels = tokenize_ICL(tokenizer, ICL_examples = ICL_examples, dataset = dataset)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


PhiForCausalLM(
  (model): PhiModel(
    (embed_tokens): Embedding(51200, 2560)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x PhiDecoderLayer(
        (self_attn): PhiAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (dense): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): PhiRotaryEmbedding()
        )
        (mlp): PhiMLP(
          (activation_fn): NewGELUActivation()
          (fc1): Linear(in_features=2560, out_features=10240, bias=True)
          (fc2): Linear(in_features=10240, out_features=2560, bias=True)
        )
        (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (final_layernorm): LayerNorm((2560,),

## Get activations and measure head's importance

In [19]:
mean_activations = get_mean_activations(
    tokenized_prompts=tok_ret,
    important_ids=ids_ret,
    tokenizer=tokenizer,
    model=model,
    config=config,
    correct_labels=correct_labels,
    device='cuda'
)
torch.save(mean_activations, f'{RELATIVE_PATH}output/{dataset_name}_mean_activations_{MODEL_NAME.replace("/", "-")}.pt')
mean_activations.shape

[x] Extracting activations:   0%|          | 0/5 [00:00<?, ?it/s]


AttributeError: 'NoneType' object has no attribute 'to_legacy_cache'

In [None]:
cie, probs_original, probs_edited  = compute_indirect_effect(
    model=model,
    tokenizer=tokenizer,
    config=config,
    dataset=dataset, 
    mean_activations=mean_activations,
    ICL_examples = ICL_examples,
    batch_size=15,
)
torch.save(cie, f'{RELATIVE_PATH}output/{dataset_name}_cie_{MODEL_NAME.replace("/", "-")}.pt')

In [None]:
import plotly.express as px

fig = px.imshow(cie.mean(dim=0))
fig.show()

PhiAttention(
  (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
  (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
  (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
  (dense): Linear(in_features=2560, out_features=2560, bias=True)
  (rotary_emb): PhiRotaryEmbedding()
)

In [16]:
config['d_model']

2560

In [14]:
with model.generate(max_new_tokens = 1) as generator:
    with generator.invoke('My name is') as invoker:
        a = model.model.layers[0].self_attn.output.save()

print(a.shape)

(torch.Size([1, 3, 2560]), None, <transformers.cache_utils.DynamicCache object at 0x7f7ef02bdd80>)
