In [1]:
import sys
from itertools import islice
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# Local task dataloaders
sys.path.append('../')
from tasks import task_dataset

# load repe module
from repe import repe_pipeline_registry
repe_pipeline_registry()


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name_or_path = "NousResearch/Llama-2-7b-hf"

model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, device_map="auto", token=True).eval()
use_fast_tokenizer = "LlamaForCausalLM" not in model.config.architectures
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=use_fast_tokenizer, padding_side="left", legacy=False, token=True)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
tokenizer.bos_token_id = 1

rep_pipeline =  pipeline("rep-reading", model=model, tokenizer=tokenizer)


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.99s/it]


In [3]:
batch_size = 2
max_length = 2048

rep_token = -1
hidden_layers = list(range(-1, -model.config.num_hidden_layers, -1))
n_difference = 1
direction_method = 'pca'


In [4]:
# Define tasks and #shots here

# task, ntrain = 'obqa', 5
# task, ntrain = 'csqa', 7
task, ntrain = 'arc_challenge', 25
# task, ntrain = 'race', 3


dataset = task_dataset(task)(ntrain=ntrain)


In [5]:
# Build an unsupervised LAT PCA representation
direction_finder_kwargs= {"n_components": 1}
rep_reader = rep_pipeline.get_directions(
    dataset['train']['data'], 
    rep_token=rep_token, 
    hidden_layers=hidden_layers, 
    n_difference=n_difference, 
    train_labels=dataset['train']['labels'], 
    direction_method=direction_method,
    direction_finder_kwargs=direction_finder_kwargs,
    batch_size=batch_size,
    max_length=max_length,
    padding="longest",
)


In [6]:
# Eval validation
results_val = {layer: {} for layer in hidden_layers}
labels = dataset['val']['labels']
H_tests = rep_pipeline(dataset['val']['data'],
                    rep_token=rep_token, 
                    hidden_layers=hidden_layers, 
                    rep_reader=rep_reader,
                    batch_size=batch_size//2,
                    max_length=2048,
                    padding="longest")

for layer in hidden_layers:
    H_test = [H[layer] for H in H_tests] 
    unflattened_H_tests = [list(islice(H_test, sum(len(c) for c in labels[:i]), sum(len(c) for c in labels[:i+1]))) for i in range(len(labels))]

    sign = rep_reader.direction_signs[layer]
    eval_func = np.argmin if sign == -1 else np.argmax
    cors = np.mean([labels[i].index(1) == eval_func(H) for i, H in enumerate(unflattened_H_tests)])

    results_val[layer] = cors
    
    print(f"{layer} : {cors}")
    print("=====")    


-1 : 0.49498327759197325
=====
-2 : 0.4882943143812709
=====
-3 : 0.4782608695652174
=====
-4 : 0.4916387959866221
=====
-5 : 0.4983277591973244
=====
-6 : 0.5150501672240803
=====
-7 : 0.5150501672240803
=====
-8 : 0.5083612040133779
=====
-9 : 0.5016722408026756
=====
-10 : 0.4983277591973244
=====
-11 : 0.5150501672240803
=====
-12 : 0.5117056856187291
=====
-13 : 0.5150501672240803
=====
-14 : 0.5284280936454849
=====
-15 : 0.5016722408026756
=====
-16 : 0.5250836120401338
=====
-17 : 0.47157190635451507
=====
-18 : 0.4916387959866221
=====
-19 : 0.5150501672240803
=====
-20 : 0.24749163879598662
=====
-21 : 0.3277591973244147
=====
-22 : 0.2909698996655518
=====
-23 : 0.2709030100334448
=====
-24 : 0.2842809364548495
=====
-25 : 0.25418060200668896
=====
-26 : 0.25752508361204013
=====
-27 : 0.2408026755852843
=====
-28 : 0.23411371237458195
=====
-29 : 0.26421404682274247
=====
-30 : 0.2408026755852843
=====
-31 : 0.22742474916387959
=====


In [7]:
# Eval Test
results_test = {layer: {} for layer in hidden_layers}
labels = dataset['test']['labels']
H_tests = rep_pipeline(dataset['test']['data'], 
                    rep_token=rep_token, 
                    hidden_layers=hidden_layers, 
                    rep_reader=rep_reader,
                    batch_size=batch_size//2,
                    max_length=2048,
                    padding="longest")

for layer in hidden_layers:
    H_test = [H[layer] for H in H_tests] 
    unflattened_H_tests = [list(islice(H_test, sum(len(c) for c in labels[:i]), sum(len(c) for c in labels[:i+1]))) for i in range(len(labels))]

    sign = rep_reader.direction_signs[layer]
    eval_func = np.argmin if sign == -1 else np.argmax
    cors = np.mean([labels[i].index(1) == eval_func(H) for i, H in enumerate(unflattened_H_tests)])

    results_test[layer] = cors
    
    print(f"{layer} : {cors}")
    print("=====") 


KeyboardInterrupt: 

In [None]:
x = list(results_val.keys())
y_val = [results_val[layer] for layer in hidden_layers]
y_test = [results_test[layer] for layer in hidden_layers]


plt.plot(x, y_val, label="Dev")
plt.plot(x, y_test, label="Test")

plt.title(f"{task} Acc by Layer")
plt.xlabel("Layer")
plt.ylabel("Acc")
plt.legend()
plt.grid(True)
plt.show()
