In [1]:
import torch
from dataclasses import dataclass
from sae import TopKSAE, TopKSAEConfig
from generate_residuals import EmbeddingGeneratorConfig
import os
from tqdm import tqdm

device= 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
assert device == 'cuda', "This notebook is not optimized for CPU"


In [2]:

embdconfig = torch.load("./small_residuals/0.pt")['config']

In [3]:
embdconfig

EmbeddingGeneratorConfig(batch_size=512, block_size=512, n_embd=384, ratio_tokens_saved=0.07, residual_layer=8, mb_per_save=2000, save_dir='./small_residuals')

In [7]:
embdconfig

saeconfig = TopKSAEConfig(
    embedding_size=embdconfig.n_embd,
    n_features=32768//2,
    topk=24,
    lr = 1e-3,
    batch_size=4096,
    latent_bias=False,
)

In [8]:
sae = TopKSAE(saeconfig)

In [9]:
residuals_dir = embdconfig.save_dir
ratio_train=0.9

residuals_files = [os.path.join(residuals_dir, f) for f in os.listdir(residuals_dir)]

train_files = residuals_files[:int(len(residuals_files)*ratio_train)]
test_files = residuals_files[int(len(residuals_files)*ratio_train):]

print(f"Training on {len(train_files)} files")
print(f"Testing on {len(test_files)} files")

Training on 11 files
Testing on 2 files


In [10]:
optimizer = torch.optim.Adam(sae.parameters(), lr=sae.config.lr)

# Run your forward pass
for f in train_files:
    print(f"Loading {f}")
    data = torch.load(f)['residuals']
    data = data[torch.randperm(data.shape[0])]
    data = data.to(torch.float32)
    # b a t c h data
    print(data.shape)
    data = data[:data.shape[0]//sae.config.batch_size*sae.config.batch_size] # cut off the last (incomplete) batch
    data = data.view(-1, sae.config.batch_size, sae.config.embedding_size)
    for batch in tqdm(data):
        sae_out = sae(batch)
        loss = sae_out['mse']
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print(loss.item())



Loading ./small_residuals/12.pt
torch.Size([1394600, 384])


100%|██████████| 340/340 [00:03<00:00, 106.94it/s]


0.5932797193527222
Loading ./small_residuals/11.pt
torch.Size([2605700, 384])


100%|██████████| 636/636 [00:05<00:00, 113.03it/s]


0.40341418981552124
Loading ./small_residuals/10.pt
torch.Size([2605700, 384])


100%|██████████| 636/636 [00:05<00:00, 114.07it/s]


0.36649155616760254
Loading ./small_residuals/9.pt
torch.Size([2605700, 384])


100%|██████████| 636/636 [00:05<00:00, 113.25it/s]


0.35675835609436035
Loading ./small_residuals/8.pt
torch.Size([2605700, 384])


100%|██████████| 636/636 [00:05<00:00, 113.24it/s]


0.34561464190483093
Loading ./small_residuals/7.pt
torch.Size([2605700, 384])


100%|██████████| 636/636 [00:05<00:00, 113.48it/s]


0.34066280722618103
Loading ./small_residuals/6.pt
torch.Size([2605700, 384])


100%|██████████| 636/636 [00:05<00:00, 113.58it/s]


0.33527040481567383
Loading ./small_residuals/5.pt
torch.Size([2605700, 384])


100%|██████████| 636/636 [00:05<00:00, 113.46it/s]


0.32978227734565735
Loading ./small_residuals/4.pt
torch.Size([2605700, 384])


100%|██████████| 636/636 [00:05<00:00, 112.24it/s]


0.3294694423675537
Loading ./small_residuals/3.pt
torch.Size([2605700, 384])


100%|██████████| 636/636 [00:05<00:00, 113.52it/s]


0.3307689428329468
Loading ./small_residuals/2.pt
torch.Size([2605700, 384])


100%|██████████| 636/636 [00:05<00:00, 113.34it/s]

0.3219376504421234





In [None]:
torch.save({'model': sae.state_dict(),
            'config': sae.config}, 
            'saes/sae_small_16k_features.pt')
           