In [1]:
import json
from pathlib import Path

import numpy as np
import torch as t
import webdataset as wds

from thought_transfer.sparse_recoder import SparseRecoder   

In [2]:
model_id = "v1_sae1_100000_20240830231731"
config = json.loads(Path(f"../outputs/models/{model_id}.json").read_text())
model = SparseRecoder(**config)
model.load_state_dict(t.load(f"../outputs/models/{model_id}.pth"))
model = model.cuda()

In [14]:
config

{'d_in': 768, 'd_out': 768, 'd_hidden': 6144, 'k': 5, 'add_topk': True}

In [3]:
values = np.random.rand(5, 768)
values = t.tensor(values, dtype=t.float32).cuda()
enc, recons = model(values)
print(enc.shape, recons.shape)
print(t.mean((values - recons) ** 2))

torch.Size([5, 6144]) torch.Size([5, 768])
tensor(19.8609, device='cuda:0', grad_fn=<MeanBackward0>)


In [9]:
from thought_transfer.train import unpack, get_input_output

In [21]:
dataset_iter = iter(wds.WebDataset("../data/wiki_sentences_10000_60000_20240830_202435.tar"))

text, resid1, resid2 = unpack(next(dataset_iter))
data_in, data_out = get_input_output(resid1, resid2, "sae1")
data_in = t.tensor(data_in, dtype=t.float32).cuda()[0]
data_in.shape

torch.Size([134, 768])

In [23]:
enc, recons = model(data_in)
print(enc.shape, recons.shape)
print(t.mean((data_in - recons) ** 2))

torch.Size([134, 6144]) torch.Size([134, 768])
tensor(0.0066, device='cuda:0', grad_fn=<MeanBackward0>)


In [26]:
t.mean((data_in - recons) ** 2, dim=1)

tensor([0.0002, 0.0004, 0.0015, 0.0025, 0.0047, 0.0091, 0.0106, 0.0055, 0.0041,
        0.0074, 0.0036, 0.0102, 0.0091, 0.0046, 0.0053, 0.0030, 0.0047, 0.0027,
        0.0069, 0.0072, 0.0036, 0.0012, 0.0021, 0.0036, 0.0027, 0.0024, 0.0024,
        0.0029, 0.0055, 0.0061, 0.0060, 0.0107, 0.0052, 0.0029, 0.0101, 0.0114,
        0.0048, 0.0074, 0.0076, 0.0096, 0.0100, 0.0117, 0.0097, 0.0042, 0.0043,
        0.0044, 0.0112, 0.0036, 0.0097, 0.0108, 0.0088, 0.0041, 0.0062, 0.0096,
        0.0104, 0.0079, 0.0040, 0.0012, 0.0051, 0.0039, 0.0122, 0.0127, 0.0050,
        0.0043, 0.0046, 0.0084, 0.0043, 0.0035, 0.0095, 0.0119, 0.0059, 0.0043,
        0.0063, 0.0036, 0.0100, 0.0087, 0.0093, 0.0060, 0.0099, 0.0075, 0.0044,
        0.0079, 0.0056, 0.0114, 0.0110, 0.0037, 0.0020, 0.0026, 0.0094, 0.0088,
        0.0119, 0.0059, 0.0048, 0.0055, 0.0022, 0.0009, 0.0120, 0.0112, 0.0149,
        0.0143, 0.0085, 0.0040, 0.0069, 0.0048, 0.0096, 0.0063, 0.0049, 0.0044,
        0.0067, 0.0050, 0.0029, 0.0028, 

In [27]:
t.sum(enc != 0, dim=1)

tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5], device='cuda:0')

In [29]:
model.decoder(enc[0:1])

