In [68]:
import torch as t
import pandas as pd
import os
from tqdm import tqdm
import plotly.express as px
import json
import einops

In [2]:
import sys
sys.path.append('../geometry_of_truth/geometry-of-truth/')

In [3]:
from transformers import LlamaForCausalLM, LlamaTokenizer

In [8]:
from glob import glob
import numpy as np

In [5]:
ACTS_BATCH_SIZE = 25
ROOT = '../geometry_of_truth/geometry-of-truth/'

In [6]:
def load_llama(device):
    print(f'Loading Llama2')
    llama_path = '/home/t-sgolechha/Desktop/llama2/llama/llama-2-7b_hf/'
    tokenizer = LlamaTokenizer.from_pretrained(llama_path)
    model = LlamaForCausalLM.from_pretrained(llama_path)
    # set tokenizer to use bos token
    tokenizer.bos_token = '<s>'
    model.to(device)
    print(f'Loaded Llama2')
    return tokenizer, model

In [7]:
device = 'cuda:0'
tokenizer, model = load_llama(device)

Loading Llama2


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded Llama2


In [10]:
layers = range(model.config.num_hidden_layers)

train_datasets = ['cities']
val_dataset = 'sp_en_trans'

# ProbeClass = MMProbe

# label tokens
t_tok = tokenizer.encode('TRUE')[-1]
f_tok = tokenizer.encode('FALSE')[-1]

In [11]:
def collect_acts(dataset_name, model_size, layer, center=True, scale=False, device='cpu'):
    """
    Collects activations from a dataset of statements, returns as a tensor of shape [n_activations, activation_dimension].
    """
    global ROOT, ACTS_BATCH_SIZE
    directory = os.path.join(ROOT, 'acts', model_size, dataset_name)
    activation_files = glob(os.path.join(directory, f'layer_{layer}_*.pt'))
    acts = [t.load(os.path.join(directory, f'layer_{layer}_{i}.pt')).to(device) for i in range(0, ACTS_BATCH_SIZE * len(activation_files), ACTS_BATCH_SIZE)]
    acts = t.cat(acts, dim=0).to(device)
    if center:
        acts = acts - t.mean(acts, dim=0)
    if scale:
        acts = acts / t.std(acts, dim=0)
    return acts

In [12]:
class MMProbe(t.nn.Module):
    def __init__(self, direction, covariance=None, inv=None, atol=1e-3):
        super().__init__()
        self.direction = t.nn.Parameter(direction, requires_grad=False)
        if inv is None:
            self.inv = t.nn.Parameter(t.linalg.pinv(covariance.cpu(), hermitian=True, atol=atol), requires_grad=False)
            self.inv.to(device)
        else:
            self.inv = t.nn.Parameter(inv, requires_grad=False)

    def forward(self, x, iid=False):
        if iid:
            return t.nn.Sigmoid()(x @ self.inv @ self.direction)
        else:
            return t.nn.Sigmoid()(x @ self.direction)

    def pred(self, x, iid=False):
        return self(x, iid=iid).round()

    def from_data(acts, labels, atol=1e-3, device='cpu'):
        acts, labels
        pos_acts, neg_acts = acts[labels==1], acts[labels==0]
        pos_mean, neg_mean = pos_acts.mean(0), neg_acts.mean(0)
        direction = pos_mean - neg_mean

        centered_data = t.cat([pos_acts - pos_mean, neg_acts - neg_mean], 0)
        covariance = centered_data.t() @ centered_data / acts.shape[0]
        
        probe = MMProbe(direction, covariance=covariance).to(device)

        return probe

In [13]:
layer_directions_t_path = '/home/t-sgolechha/Desktop/mats_research_sprint/directions/llama2_7b_mm_layer_directions_cities.pt'

In [14]:
layer_directions_t = t.load(layer_directions_t_path).to(device)

In [16]:
proj_layer = []
# project each layer activations to Truth direction for that layer
for layer in tqdm(layers):
    acts = collect_acts('cities', '7B', layer).to(device)
    proj = acts @ layer_directions_t[layer]
    proj_layer.append(proj)

