In [10]:
import json

import numpy as np
import torch as t
import webdataset as wds
import datasets

from thought_transfer.sparse_recoder import SparseRecoder

In [5]:
sae1_id = 'sae1_20240831_015411'
sae2_id = 'sae2_20240831_015411'
trans_id = 'translate_20240831_015411'

In [33]:
def load_model(model_id):
    config_file = f'../outputs/models/{model_id}.json'
    config = json.load(open(config_file, 'r'))
    model = SparseRecoder(**config)
    model.load_state_dict(t.load(f'../outputs/models/{model_id}.pth'))
    model = model.to('cuda')
    return model

In [34]:
sae1 = load_model(sae1_id)
sae2 = load_model(sae2_id)
trans = load_model(trans_id)

In [35]:
sae1

SparseRecoder(
  (encoder): Linear(in_features=768, out_features=6144, bias=True)
  (decoder): Linear(in_features=6144, out_features=768, bias=True)
)

In [36]:
from thought_transfer.train import unpack, get_input_output

In [37]:
def explained_variance(pred, target):
    return t.mean(1 - t.var(pred - target, dim=0) / (t.var(target, dim=0) + 1e-8))

In [38]:
from tqdm.notebook import tqdm

In [42]:
inds = np.load("../data/wiki_split.npz")
# We ended up not using this part
data = wds.WebDataset(f"../data/wiki_sentences_100000_200000_20240830_225314.tar")

# Translation model translates OPT -> GPT2
# (there is a bug in naming the dataset columns)
sum_explained_variance_1 = 0
sum_explained_variance_2 = 0
sum_explained_variance_trans = 0

num_samples = 1000

for i, sample in tqdm(enumerate(data), total=num_samples):
    if i == num_samples:
        break
    text, resid_opt, resid_gpt2 = unpack(sample)
    # first dim is 1
    resid_opt = t.tensor(resid_opt[0], device='cuda')
    resid_gpt2 = t.tensor(resid_gpt2[0], device='cuda')
    
    var_1 = explained_variance(sae1(resid_opt)[1], resid_opt)
    var_2 = explained_variance(sae2(resid_gpt2)[1], resid_gpt2)
    var_trans = explained_variance(trans(resid_opt)[1], resid_gpt2)
    
    sum_explained_variance_1 += var_1
    sum_explained_variance_2 += var_2
    sum_explained_variance_trans  += var_trans
    
print(f"Explained variance for OPT: {sum_explained_variance_1 / num_samples}")
print(f"Explained variance for GPT2: {sum_explained_variance_2 / num_samples}")
print(f"Explained variance for OPT -> GPT2: {sum_explained_variance_trans / num_samples}")

  0%|          | 0/1000 [00:00<?, ?it/s]

Explained variance for OPT: -0.36211150884628296
Explained variance for GPT2: 0.4416324496269226
Explained variance for OPT -> GPT2: 0.46607792377471924
