In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os, re, json
import torch, numpy as np

import sys
sys.path.append('..')
torch.set_grad_enabled(False)

from src.utils.extract_utils import get_mean_head_activations, compute_universal_function_vector
from src.utils.intervention_utils import fv_intervention_natural_text, function_vector_intervention
from src.utils.model_utils import load_gpt_model_and_tokenizer
from src.utils.prompt_utils import load_dataset, word_pairs_to_prompt_data, create_prompt
from src.utils.eval_utils import decode_to_vocab, sentence_eval

In [3]:
dataset = load_dataset('fv_unsafe_Q_safe_A', seed=0)
dataset

{'train': ICLDataset({
 	features: ['input', 'output'],
 	num_rows: 221
 }),
 'valid': ICLDataset({
 	features: ['input', 'output'],
 	num_rows: 29
 }),
 'test': ICLDataset({
 	features: ['input', 'output'],
 	num_rows: 66
 })}

## Load model & tokenizer

In [4]:
model_name = 'meta-llama/Llama-2-7b-chat-hf'
# model_name = "mistralai/Mistral-7B-Instruct-v0.2"
# model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name)
#EDIT_LAYER = 9

Loading:  meta-llama/Llama-2-7b-chat-hf


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Load dataset and Compute task-conditioned mean activations

In [5]:
#dataset = load_dataset('fv_unsafe_Q_safe_A', seed=0)
mean_activations = get_mean_head_activations(dataset, model, model_config, tokenizer)

In [6]:
model.state_dict()

OrderedDict([('model.embed_tokens.weight',
              tensor([[ 1.1921e-06, -1.7881e-06, -4.2915e-06,  ...,  8.3447e-07,
                       -6.4373e-06,  8.9407e-07],
                      [ 1.8387e-03, -3.8147e-03,  9.6130e-04,  ..., -9.0332e-03,
                        2.6550e-03, -3.7537e-03],
                      [ 1.0193e-02,  9.7656e-03, -5.2795e-03,  ...,  2.9297e-03,
                        4.0817e-04, -5.0964e-03],
                      ...,
                      [-1.3550e-02, -3.5095e-03, -1.8921e-02,  ..., -9.3384e-03,
                        8.7891e-03, -1.2741e-03],
                      [-1.0681e-02,  8.9722e-03,  1.2573e-02,  ..., -3.3691e-02,
                       -1.6235e-02,  3.0212e-03],
                      [-9.0942e-03, -1.8082e-03, -6.9809e-04,  ...,  3.8452e-03,
                       -1.2085e-02,  7.2861e-04]], device='cuda:0')),
             ('model.layers.0.self_attn.q_proj.weight',
              tensor([[-0.0060, -0.0146, -0.0021,  ...,  0.0042,  0.

## Compute function vector (FV)

In [7]:
FV, top_heads = compute_universal_function_vector(mean_activations, model, model_config, n_top_heads=10)
FV

tensor([[ 0.1596,  0.0555,  0.1561,  ..., -0.0193, -0.0630, -0.1086]],
       device='cuda:0')

In [8]:
import numpy as np
np.save('../outputs/FV_vector_top_head_10_llama_7b.npy', FV.detach().cpu().numpy()) 