In [1]:
import torch as t
import pandas as pd
import os
from tqdm import tqdm
import plotly.express as px
import json

In [2]:
import sys
sys.path.append('./geometry-of-truth/')

In [9]:
from transformers import LlamaForCausalLM, LlamaTokenizer

In [23]:
from glob import glob

In [26]:
ACTS_BATCH_SIZE = 25

In [14]:
ROOT = './geometry-of-truth/'

In [27]:
def collect_acts(dataset_name, model_size, layer, center=True, scale=False, device='cpu'):
    """
    Collects activations from a dataset of statements, returns as a tensor of shape [n_activations, activation_dimension].
    """
    global ROOT, ACTS_BATCH_SIZE
    directory = os.path.join(ROOT, 'acts', model_size, dataset_name)
    activation_files = glob(os.path.join(directory, f'layer_{layer}_*.pt'))
    acts = [t.load(os.path.join(directory, f'layer_{layer}_{i}.pt')).to(device) for i in range(0, ACTS_BATCH_SIZE * len(activation_files), ACTS_BATCH_SIZE)]
    acts = t.cat(acts, dim=0).to(device)
    if center:
        acts = acts - t.mean(acts, dim=0)
    if scale:
        acts = acts / t.std(acts, dim=0)
    return acts

In [10]:
def load_llama(device):
    print(f'Loading Llama2')
    llama_path = '/home/t-sgolechha/Desktop/llama2/llama/llama-2-7b_hf/'
    tokenizer = LlamaTokenizer.from_pretrained(llama_path)
    model = LlamaForCausalLM.from_pretrained(llama_path)
    # set tokenizer to use bos token
    tokenizer.bos_token = '<s>'
    model.to(device)
    print(f'Loaded Llama2')
    return tokenizer, model

In [11]:
device = 'cuda:0'
tokenizer, model = load_llama(device)

Loading Llama2


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded Llama2


In [36]:
class MMProbe(t.nn.Module):
    def __init__(self, direction, covariance=None, inv=None, atol=1e-3):
        super().__init__()
        self.direction = t.nn.Parameter(direction, requires_grad=False)
        if inv is None:
            self.inv = t.nn.Parameter(t.linalg.pinv(covariance.cpu(), hermitian=True, atol=atol), requires_grad=False)
            self.inv.to(device)
        else:
            self.inv = t.nn.Parameter(inv, requires_grad=False)

    def forward(self, x, iid=False):
        if iid:
            return t.nn.Sigmoid()(x @ self.inv @ self.direction)
        else:
            return t.nn.Sigmoid()(x @ self.direction)

    def pred(self, x, iid=False):
        return self(x, iid=iid).round()

    def from_data(acts, labels, atol=1e-3, device='cpu'):
        acts, labels
        pos_acts, neg_acts = acts[labels==1], acts[labels==0]
        pos_mean, neg_mean = pos_acts.mean(0), neg_acts.mean(0)
        direction = pos_mean - neg_mean

        centered_data = t.cat([pos_acts - pos_mean, neg_acts - neg_mean], 0)
        covariance = centered_data.t() @ centered_data / acts.shape[0]
        
        probe = MMProbe(direction, covariance=covariance).to(device)

        return probe

In [42]:
model.config.num_hidden_layers

32

In [43]:
layers = range(model.config.num_hidden_layers)

train_datasets = ['cities']
val_dataset = 'sp_en_trans'

ProbeClass = MMProbe

# label tokens
t_tok = tokenizer.encode('TRUE')[-1]
f_tok = tokenizer.encode('FALSE')[-1]

In [44]:
layer_directions = []

