In [3]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer

### Some configs

In [4]:
MODEL_NAME = 'gpt2'
MODEL_NAME = 'microsoft/phi-2'
MODEL_NAME = 'EleutherAI/pythia-1B'

LOAD_IN_8BIT = True
RELATIVE_PATH = '../'

dataset_name = 'country-capital'
# select number of ICL examples (query excluded)
ICL_examples = 4
debug = True

### dataset

In [None]:
import json

def load_json_dataset(json_path):
    with open(json_path, encoding='utf-8') as file:
        dataset = json.load(file)
    return dataset

dataset = load_json_dataset(f'{RELATIVE_PATH}data/{dataset_name}.json')
dataset = list(map(lambda x: tuple(x.values()), dataset))

if debug:
    dataset = dataset[0:50]

print(f'dataset len: {len(dataset)}')

In [None]:
import sys
sys.path.append('..')
torch.set_grad_enabled(False)

from src.utils.model_utils import load_gpt_model_and_tokenizer, set_seed
from src.extraction import get_mean_activations
from src.utils.prompt_helper import tokenize_ICL
from src.intervention import compute_indirect_effect
set_seed(32)

In [None]:
model, tokenizer, config, device = load_gpt_model_and_tokenizer(MODEL_NAME, LOAD_IN_8BIT)
tok_ret, ids_ret, correct_labels = tokenize_ICL(tokenizer, ICL_examples = ICL_examples, dataset = dataset)

## Get activations and measure head's importance

In [None]:
mean_activations = get_mean_activations(
    tokenized_prompts=tok_ret,
    important_ids=ids_ret,
    tokenizer=tokenizer,
    model=model,
    config=config,
    correct_labels=correct_labels,
    device='cuda',
)
# torch.save(mean_activations, f'{RELATIVE_PATH}output/{dataset_name}_mean_activations_{MODEL_NAME.replace("/", "-")}.pt')
mean_activations.shape

In [None]:
cie, probs_original, probs_edited  = compute_indirect_effect(
    model=model,
    tokenizer=tokenizer,
    config=config,
    dataset=dataset, 
    mean_activations=mean_activations,
    ICL_examples = ICL_examples,
    batch_size=15,
)
# torch.save(cie, f'{RELATIVE_PATH}output/{dataset_name}_cie_{MODEL_NAME.replace("/", "-")}.pt')

In [6]:
cie = torch.load(f'{RELATIVE_PATH}output/sentiment_cie_EleutherAI-pythia-1B.pt')

import plotly.express as px

fig = px.imshow(cie.mean(dim=0))
fig.show()