100%|██████████| 32/32 [00:01<00:00, 18.40it/s]


In [17]:
len(proj_layer), proj_layer[0].shape

(32, torch.Size([1496]))

In [18]:
proj_layer_t = t.stack(proj_layer, dim=0)

In [19]:
proj_layer_t.shape

torch.Size([32, 1496])

In [35]:
# normalize proj_layer_t for each layer divide by std
proj_layer_t_norm = proj_layer_t / t.std(proj_layer_t, dim=1, keepdim=True)

In [36]:
labels = t.Tensor(pd.read_csv(f'{ROOT}/datasets/cities.csv')['label'].tolist())

In [37]:
labels.shape

torch.Size([1496])

In [38]:
proj_layer_t.flatten().cpu().numpy().shape

(47872,)

In [49]:
df = pd.DataFrame()

In [50]:
df["projection"] = proj_layer_t_norm.flatten().cpu().numpy()
df["layer"] = np.repeat(np.arange(32), 1496)
df["label"] = np.tile(labels.cpu().numpy(), 32)

In [51]:
df["index"] = np.tile(np.arange(1496), 32)

In [52]:
fig = px.scatter(df, x="index", y="projection", animation_frame="layer", color="label", height=600, width=600)
fig.show()

In [53]:
fig.write_html("day_1_llama2_7b_cities_layer_proj_mmprobe.html")

In [46]:
# Figure out xlims and ylims for each frame of the plotly animation
xranges, yranges = [], []
pad_pct = 0.05
for frame in df["frame"].unique():
    subdf = df.query("frame == @frame")
    xpad = pad_pct * (subdf["x"].max() - subdf["x"].min())
    xmin = subdf["x"].min() - xpad
    xmax = subdf["x"].max() + xpad
    ypad = pad_pct * (subdf["y"].max() - subdf["y"].min())
    ymin = subdf["y"].min() - ypad
    ymax = subdf["y"].max() + ypad
    xranges.append([xmin, xmax])
    yranges.append([ymin, ymax])

# Create plotly figure with autoscaling frames
fig = px.scatter(df, x="x", y="y", animation_frame="frame", height=500, width=500)
for f, xlims, ylims in zip(fig.frames, xranges, yranges):
    f.layout.update(xaxis_range=xlims, yaxis_range=ylims)
# fig["layout"].pop("updatemenus");  # Remove the play/stop buttons

In [47]:
fig.show()

## Direction Exploration

In [57]:
sys.path.append('../../mech_interp/core/')

In [58]:
from mi_utils import load_hooked_llama2

In [59]:
device

'cuda:0'

In [60]:
model_tf, tokenizer_tf = load_hooked_llama2(version='7b', device=device)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model meta-llama/Llama-2-7b-hf into HookedTransformer
Moving model to device:  cuda:0
Llama 2 (7b) loaded successfully on cuda:0. Happy interpreting!


In [173]:
layer_directions_t = layer_directions_t / t.norm(layer_directions_t, dim=1, keepdim=True)

In [174]:
cities_df = pd.read_csv(f'{ROOT}/datasets/cities.csv')

In [175]:
cities_df.shape, cities_df.columns

((1496, 5),
 Index(['statement', 'label', 'city', 'country', 'correct_country'], dtype='object'))

In [176]:
prompt: str = cities_df['statement'].iloc[0]

In [177]:
from pprint import pprint as pp

In [178]:
pp(prompt)

'The city of Krasnodar is in Russia.'


In [179]:
patch_dir = layer_directions_t[31]

In [180]:
_, cache = model_tf.run_with_cache(prompt, return_type="loss")

In [181]:
comps_dec, comp_labels = cache.decompose_resid(return_labels=True)

In [182]:
comps_acc, comp_labels = cache.accumulated_resid(incl_mid=True, return_labels=True)

In [183]:
# components of comps in the directions of layer_directions_t

comp_proj_dec = einops.einsum(
    comps_dec[:, 0, -1, :], 
    layer_directions_t, 
    'comps d_model, layer d_model -> comps layer',
)