for layer in tqdm(layers):
    # get probe
    if ProbeClass == LRProbe or ProbeClass == MMProbe:
        acts, labels = [], []
        for dataset in train_datasets:
            acts.append(collect_acts(dataset, '7B', layer).to(device))
            labels.append(t.Tensor(pd.read_csv(f'{ROOT}/datasets/{dataset}.csv')['label'].tolist()).to(device))
        acts, labels = t.cat(acts), t.cat(labels)
        probe = ProbeClass.from_data(acts, labels, device=device)
    # get direction
    direction = probe.direction
    true_acts, false_acts = acts[labels==1], acts[labels==0]
    true_mean, false_mean = true_acts.mean(0), false_acts.mean(0)
    direction = direction / direction.norm()
    diff = (true_mean - false_mean) @ direction
    direction = diff * direction

    layer_directions.append(direction)

100%|██████████| 32/32 [00:36<00:00,  1.13s/it]


In [45]:
layer_directions_t = t.stack(layer_directions, dim=0)

In [46]:
layer_directions_t.shape

torch.Size([32, 4096])

In [47]:
layer_directions_t_path = '/home/t-sgolechha/Desktop/mats_research_sprint/directions/llama2_7b_mm_layer_directions_cities.pt'

In [48]:
t.save(layer_directions_t, layer_directions_t_path)

In [38]:
# get probe
if ProbeClass == LRProbe or ProbeClass == MMProbe:
    acts, labels = [], []
    for dataset in train_datasets:
        acts.append(collect_acts(dataset, '7B', layer).to(device))
        labels.append(t.Tensor(pd.read_csv(f'{ROOT}/datasets/{dataset}.csv')['label'].tolist()).to(device))
    acts, labels = t.cat(acts), t.cat(labels)
    probe = ProbeClass.from_data(acts, labels, device=device)

In [39]:
# get direction
direction = probe.direction
true_acts, false_acts = acts[labels==1], acts[labels==0]
true_mean, false_mean = true_acts.mean(0), false_acts.mean(0)
direction = direction / direction.norm()
diff = (true_mean - false_mean) @ direction
direction = diff * direction

torch.Size([4096])

In [None]:
prompt = """\
The Spanish word 'jirafa' means 'giraffe'. This statement is: TRUE
The Spanish word 'escribir' means 'to write'. This statement is: TRUE
The Spanish word 'diccionario' means 'dictionary'. This statement is: TRUE
The Spanish word 'gato' means 'cat'. This statement is: TRUE
The Spanish word 'aire' means 'silver'. This statement is: FALSE"""

In [None]:
# make sure everything is clean going in
for module in model.model.layers:
    module._forward_hooks.clear()

df_out = {'alpha' : [], 'diff' : [], 'tot' : []}

In [None]:
# keep increasing alpha until things get worse
last_diff = -2
diff = -1
tot = 1
alpha = -1
while diff > last_diff and tot > .95:
    last_diff = diff
    alpha += 1
    # get probs
    df = pd.read_csv(f'datasets/{val_dataset}.csv')
    diffs, tots = [], []
    for _, row in tqdm(df.iterrows()):
        if row['label'] == 0 and row['statement'] not in prompt:
            input_ids = tokenizer(prompt + '\n' +  row['statement'] + ' This statement is:', return_tensors='pt').input_ids.to(device)
            period_tok = tokenizer.encode("'test'.")[-1]
            period_idxs = (input_ids == period_tok).nonzero(as_tuple=True)[1]
            intervention_idx = period_idxs[5]

            # intervened prob
            def hook(module, input, output):
                output[0][:,intervention_idx - 1, :] += direction * alpha
                output[0][:, intervention_idx, :] += direction * alpha
                return output
            handle = model.model.layers[layer-1].register_forward_hook(hook)
            probs = model(input_ids).logits[0,-1,:].softmax(-1)
            handle.remove()

            diffs.append(probs[t_tok].item() - probs[f_tok].item())
            tots.append(probs[t_tok].item() + probs[f_tok].item())
    diff = sum(diffs) / len(diffs)
    tot = sum(tots) / len(tots)
    df_out['alpha'].append(alpha)
    df_out['diff'].append(diff)
    df_out['tot'].append(tot)