tensor([[ 3.4579e-01, -8.5966e-01,  7.9997e-01,  5.7606e-01,  5.1375e-01,
          7.1244e-01,  3.1541e-01,  7.1562e-01,  6.6337e-01,  4.0874e-01,
          5.3442e-01,  2.3515e-01,  3.3904e-01,  2.5411e-01,  5.7630e-01,
          4.4105e-01,  2.4718e-01,  5.5607e-01,  4.9332e-01,  3.4102e-01,
          7.2688e-01,  3.2475e-01,  4.4396e-01,  4.8632e-01,  6.0079e-01,
          6.8202e-01,  5.6782e-01,  6.5708e-01,  3.6430e-01,  2.5520e-01,
          3.9471e-01,  6.6204e-01,  8.1521e-01,  6.2110e-01,  4.5968e-01,
          7.0670e-01,  4.4246e-01,  6.2241e-01,  2.7467e-01,  4.7341e-01,
          5.6311e-01,  6.7837e-01,  4.5122e-01,  2.7583e-01,  4.7569e-01,
          6.9155e-01,  3.7528e-01, -7.7473e-02,  5.0323e-01,  4.0237e-01,
          5.4091e-01,  4.6960e-01,  4.0502e-01,  6.0138e-01,  3.5814e-01,
          6.3176e-01,  6.2696e-01,  5.0988e-01,  6.8141e-01,  2.7431e-01,
          2.5397e-01,  2.2945e-01,  4.5011e-01,  1.8967e-01,  4.4073e-01,
          6.3364e-01,  2.9243e-01,  5.

In [30]:
data_in[0:1]

tensor([[ 3.4905e-01, -8.6757e-01,  7.9859e-01,  5.2501e-01,  5.3448e-01,
          6.9719e-01,  3.1645e-01,  7.0870e-01,  6.5303e-01,  4.1587e-01,
          5.3241e-01,  2.3489e-01,  3.4152e-01,  2.5828e-01,  5.8796e-01,
          4.4585e-01,  2.5482e-01,  5.7225e-01,  4.9722e-01,  3.5390e-01,
          7.3329e-01,  3.3206e-01,  4.3404e-01,  4.7741e-01,  6.1378e-01,
          6.8703e-01,  5.6901e-01,  6.5461e-01,  3.6051e-01,  2.5420e-01,
          4.0583e-01,  6.7008e-01,  8.1667e-01,  6.2020e-01,  4.7169e-01,
          7.0074e-01,  4.4300e-01,  6.2317e-01,  2.7435e-01,  4.5343e-01,
          5.7283e-01,  6.6590e-01,  4.5660e-01,  2.7317e-01,  4.7113e-01,
          6.9130e-01,  3.7925e-01, -7.2070e-02,  5.0837e-01,  4.0746e-01,
          5.3928e-01,  4.8130e-01,  4.0935e-01,  6.0876e-01,  3.5426e-01,
          6.3735e-01,  6.1800e-01,  5.3834e-01,  6.7435e-01,  2.9438e-01,
          2.4602e-01,  2.3112e-01,  4.4872e-01,  2.0129e-01,  4.3110e-01,
          6.3825e-01,  2.8465e-01,  6.

In [31]:
import matplotlib.pyplot as plt

In [34]:
# compute singular values of data_in
u, s, v = t.svd(data_in)
print(s)

tensor([376.1247,  20.0699,   9.1911,   8.3967,   8.3428,   7.4547,   7.1551,
          6.6116,   6.3364,   6.1646,   5.8336,   5.6752,   5.6347,   5.5108,
          5.3415,   5.2723,   5.0635,   4.8574,   4.7790,   4.7002,   4.5621,
          4.5043,   4.4368,   4.3857,   4.3322,   4.2206,   4.1496,   4.0738,
          3.9974,   3.9320,   3.9055,   3.7816,   3.7497,   3.6504,   3.5981,
          3.5780,   3.5608,   3.4645,   3.4079,   3.3928,   3.3253,   3.2673,
          3.2560,   3.1897,   3.1302,   3.0906,   3.0568,   2.9792,   2.9650,
          2.9250,   2.8848,   2.8147,   2.8031,   2.7763,   2.7605,   2.7350,
          2.6820,   2.6312,   2.6063,   2.5770,   2.5603,   2.5293,   2.5168,
          2.4336,   2.4114,   2.3899,   2.3791,   2.3695,   2.3541,   2.3295,
          2.2995,   2.2688,   2.2496,   2.2162,   2.1915,   2.1300,   2.0999,
          2.0933,   2.0602,   2.0362,   1.9842,   1.9588,   1.9362,   1.8889,
          1.8756,   1.8500,   1.8291,   1.7684,   1.7583,   1.72

Very nonuniform SVD, probably the reason for the close reconstruction.

In [40]:
u, s, v = t.svd(data_in[1:, :])

In [41]:
s

tensor([20.3406,  9.1911,  8.3981,  8.3428,  7.4550,  7.1599,  6.6124,  6.3366,
         6.1646,  5.8353,  5.6764,  5.6369,  5.5108,  5.3415,  5.2723,  5.0635,
         4.8575,  4.7816,  4.7004,  4.5624,  4.5056,  4.4369,  4.3859,  4.3330,
         4.2206,  4.1499,  4.0750,  3.9979,  3.9322,  3.9055,  3.7847,  3.7498,
         3.6505,  3.5984,  3.5780,  3.5614,  3.4645,  3.4085,  3.3928,  3.3255,
         3.2678,  3.2560,  3.1903,  3.1303,  3.0907,  3.0568,  2.9792,  2.9671,
         2.9265,  2.8853,  2.8150,  2.8031,  2.7763,  2.7607,  2.7350,  2.6824,
         2.6312,  2.6065,  2.5772,  2.5608,  2.5307,  2.5169,  2.4336,  2.4121,
         2.3906,  2.3791,  2.3698,  2.3541,  2.3296,  2.2996,  2.2689,  2.2500,
         2.2201,  2.1915,  2.1300,  2.0999,  2.0933,  2.0602,  2.0363,  1.9843,
         1.9592,  1.9367,  1.8892,  1.8759,  1.8500,  1.8296,  1.7693,  1.7583,
         1.7274,  1.7156,  1.6686,  1.6547,  1.6397,  1.6212,  1.6035,  1.5895,
         1.5574,  1.5212,  1.4969,  1.46

In [None]:
plt.plot(s.cpu())