In [209]:
comp_proj_dec = comps_dec[:, 0, -1, :] @ layer_directions_t.T

In [210]:
comp_proj_acc = comps_acc[:, 0, -1, :] @ layer_directions_t.T

In [211]:
comp_proj_dec.shape, comp_proj_acc.shape

(torch.Size([65, 32]), torch.Size([65, 32]))

In [212]:
import torch
import plotly.express as px
import plotly.graph_objects as go

In [215]:
fig = px.imshow(
    comp_proj_dec.cpu().numpy().T,
    labels=dict(x="Component (Decomposed)", y="Layer", color="Projection"),
)
fig.update_layout(
    xaxis=dict(tickmode="array", tickvals=list(range(65)), ticktext=comp_labels),
)

fig.update_layout(
    autosize=False,
    width=1000,  # Adjust the width of the plot
    height=600,  # Adjust the height of the plot
    margin=dict(l=50, r=50, b=50, t=50),  # Adjust margins as needed
    # aspectratio=dict(x=1, y=1),  # Set the aspect ratio to 1 for square cells
)

# Set background color to white
fig.update_layout(plot_bgcolor='white')
fig.show()

In [216]:
fig = px.imshow(
    comp_proj_acc.cpu().numpy().T,
    labels=dict(x="Component (Accumulated)", y="Layer", color="Projection"),
)
fig.update_layout(
    xaxis=dict(tickmode="array", tickvals=list(range(65)), ticktext=comp_labels),
    yaxis=dict(tickmode="array", tickvals=list(range(32)), ticktext=list(range(32))),
)

fig.update_layout(
    autosize=False,
    width=1000,  # Adjust the width of the plot
    height=600,  # Adjust the height of the plot
    margin=dict(l=50, r=50, b=50, t=50),  # Adjust margins as needed
    # aspectratio=dict(x=1, y=1),  # Set the aspect ratio to 1 for square cells
)

# Set background color to white
fig.update_layout(plot_bgcolor='white')
fig.show()

In [192]:
sample_acts_from_hf = collect_acts('cities', '7B', 31).to(device)[0]

In [193]:
cache['blocks.31.hook_resid_post'][0, -1, :].shape

torch.Size([4096])

In [194]:
torch.allclose(sample_acts_from_hf, cache['blocks.31.hook_resid_post'][0, -1, :])

False

In [195]:
cache.keys()

