In [None]:
from dotenv import load_dotenv
import os

load_dotenv()

In [None]:
import sys
from itertools import islice
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# Local task dataloaders
sys.path.append('../')
sys.path.append('../../')
from tasks import task_dataset

# load repe module
from repe import repe_pipeline_registry
repe_pipeline_registry()


In [None]:
model_name_or_path = "meta-llama/Llama-2-7b-hf"

model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             torch_dtype=torch.bfloat16,
                                             #device_map="cuda",
                                             token=os.getenv('HF_TOKEN')).eval()

In [None]:
use_fast_tokenizer = "LlamaForCausalLM" not in model.config.architectures
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
                                          use_fast=use_fast_tokenizer,
                                          padding_side="left",
                                          legacy=True, ## check this! It does not work with False
                                          token=os.getenv('HF_TOKEN'))
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
tokenizer.bos_token_id = 1

rep_pipeline =  pipeline("rep-reading", model=model, tokenizer=tokenizer)

In [None]:
model.to('cuda')

In [None]:
batch_size = 32
max_length = 2048

rep_token = -1
hidden_layers = list(range(-1, -model.config.num_hidden_layers, -1))
n_difference = 1
direction_method = 'pca'

In [None]:
# Define tasks and #shots here

# task, ntrain = 'obqa', 5
# task, ntrain = 'csqa', 7
task, ntrain = 'arc_challenge', 25
# task, ntrain = 'race', 3


dataset = task_dataset(task)(ntrain=ntrain)

In [None]:
# Build an unsupervised LAT PCA representation
direction_finder_kwargs= {"n_components": 1}
rep_reader = rep_pipeline.get_directions(
    dataset['train']['data'], 
    rep_token=rep_token, 
    hidden_layers=hidden_layers, 
    n_difference=n_difference, 
    train_labels=dataset['train']['labels'], 
    direction_method=direction_method,
    direction_finder_kwargs=direction_finder_kwargs,
    batch_size=batch_size,
    max_length=max_length,
    padding="longest",
)

In [None]:
# Eval validation
results_val = {layer: {} for layer in hidden_layers}
labels = dataset['val']['labels']
H_tests = rep_pipeline(dataset['val']['data'],
                    rep_token=rep_token, 
                    hidden_layers=hidden_layers, 
                    rep_reader=rep_reader,
                    batch_size=8,
                    max_length=2048,
                    padding="longest")

for layer in hidden_layers:
    H_test = [H[layer] for H in H_tests] 
    unflattened_H_tests = [list(islice(H_test, sum(len(c) for c in labels[:i]), sum(len(c) for c in labels[:i+1]))) for i in range(len(labels))]

    sign = rep_reader.direction_signs[layer]
    eval_func = np.argmin if sign == -1 else np.argmax
    cors = np.mean([labels[i].index(1) == eval_func(H) for i, H in enumerate(unflattened_H_tests)])

    results_val[layer] = cors
    
    print(f"{layer} : {cors}")
    print("=====")    

In [None]:
# Eval Test
results_test = {layer: {} for layer in hidden_layers}
labels = dataset['test']['labels']
H_tests = rep_pipeline(dataset['test']['data'], 
                    rep_token=rep_token, 
                    hidden_layers=hidden_layers, 
                    rep_reader=rep_reader,
                    batch_size=8,
                    max_length=2048,
                    padding="longest")

for layer in hidden_layers:
    H_test = [H[layer] for H in H_tests] 
    unflattened_H_tests = [list(islice(H_test, sum(len(c) for c in labels[:i]), sum(len(c) for c in labels[:i+1]))) for i in range(len(labels))]

    sign = rep_reader.direction_signs[layer]
    eval_func = np.argmin if sign == -1 else np.argmax
    cors = np.mean([labels[i].index(1) == eval_func(H) for i, H in enumerate(unflattened_H_tests)])

    results_test[layer] = cors
    
    print(f"{layer} : {cors}")
    print("=====") 

In [None]:
x = list(results_val.keys())
y_val = [results_val[layer] for layer in hidden_layers]
y_test = [results_test[layer] for layer in hidden_layers]


plt.plot(x, y_val, label="Dev")
plt.plot(x, y_test, label="Test")

plt.title(f"{task} Acc by Layer")
plt.xlabel("Layer")
plt.ylabel("Acc")
plt.legend()
plt.grid(True)
plt.show()