In [1]:
import sys
sys.path.append("../")

import torch 

from transformers import AutoModelForCausalLM, AutoTokenizer

from white_box.model_wrapper import ModelWrapper
from white_box.chat_model_utils import load_model_and_tokenizer, get_template, MODEL_CONFIGS

%load_ext autoreload
%autoreload 2

In [2]:
model_name = 'llama2_7b'
model_config = MODEL_CONFIGS[model_name]

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", 
    torch_dtype=torch.float16, 
    device_map="auto")

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", padding_side = "right")
tokenizer.pad_token = tokenizer.eos_token

template = get_template(model_name, chat_template=model_config.get('chat_template', None))['prompt']
mw = ModelWrapper(model, tokenizer, template = template)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Found Instruction template for llama2_7b
{'description': 'Template used by Llama2 Chat', 'prompt': '[INST] {instruction} [/INST] '}


In [29]:

test_tensor = mw.batch_hiddens(["Hello, how are you?", "I am doing well, thank you."], tok_idxs = [-1, -2])['resid']

100%|██████████| 1/1 [00:00<00:00,  2.12it/s]

tensor([[0, 0],
        [1, 1]])
tensor([[13, 14],
        [15, 16]])





In [30]:
prompts = ["Hello, how are you?", "I am doing well, thank you."]

formatted_prompts = [template.format(instruction=s) for s in prompts]
inputs = tokenizer(formatted_prompts, return_tensors="pt", padding=True, max_length=2048, truncation=True)

res = model(**inputs, output_hidden_states = True)['hidden_states']
hidden_states = torch.stack([r for r in res], dim=1)[:, 1:].float()

In [40]:
first_prompt_state = hidden_states[0, :, tok_idxs - 2]
sec_prompt_state = hidden_states[1, :, tok_idxs]

assert torch.allclose(test_tensor[0], first_prompt_state)
assert torch.allclose(test_tensor[1], sec_prompt_state)

In [34]:
test_tensor[0].shape

torch.Size([32, 2, 4096])

In [22]:
acts = torch.stack([r for r in res], dim=1)

torch.allclose(acts[1, 1:, -1].float(), test_res['resid'][1, :, -1])

True

In [28]:
torch.allclose(acts[0, 1:, -3].float(), test_res['resid'][0, :, -1])

True

In [16]:
from pytests.test_mw import _test_batch_hiddens
prompts = ["Hello, how are you?", "I am doing well, thank you."]

_test_batch_hiddens(model, model_name, prompts, padding_side = "left", tok_idxs = [-1, -2, -3])

Found Instruction template for llama2_7b
{'description': 'Template used by Llama2 Chat', 'prompt': '[INST] {instruction} [/INST] '}


100%|██████████| 1/1 [00:00<00:00,  2.87it/s]


In [6]:
import sys
sys.path.append("../")
from tests.test_mw import _test_batch_hiddens

ModuleNotFoundError: No module named 'tests.test_mw'

In [49]:
import tests

In [50]:
tests

<module 'tests' from '/home/ubuntu/anaconda3/envs/white-box/lib/python3.11/site-packages/tests/__init__.py'>

In [21]:
import torch

# Example data
matrix = torch.tensor([[10, 20, 30],
                       [40, 50, 60],
                       [70, 80, 90]])

# Index tensor
indices = torch.tensor([[0, 1], [1, 2], [0, 2]])

# Create batch indices
batch_indices = torch.arange(matrix.shape[0]).unsqueeze(1).expand(-1, indices.shape[1])

# Use advanced indexing to select the elements
selected_elements = matrix[batch_indices, indices]

print("Original matrix:")
print(matrix)
print("\nIndices:")
print(indices)
print("\nSelected elements:")
print(selected_elements)


Original matrix:
tensor([[10, 20, 30],
        [40, 50, 60],
        [70, 80, 90]])

Indices:
tensor([[0, 1],
        [1, 2],
        [0, 2]])

Selected elements:
tensor([[10, 20],
        [50, 60],
        [70, 90]])