dict_keys(['hook_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_rot_q', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_pre_linear', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_rot_q', 'blocks.1.attn.hook_rot_k', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.m

In [196]:
cache['blocks.31.hook_resid_post'][0, -1, :]

tensor([ 0.3416,  0.9180,  4.0099,  ..., -1.6804,  0.4642,  1.5637],
       device='cuda:0')

In [197]:
t.manual_seed(32)
x = t.rand((1, 5, 4096)).to(device)

In [198]:
no_lens = model.model.layers[0].forward(x)

In [199]:
lens = model_tf.blocks[0].forward(x)

In [200]:
no_lens = no_lens[0]

In [201]:
no_lens

tensor([[[0.8668, 0.2813, 0.4026,  ..., 0.9547, 0.8108, 0.3769],
         [0.0397, 0.7383, 0.3898,  ..., 0.7659, 0.6448, 0.7672],
         [0.4514, 0.7654, 0.4833,  ..., 0.3922, 0.3931, 0.1100],
         [0.5426, 0.1316, 0.1745,  ..., 0.0745, 0.1604, 0.2483],
         [0.0576, 0.5389, 0.0285,  ..., 0.6003, 0.7046, 0.1649]]],
       device='cuda:0')

In [202]:
lens

tensor([[[0.8695, 0.2830, 0.4036,  ..., 0.9583, 0.8089, 0.3814],
         [0.0398, 0.7404, 0.3906,  ..., 0.7684, 0.6461, 0.7661],
         [0.4496, 0.7662, 0.4832,  ..., 0.3926, 0.3949, 0.1098],
         [0.5422, 0.1312, 0.1750,  ..., 0.0756, 0.1609, 0.2492],
         [0.0576, 0.5389, 0.0285,  ..., 0.6003, 0.7046, 0.1649]]],
       device='cuda:0')

In [203]:
def load_statements(dataset_name):
    """
    Load statements from csv file, return list of strings.
    """
    dataset = pd.read_csv(f"{ROOT}datasets/{dataset_name}.csv")
    statements = dataset['statement'].tolist()
    return statements

In [204]:
def get_acts(statements, tokenizer, model, layers, device):
    """
    Get given layer activations for the statements. 
    Return dictionary of stacked activations.
    """
    # attach hooks
    hooks, handles = [], []
    for layer in layers:
        hook = Hook()
        handle = model.model.layers[layer].register_forward_hook(hook)
        hooks.append(hook), handles.append(handle)
    
    # get activations
    acts = {layer : [] for layer in layers}
    for statement in tqdm(statements):
        input_ids = tokenizer.encode(statement, return_tensors="pt").to(device)
        model(input_ids)
        for layer, hook in zip(layers, hooks):
            acts[layer].append(hook.out[0, -1])
    
    for layer, act in acts.items():
        acts[layer] = t.stack(act).float()
    
    # remove hooks
    for handle in handles:
        handle.remove()
    
    return acts

In [205]:
train_datasets

['cities']

In [206]:
layers

range(0, 32)

In [207]:
acts_path = '/home/t-sgolechha/Desktop/mats_research_sprint/geometry_of_truth/geometry-of-truth/acts/7B_lens/'

In [208]:
for dataset in train_datasets:
    statements = load_statements(dataset)
    layers = [int(layer) for layer in layers]
    if layers == [-1]:
        layers = list(range(len(model.model.layers)))
    save_dir = f"{acts_path}{dataset}/"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    for idx in range(0, len(statements), 25):
        acts = get_acts(statements[idx:idx + 25], tokenizer, model, layers, device)
        for layer, act in acts.items():
                t.save(act, f"{save_dir}/layer_{layer}_{idx}.pt")

NameError: name 'Hook' is not defined

In [None]:
per_head_residual, head_labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_residual = einops.rearrange(
    per_head_residual,
    "(layer head) ... -> layer head ...",
    layer=model_tf.cfg.n_layers
)[:, :, 0, :]  # (layer head d_model)

In [217]:
per_head_residual.shape

torch.Size([32, 32, 4096])

In [218]:
head_dvaas = per_head_residual @ layer_directions_t.T  

In [219]:
head_dvaas.shape

torch.Size([32, 32, 32])

In [220]:
head_dvaas_np = head_dvaas.cpu().numpy()

In [221]:
fig = px.imshow(
    head_dvaas_np,
    animation_frame=0,
    labels=dict(x="Direction", y="Head", color="Value"),
)
fig.update_layout(
    xaxis=dict(tickmode="array", tickvals=list(range(32)), ticktext=[str(i) for i in range(32)]),
    yaxis=dict(tickmode="array", tickvals=list(range(32)), ticktext=[str(i) for i in range(32)]),
)
fig.update_layout(
    autosize=False,
    width=700,  # Adjust the width of the plot
    height=700,  # Adjust the height of the plot
    margin=dict(l=50, r=50, b=50, t=50),  # Adjust margins as needed
    # aspectratio=dict(x=1, y=1),  # Set the aspect ratio to 1 for square cells
)

#### Interesting Heads

* **Good**: L6H2, L5H1, L12H22, L15H10, L19H29, L21H25 (interesting), L26H18 (interesting), L27H7, L27H9, L28H28, L31H31, L29H10, 

* **Bad**: L8H21, L4H7, L9H4, L10H1, L18H17, L30H31, L30H13?, L30H24

In [None]:
head_dvaas[25, 12, 19]

tensor(0.0119, device='cuda:0')

In [167]:
head_dvaas.shape # (layer, head, direction)

torch.Size([32, 32, 32])

In [171]:
layer_directions_t_norm = layer_directions_t / t.norm(layer_directions_t, dim=1, keepdim=True)

In [172]:
px.imshow((layer_directions_t_norm @ layer_directions_t_norm.T).cpu().numpy())