In [1]:
import os
import torch
import numpy as np
from tqdm import tqdm
import sys
import argparse

sys.path.append("../")

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM

In [3]:
import pyvene as pv

In [4]:
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir="./hf_cache"
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype = torch.float16,
    cache_dir="./hf_cache",
    low_cpu_mem_usage = True,
    device_map = "auto"
)

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
device = "cuda"

In [6]:
def format_truthfulqa(question, choice):
    return f"Q: {question} A: {choice}"

In [7]:
import json 
def load_task_dataset(dsname):
    filepath = f"{dsname}.json"
    with open(filepath, "r") as f:
        data = json.load(f)

    # print(data)
    processed_data = []
    for item in data[:20]:
        question = item.get("question", "")
        choices = item.get("mc2_targets",{}).keys()
        # labels = item.get("mc2_targets",{}).get("labels", [])

        # print(item.get("mc2_targets",{}).keys())
        # break
        for choice in choices:
            processed_data.append(
                {
                    "text": format_truthfulqa(question, choice),
                    "label": int(item.get("mc2_targets",{})[choice])
                }
            )
            # print(processed_data)
    return processed_data

In [8]:
dataset = load_task_dataset("mc_task")

In [9]:
prompts = []
labels = []

for item in dataset:
    text = item["text"]
    label = item["label"]
    prompt = tokenizer(text, return_tensors='pt').input_ids
    prompts.append(prompt)
    labels.append(label)

In [10]:
from interveners import wrapper, Collector

collectors = []

pv_config = []

for layer in range(model.config.num_hidden_layers): 
    collector = Collector(multiplier=0, head=-1) #head=-1 to collect all head activations, multiplier doens't matter
    collectors.append(collector)
    pv_config.append({
        "component": f"model.layers[{layer}].self_attn.o_proj.input",
        "intervention": wrapper(collector),
    })
collected_model = pv.IntervenableModel(pv_config, model)

In [11]:
all_layer_wise_activations = []
all_head_wise_activations = []

In [12]:
def get_llama_activations_pyvene(collected_model, collectors, prompt, device):
    with torch.no_grad():
        prompt = prompt.to(device)
        output = collected_model({"input_ids": prompt, "output_hidden_states": True})[1]
    hidden_states = output.hidden_states
    hidden_states = torch.stack(hidden_states, dim = 0).squeeze()
    hidden_states = hidden_states.detach().cpu().numpy()
    head_wise_hidden_states = []
    for collector in collectors:
        if collector.collect_state:
            states_per_gen = torch.stack(collector.states, axis=0).cpu().numpy()
            head_wise_hidden_states.append(states_per_gen)
        else:
            head_wise_hidden_states.append(None)
        collector.reset()
    mlp_wise_hidden_states = []
    head_wise_hidden_states = torch.stack([torch.tensor(h) for h in head_wise_hidden_states], dim=0).squeeze().numpy()
    return hidden_states, head_wise_hidden_states, mlp_wise_hidden_states

In [13]:
# from utils import get_llama_activations_pyvene

print("Getting activations")
for i, prompt in enumerate(tqdm(prompts)):
    layer_wise_activations, head_wise_activations, _ = get_llama_activations_pyvene(collected_model, collectors, prompt, device)
    # print(len(layer_wise_activations[0]), len(head_wise_activations))
    # continue
    # Extract only the last token activation from the specified layer
    last_token_activation = layer_wise_activations[14, -1, :].copy()  # (D,)
    all_layer_wise_activations.append(last_token_activation)
    
    # For head-wise, also extract only last token
    last_token_head_activation = head_wise_activations[-1, :].copy()  # (H*d_head,)
    all_head_wise_activations.append(last_token_head_activation)

Getting activations


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 208/208 [00:14<00:00, 14.70it/s]


In [14]:
all_layer_wise_activations = np.array(all_layer_wise_activations, dtype=np.float32)  # (N, D)
all_head_wise_activations = np.array(all_head_wise_activations, dtype=np.float32)  # (N, H*d_head)
labels = np.array(labels, dtype=np.int32)  # (N,)

# Ensure features directory exists
os.makedirs('../features', exist_ok=True)

print("Saving labels")
np.save(f'../features/mistral_tqa_labels.npy', labels)

print("Saving layer wise activations")
np.save(f'../features/mistral_tqa_layer_wise.npy', all_layer_wise_activations)

print("Saving head wise activations")
np.save(f'../features/mistral_tqa_head_wise.npy', all_head_wise_activations)

Saving labels
Saving layer wise activations
Saving head wise activations


In [15]:
labels.shape

(208,)

In [16]:
all_layer_wise_activations.shape

(208, 4096)