In [1]:
import torch as t
import pandas as pd
import os
from tqdm import tqdm
import plotly.express as px
import json

In [2]:
import sys
sys.path.append('../geometry_of_truth/geometry-of-truth/')

In [3]:
from transformers import LlamaForCausalLM, LlamaTokenizer

In [8]:
from glob import glob
import numpy as np

In [5]:
ACTS_BATCH_SIZE = 25
ROOT = '../geometry_of_truth/geometry-of-truth/'

In [6]:
def load_llama(device):
    print(f'Loading Llama2')
    llama_path = '/home/t-sgolechha/Desktop/llama2/llama/llama-2-7b_hf/'
    tokenizer = LlamaTokenizer.from_pretrained(llama_path)
    model = LlamaForCausalLM.from_pretrained(llama_path)
    # set tokenizer to use bos token
    tokenizer.bos_token = '<s>'
    model.to(device)
    print(f'Loaded Llama2')
    return tokenizer, model

In [7]:
device = 'cuda:0'
tokenizer, model = load_llama(device)

Loading Llama2


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded Llama2


In [10]:
layers = range(model.config.num_hidden_layers)

train_datasets = ['cities']
val_dataset = 'sp_en_trans'

# ProbeClass = MMProbe

# label tokens
t_tok = tokenizer.encode('TRUE')[-1]
f_tok = tokenizer.encode('FALSE')[-1]

In [11]:
def collect_acts(dataset_name, model_size, layer, center=True, scale=False, device='cpu'):
    """
    Collects activations from a dataset of statements, returns as a tensor of shape [n_activations, activation_dimension].
    """
    global ROOT, ACTS_BATCH_SIZE
    directory = os.path.join(ROOT, 'acts', model_size, dataset_name)
    activation_files = glob(os.path.join(directory, f'layer_{layer}_*.pt'))
    acts = [t.load(os.path.join(directory, f'layer_{layer}_{i}.pt')).to(device) for i in range(0, ACTS_BATCH_SIZE * len(activation_files), ACTS_BATCH_SIZE)]
    acts = t.cat(acts, dim=0).to(device)
    if center:
        acts = acts - t.mean(acts, dim=0)
    if scale:
        acts = acts / t.std(acts, dim=0)
    return acts

In [12]:
class MMProbe(t.nn.Module):
    def __init__(self, direction, covariance=None, inv=None, atol=1e-3):
        super().__init__()
        self.direction = t.nn.Parameter(direction, requires_grad=False)
        if inv is None:
            self.inv = t.nn.Parameter(t.linalg.pinv(covariance.cpu(), hermitian=True, atol=atol), requires_grad=False)
            self.inv.to(device)
        else:
            self.inv = t.nn.Parameter(inv, requires_grad=False)

    def forward(self, x, iid=False):
        if iid:
            return t.nn.Sigmoid()(x @ self.inv @ self.direction)
        else:
            return t.nn.Sigmoid()(x @ self.direction)

    def pred(self, x, iid=False):
        return self(x, iid=iid).round()

    def from_data(acts, labels, atol=1e-3, device='cpu'):
        acts, labels
        pos_acts, neg_acts = acts[labels==1], acts[labels==0]
        pos_mean, neg_mean = pos_acts.mean(0), neg_acts.mean(0)
        direction = pos_mean - neg_mean

        centered_data = t.cat([pos_acts - pos_mean, neg_acts - neg_mean], 0)
        covariance = centered_data.t() @ centered_data / acts.shape[0]
        
        probe = MMProbe(direction, covariance=covariance).to(device)

        return probe

In [13]:
layer_directions_t_path = '/home/t-sgolechha/Desktop/mats_research_sprint/directions/llama2_7b_mm_layer_directions_cities.pt'

In [14]:
layer_directions_t = t.load(layer_directions_t_path).to(device)

In [16]:
proj_layer = []
# project each layer activations to Truth direction for that layer
for layer in tqdm(layers):
    acts = collect_acts('cities', '7B', layer).to(device)
    proj = acts @ layer_directions_t[layer]
    proj_layer.append(proj)

100%|██████████| 32/32 [00:01<00:00, 18.40it/s]


In [17]:
len(proj_layer), proj_layer[0].shape

(32, torch.Size([1496]))

In [18]:
proj_layer_t = t.stack(proj_layer, dim=0)

In [19]:
proj_layer_t.shape

torch.Size([32, 1496])

In [35]:
# normalize proj_layer_t for each layer divide by std
proj_layer_t_norm = proj_layer_t / t.std(proj_layer_t, dim=1, keepdim=True)

In [36]:
labels = t.Tensor(pd.read_csv(f'{ROOT}/datasets/cities.csv')['label'].tolist())

In [37]:
labels.shape

torch.Size([1496])

In [38]:
proj_layer_t.flatten().cpu().numpy().shape

(47872,)

In [49]:
df = pd.DataFrame()

In [50]:
df["projection"] = proj_layer_t_norm.flatten().cpu().numpy()
df["layer"] = np.repeat(np.arange(32), 1496)
df["label"] = np.tile(labels.cpu().numpy(), 32)

In [51]:
df["index"] = np.tile(np.arange(1496), 32)

In [52]:
fig = px.scatter(df, x="index", y="projection", animation_frame="layer", color="label", height=600, width=600)
fig.show()

In [53]:
fig.write_html("day_1_llama2_7b_cities_layer_proj_mmprobe.html")

In [46]:
# Figure out xlims and ylims for each frame of the plotly animation
xranges, yranges = [], []
pad_pct = 0.05
for frame in df["frame"].unique():
    subdf = df.query("frame == @frame")
    xpad = pad_pct * (subdf["x"].max() - subdf["x"].min())
    xmin = subdf["x"].min() - xpad
    xmax = subdf["x"].max() + xpad
    ypad = pad_pct * (subdf["y"].max() - subdf["y"].min())
    ymin = subdf["y"].min() - ypad
    ymax = subdf["y"].max() + ypad
    xranges.append([xmin, xmax])
    yranges.append([ymin, ymax])

# Create plotly figure with autoscaling frames
fig = px.scatter(df, x="x", y="y", animation_frame="frame", height=500, width=500)
for f, xlims, ylims in zip(fig.frames, xranges, yranges):
    f.layout.update(xaxis_range=xlims, yaxis_range=ylims)
# fig["layout"].pop("updatemenus");  # Remove the play/stop buttons

In [47]:
fig.show()