In [None]:
# save results
log = {
    'train_datasets' : train_datasets,
    'val_dataset' : val_dataset,
    'layer' : layer,
    'probe class' : ProbeClass.__name__,
    'prompt' : prompt,
    'results' : df_out,
    'experiment' : 'false to true'
}

with open('experimental_outputs/label_change_intervention_results.json', 'r') as f:
    data = json.load(f)
data.append(log)
with open('experimental_outputs/label_change_intervention_results.json', 'w') as f:
    json.dump(data, f, indent=4)

In [None]:
px.line(pd.DataFrame(df_out), x='alpha', y=['diff', 'tot'])

In [49]:
acts.shape

torch.Size([1496, 4096])

In [52]:
def get_pcs(X, k=2, offset=0):
    """
    Performs Principal Component Analysis (PCA) on the n x d data matrix X. 
    Returns the k principal components, the corresponding eigenvalues and the projected data.
    """

    # Subtract the mean to center the data
    X = X - t.mean(X, dim=0)
    
    # Compute the covariance matrix
    cov_mat = t.mm(X.t(), X) / (X.size(0) - 1)
    
    # Perform eigen decomposition
    eigenvalues, eigenvectors = t.linalg.eigh(cov_mat.cpu())
    eigenvalues = eigenvalues.to(device)
    eigenvectors = eigenvectors.to(device)
    
    # Since the eigenvalues and vectors are not necessarily sorted, we do that now
    sorted_indices = t.argsort(eigenvalues, descending=True)
    eigenvectors = eigenvectors[:, sorted_indices]
    
    # Select the pcs
    eigenvectors = eigenvectors[:, offset:offset+k]
    
    return eigenvectors

In [53]:
pcs = get_pcs(acts, k=2, offset=0)

In [60]:
pcs_layer = []
# compute pcs for each layer
for layer in tqdm(layers):
    acts = collect_acts('cities', '7B', layer).to(device)
    pcs = get_pcs(acts, k=2, offset=0)
    pcs_layer.append(pcs)

100%|██████████| 32/32 [00:31<00:00,  1.02it/s]


In [61]:
proj_layer = []
# project each layer activations to pcs
for layer in tqdm(layers):
    acts = collect_acts('cities', '7B', layer).to(device)
    proj = acts @ pcs_layer[layer]
    proj_layer.append(proj)

100%|██████████| 32/32 [00:01<00:00, 27.30it/s]


In [64]:
proj_layer_t = t.stack(proj_layer, dim=0)

In [76]:
proj_layer_t.shape

torch.Size([32, 1496, 2])

In [82]:
df = pd.DataFrame()

In [71]:
import numpy as np

In [72]:
df = ntensor_to_long(proj_layer_t, dim_names=['layer', 'datapoint', 'pc'])

In [83]:
df["pc1"] = proj_layer_t[:, :, 0].flatten().cpu().numpy()
df["pc2"] = proj_layer_t[:, :, 1].flatten().cpu().numpy()
df["layer"] = np.repeat(np.arange(32), 1496)

In [84]:
df.shape, df.columns

((47872, 3), Index(['pc1', 'pc2', 'layer'], dtype='object'))

In [86]:
labels = t.Tensor(pd.read_csv(f'{ROOT}/datasets/{dataset}.csv')['label'].tolist())

In [92]:
df['label'] = np.tile(labels.cpu().numpy(), 32)

In [93]:
np.tile(np.arange(5), 32)

array([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1,
       2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3,
       4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0,
       1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2,
       3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4,
       0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1,
       2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3,
       4, 0, 1, 2, 3, 4])

In [99]:
fig = px.scatter(df, x="pc1", y="pc2", animation_frame="layer", color="label", height=600, width=600)
fig.write_html("pc2_llama2-7b-cities.html")

In [96]:
proj_layer_t.shape

torch.Size([32, 1496, 2])

In [98]:
layer_directions_t.shape

torch.Size([32, 4096])