In [1]:
import torch
from dataclasses import dataclass
from sae import TopKSAE, TopKSAEConfig
from generate_residuals import EmbeddingGeneratorConfig
import os
from tqdm import tqdm

device= 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
assert device == 'cuda', "This notebook is not optimized for CPU"


block_size: 512
n_layer: 14
n_head: 16
n_embd: 512
feed_forward_factor: 2.5
vocab_size: 8192
data_dir: dataset
expt_name: restart_good_3hr_search
batch_size: 128
max_lr: 0.002
min_lr: 0.0001
beta_1: 0.9
beta_2: 0.99
warmup_steps: 50
max_steps: 60000
max_runtime_seconds: 10800
weight_decay: 0.12
need_epoch_reshuffle: True
matmul_precision: high
smoke_test: False


In [2]:

embdconfig = torch.load("./residuals/0.pt")['config']

In [3]:
embdconfig

EmbeddingGeneratorConfig(batch_size=512, block_size=512, n_embd=512, ratio_tokens_saved=0.07, residual_layer=10, mb_per_save=2000, save_dir='./residuals/')

In [4]:
embdconfig

saeconfig = TopKSAEConfig(
    embedding_size=embdconfig.n_embd,
    n_features=32768,
    topk=24,
    lr = 1e-3,
    batch_size=4096,
)

In [5]:
sae = TopKSAE(saeconfig)

In [6]:
residuals_dir = embdconfig.save_dir
ratio_train=0.9

residuals_files = [os.path.join(residuals_dir, f) for f in os.listdir(residuals_dir)]

train_files = residuals_files[:int(len(residuals_files)*ratio_train)]
test_files = residuals_files[int(len(residuals_files)*ratio_train):]

print(f"Training on {len(train_files)} files")
print(f"Testing on {len(test_files)} files")

Training on 14 files
Testing on 2 files


In [7]:
optimizer = torch.optim.Adam(sae.parameters(), lr=sae.config.lr)

# Run your forward pass
for f in train_files:
    print(f"Loading {f}")
    data = torch.load(f)['residuals']
    data = data[torch.randperm(data.shape[0])]
    data = data.to(torch.float32)
    # b a t c h data
    print(data.shape)
    data = data[:data.shape[0]//sae.config.batch_size*sae.config.batch_size] # cut off the last (incomplete) batch
    data = data.view(-1, sae.config.batch_size, sae.config.embedding_size)
    for batch in tqdm(data):
        sae_out = sae(batch)
        loss = sae_out['mse']
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print(loss.item())



Loading ./residuals/15.pt


torch.Size([1963450, 512])


100%|██████████| 479/479 [00:13<00:00, 35.75it/s]


0.3685932159423828
Loading ./residuals/14.pt
torch.Size([1963450, 512])


100%|██████████| 479/479 [00:13<00:00, 35.63it/s]


0.3354792892932892
Loading ./residuals/13.pt
torch.Size([1963450, 512])


100%|██████████| 479/479 [00:13<00:00, 35.34it/s]


0.32165712118148804
Loading ./residuals/12.pt
torch.Size([1963450, 512])


100%|██████████| 479/479 [00:13<00:00, 35.12it/s]


0.31672143936157227
Loading ./residuals/11.pt
torch.Size([1963450, 512])


100%|██████████| 479/479 [00:13<00:00, 34.93it/s]


0.3071986138820648
Loading ./residuals/10.pt
torch.Size([1963450, 512])


100%|██████████| 479/479 [00:13<00:00, 34.73it/s]


0.3015247583389282
Loading ./residuals/9.pt
torch.Size([1963450, 512])


100%|██████████| 479/479 [00:13<00:00, 34.45it/s]


0.30244624614715576
Loading ./residuals/8.pt
torch.Size([1963450, 512])


100%|██████████| 479/479 [00:14<00:00, 34.04it/s]


0.30095475912094116
Loading ./residuals/7.pt
torch.Size([1963450, 512])


100%|██████████| 479/479 [00:13<00:00, 34.44it/s]


0.2984688878059387
Loading ./residuals/6.pt
torch.Size([1963450, 512])


100%|██████████| 479/479 [00:13<00:00, 34.41it/s]


0.2971307039260864
Loading ./residuals/5.pt
torch.Size([1963450, 512])


100%|██████████| 479/479 [00:13<00:00, 34.39it/s]


0.2946024239063263
Loading ./residuals/4.pt
torch.Size([1963450, 512])


100%|██████████| 479/479 [00:13<00:00, 34.30it/s]


0.29326707124710083
Loading ./residuals/3.pt
torch.Size([1963450, 512])


100%|██████████| 479/479 [00:13<00:00, 34.25it/s]


0.293413907289505
Loading ./residuals/2.pt
torch.Size([1963450, 512])


100%|██████████| 479/479 [00:14<00:00, 34.19it/s]

0.2851043939590454





In [8]:
torch.save({'model': sae.state_dict(),
            'config': sae.config}, 
            'saes/sae.pt')
           