## Set up dependencies.

In [1]:
!pip install transformers torch torchinfo tqdm



In [2]:
from pathlib import Path
from tqdm import tqdm
import os
from datasets import load_dataset, Dataset, IterableDataset

## Download the Dolma v1.6_sample dataset (~16GB)

In [3]:
data = load_dataset('allenai/dolma', name='v1_6-sample', trust_remote_code=True, streaming=True, split="train")

## Load phi-2 and set up a method that gets the middle layer's activations.

In [4]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)

phi = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype=torch.float16, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)

2024-07-25 17:36:31.821871: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-07-25 17:36:31.859638: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
# Suppress warnings about tokenising long passages. 
# We're going to handle splitting the text ourselves.
import sys
tokenizer.model_max_length = sys.maxsize

In [6]:
def tokenize(text):
    return tokenizer(text, return_tensors="pt", return_attention_mask=False)['input_ids']

In [7]:
@torch.no_grad()
def get_activations(tokens):
    """
    Takes a tensor of shape e.g. (B, T) containing a batch of token strings,
    and outputs a tensor of shape (B, T, 2560) containing the activations
    of the middle layer (hidden layer 15) at each token.
    """
    out = phi(tokens, output_hidden_states=True, return_dict=True)
    hidden_states = out.hidden_states
    middle_layer = hidden_states[16]
    return middle_layer.float()

In [8]:
def get_token_blocks(text, block_size):
    """
    Given a text, tokenizes it and returns a tensor of shape 
    (N, BS) containing the text as a series of blocks of 
    tokens of size BS.
    N is the number of blocks contained within the text.   
    """    
    block_size = 100
    tokens = tokenize(text).view(-1)
    truncated_length = (len(tokens) // block_size) * block_size
    tokens = tokens[:truncated_length]
    block_tensor = tokens.view(-1, block_size)
    blocks = torch.split(block_tensor, 1, dim=0)    
    return [b.view(-1) for b in blocks]

In [9]:
def all_blocks(split, block_size=100):    
    for i, example in enumerate(data):
        is_test_example = i % 10 == 0
        if split == 'train' and is_test_example or split != 'train' and not is_test_example:
            continue
        for block in get_token_blocks(example['text'], block_size=block_size):                        
            yield block

In [10]:
def batch_iterable(iterable, batch_size):
    current_batch = []
    for item in iterable:
        current_batch.append(item)
        if len(current_batch) == batch_size:
            yield current_batch
            current_batch = []


def all_activations(split, block_size=100, minibatch_size=5, start_block=0, parallelism=10, workahead_cache_size=50):
    """
    Generator yielding minibatches of training examples.
    Each example is returned as a tuple (tokens, activations).
     - tokens is a tensor of dimention (MB, BS) where MB 
       is the specified minibatch size, and BS is the block size.
     - activations is a tensor of dimension (MB, BS, 2560)
    """    
    global phi
    assert workahead_cache_size % parallelism == 0
    assert workahead_cache_size % minibatch_size == 0
    assert parallelism % minibatch_size == 0
    blocks = islice(all_blocks(split, block_size=block_size), start_block, None)
    for block_megabatch in batch_iterable(blocks, workahead_cache_size):  
        phi = phi.to("cuda")
        minibatches = []        
        for inputs in torch.stack(block_megabatch).split(parallelism):            
            activations = get_activations(inputs)      
            minibatch_inputs = inputs.split(minibatch_size)
            minibatch_activations = activations.split(minibatch_size)
            minibatches.extend(zip(minibatch_inputs, minibatch_activations))
        phi = phi.to("cpu")
        
        for minibatch_input, minibatch_activation in minibatches:
            yield minibatch_input, minibatch_activation

## Now define our autoencoder

In [11]:
class SparseAutoencoder(torch.nn.Module):
    def __init__(self, activation_dimension=2560, inner_dimension=100_000):
        super().__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(activation_dimension, inner_dimension),
            torch.nn.ReLU()
        )
        self.decoder = torch.nn.Linear(inner_dimension, activation_dimension)

    def forward(self, activations):
        encoded = self.encoder(activations)
        decoded = self.decoder(encoded)
        return decoded, encoded  # Return encoded too, as it's used in the loss fn.

In [12]:
if Path('sae.pt').exists():
    sae = torch.load('sae.pt').to(device)
    is_trained = True
else:
    sae = SparseAutoencoder()
    is_trained = False

In [13]:
from torch.nn import functional as F
def calculate_loss(sae_model, prediction, target, feature_activations, lamb=5):
    weight_norms = torch.norm(sae_model.decoder.weight, dim=0, p=2).view(-1)
    feature_sizes = torch.abs(feature_activations)
    sparsity_loss = lamb * torch.sum(weight_norms * feature_sizes)

    prediction_loss = F.mse_loss(prediction, target, reduction='sum')

    total_loss = prediction_loss + sparsity_loss
    
    return total_loss, prediction_loss, sparsity_loss    

In [14]:
t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
f = r-a  # free inside reserved
print(t, r, a, f)

15655829504 7822376960 7616183296 206193664


## Train our autoencoder on 1M tokens' worth of activations.

In [15]:
from itertools import islice

@torch.no_grad()
def get_eval_loss(sae_model, minibatches, tokens_per_minibatch):
    losses = []
    sparsities = []
    accuracies = []
    for tokens, activations in minibatches:
        decoded, encoded = sae_model(activations)
        loss, accuracy, sparsity = calculate_loss(sae_model, decoded, activations, encoded)
        losses.append(loss.item() / tokens_per_minibatch)
        accuracies.append(accuracy.item() / tokens_per_minibatch)        
        sparsities.append(sparsity.item() / tokens_per_minibatch)        
    return sum(losses) / len(losses), sum(accuracies) / len(accuracies), sum(sparsities) / len(sparsities)

def train(sae_model, num_tokens=1000000, minibatch_size=5, block_size=100, loss_history=100, adam_beta1=0.9, adam_beta2=0.999, lr=5E-5, eval_tokens=1000, show_eval_token_interval=25000, start_block=0):            
    optimizer = torch.optim.AdamW(sae_model.parameters(), lr=lr, betas=(adam_beta1, adam_beta2))
    tokens_per_minibatch = minibatch_size * block_size
    num_minibatches = num_tokens // tokens_per_minibatch
    eval_minibatches = eval_tokens // tokens_per_minibatch
    losses = []
    sparsities = []
    activations_gen = all_activations('train', minibatch_size=minibatch_size, block_size=block_size + start_block)
    test_activations_gen = all_activations('test', minibatch_size=minibatch_size, block_size=block_size, start_block=1000+start_block)    
    
    training_data = islice(activations_gen, num_minibatches)         
    show_eval_minibatch_interval = show_eval_token_interval // tokens_per_minibatch

    pbar = tqdm(training_data, total=num_minibatches, smoothing=0)
    for idx, (tokens, activations) in enumerate(pbar):        
        decoded, encoded = sae_model(activations)        
        loss, accuracy, sparsity = calculate_loss(sae_model, decoded, activations, encoded)        
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()        
        losses.append(loss.item() / tokens_per_minibatch)
        sparsities.append(sparsity.item() / tokens_per_minibatch)
        if len(losses) > loss_history:
            losses = losses[1:]
            sparsities = sparsities[1:]
        avg_loss = sum(losses) / len(losses)
        avg_sparsity = sum(sparsities) / len(sparsities)
        tokens_seen = (idx + 1) * tokens_per_minibatch
        pbar.set_description(f'Tokens seen: {tokens_seen}; Loss: {avg_loss:.2f}; Sparsity-specific loss: {avg_sparsity:.2f}')
        if (idx + 1) % show_eval_minibatch_interval == 0:
            minibatches = islice(test_activations_gen, eval_minibatches)
            eval_loss, _, eval_sparsity = get_eval_loss(sae_model, minibatches, tokens_per_minibatch)
            print(f'Seen {tokens_seen} tokens')
            print(f'\t - Training data: loss is {avg_loss:.2f}; sparsity loss is {avg_sparsity:.2f}')
            print(f'\t - Eval data: loss is {eval_loss:.2f}; sparsity loss is {eval_sparsity:.2f}')            
            torch.save(sae, 'sae.pt')

if not is_trained:
    train(sae)
    is_trained = True
    torch.save(sae, 'sae.pt')

Tokens seen: 25000; Loss: 20933.04; Sparsity-specific loss: 10754.98:   2%|▏         | 49/2000 [00:57<38:03,  1.17s/it] 

Seen 25000 tokens
	 - Training data: loss is 20933.04; sparsity loss is 10754.98
	 - Eval data: loss is 10815.54; sparsity loss is 7070.77


Tokens seen: 50000; Loss: 15219.29; Sparsity-specific loss: 8368.90:   5%|▍         | 99/2000 [01:57<37:27,  1.18s/it] 

Seen 50000 tokens
	 - Training data: loss is 15219.29; sparsity loss is 8368.90
	 - Eval data: loss is 8329.01; sparsity loss is 4630.19


Tokens seen: 75000; Loss: 8548.77; Sparsity-specific loss: 5178.10:   7%|▋         | 149/2000 [02:48<34:55,  1.13s/it] 

Seen 75000 tokens
	 - Training data: loss is 8548.77; sparsity loss is 5178.10
	 - Eval data: loss is 7096.80; sparsity loss is 3716.87


Tokens seen: 100000; Loss: 6943.79; Sparsity-specific loss: 3883.54:  10%|▉         | 199/2000 [03:39<33:09,  1.10s/it]

Seen 100000 tokens
	 - Training data: loss is 6943.79; sparsity loss is 3883.54
	 - Eval data: loss is 6099.36; sparsity loss is 2952.13


Tokens seen: 125000; Loss: 5782.35; Sparsity-specific loss: 2989.93:  12%|█▏        | 249/2000 [04:31<31:46,  1.09s/it]

Seen 125000 tokens
	 - Training data: loss is 5782.35; sparsity loss is 2989.93
	 - Eval data: loss is 5183.12; sparsity loss is 2226.75


Tokens seen: 150000; Loss: 4877.06; Sparsity-specific loss: 2269.96:  15%|█▍        | 299/2000 [05:22<30:33,  1.08s/it]

Seen 150000 tokens
	 - Training data: loss is 4877.06; sparsity loss is 2269.96
	 - Eval data: loss is 4446.63; sparsity loss is 1690.72


Tokens seen: 175000; Loss: 4229.23; Sparsity-specific loss: 1713.22:  17%|█▋        | 349/2000 [06:17<29:46,  1.08s/it]

Seen 175000 tokens
	 - Training data: loss is 4229.23; sparsity loss is 1713.22
	 - Eval data: loss is 3987.32; sparsity loss is 1196.26


Tokens seen: 200000; Loss: 3912.51; Sparsity-specific loss: 1337.31:  20%|█▉        | 399/2000 [07:09<28:42,  1.08s/it]

Seen 200000 tokens
	 - Training data: loss is 3912.51; sparsity loss is 1337.31
	 - Eval data: loss is 3676.87; sparsity loss is 864.29


Tokens seen: 225000; Loss: 3726.95; Sparsity-specific loss: 1054.77:  22%|██▏       | 449/2000 [08:00<27:40,  1.07s/it]

Seen 225000 tokens
	 - Training data: loss is 3726.95; sparsity loss is 1054.77
	 - Eval data: loss is 3650.05; sparsity loss is 714.87


Tokens seen: 250000; Loss: 3474.29; Sparsity-specific loss: 817.90:  25%|██▍       | 499/2000 [08:52<26:40,  1.07s/it] 

Seen 250000 tokens
	 - Training data: loss is 3474.29; sparsity loss is 817.90
	 - Eval data: loss is 3384.58; sparsity loss is 667.71


Tokens seen: 275000; Loss: 3225.81; Sparsity-specific loss: 679.14:  27%|██▋       | 549/2000 [09:44<25:43,  1.06s/it]

Seen 275000 tokens
	 - Training data: loss is 3225.81; sparsity loss is 679.14
	 - Eval data: loss is 3298.96; sparsity loss is 618.56


Tokens seen: 300000; Loss: 2892.85; Sparsity-specific loss: 610.35:  30%|██▉       | 599/2000 [10:39<24:56,  1.07s/it]

Seen 300000 tokens
	 - Training data: loss is 2892.85; sparsity loss is 610.35
	 - Eval data: loss is 3112.89; sparsity loss is 562.26


Tokens seen: 325000; Loss: 2586.66; Sparsity-specific loss: 590.88:  32%|███▏      | 649/2000 [11:30<23:57,  1.06s/it]

Seen 325000 tokens
	 - Training data: loss is 2586.66; sparsity loss is 590.88
	 - Eval data: loss is 3102.28; sparsity loss is 598.11


Tokens seen: 350000; Loss: 2404.09; Sparsity-specific loss: 592.08:  35%|███▍      | 699/2000 [12:20<22:59,  1.06s/it]

Seen 350000 tokens
	 - Training data: loss is 2404.09; sparsity loss is 592.08
	 - Eval data: loss is 3109.60; sparsity loss is 611.18


Tokens seen: 375000; Loss: 2285.37; Sparsity-specific loss: 595.19:  37%|███▋      | 749/2000 [13:11<22:02,  1.06s/it]

Seen 375000 tokens
	 - Training data: loss is 2285.37; sparsity loss is 595.19
	 - Eval data: loss is 3448.28; sparsity loss is 664.23


Tokens seen: 400000; Loss: 2255.17; Sparsity-specific loss: 598.25:  40%|███▉      | 799/2000 [14:01<21:05,  1.05s/it]

Seen 400000 tokens
	 - Training data: loss is 2255.17; sparsity loss is 598.25
	 - Eval data: loss is 3338.53; sparsity loss is 645.85


Tokens seen: 425000; Loss: 2247.91; Sparsity-specific loss: 601.14:  42%|████▏     | 849/2000 [14:56<20:15,  1.06s/it]

Seen 425000 tokens
	 - Training data: loss is 2247.91; sparsity loss is 601.14
	 - Eval data: loss is 3409.00; sparsity loss is 665.43


Tokens seen: 450000; Loss: 2226.60; Sparsity-specific loss: 602.64:  45%|████▍     | 899/2000 [15:47<19:20,  1.05s/it]

Seen 450000 tokens
	 - Training data: loss is 2226.60; sparsity loss is 602.64
	 - Eval data: loss is 2986.37; sparsity loss is 594.48


Tokens seen: 475000; Loss: 2174.69; Sparsity-specific loss: 605.12:  47%|████▋     | 949/2000 [16:37<18:24,  1.05s/it]

Seen 475000 tokens
	 - Training data: loss is 2174.69; sparsity loss is 605.12
	 - Eval data: loss is 2929.66; sparsity loss is 617.49


Tokens seen: 500000; Loss: 2136.36; Sparsity-specific loss: 610.42:  50%|████▉     | 999/2000 [17:28<17:30,  1.05s/it]

Seen 500000 tokens
	 - Training data: loss is 2136.36; sparsity loss is 610.42
	 - Eval data: loss is 2930.72; sparsity loss is 618.33


Tokens seen: 525000; Loss: 2099.28; Sparsity-specific loss: 611.47:  52%|█████▏    | 1049/2000 [18:18<16:35,  1.05s/it]

Seen 525000 tokens
	 - Training data: loss is 2099.28; sparsity loss is 611.47
	 - Eval data: loss is 2910.00; sparsity loss is 616.41


Tokens seen: 550000; Loss: 2325.50; Sparsity-specific loss: 621.15:  55%|█████▍    | 1099/2000 [19:12<15:45,  1.05s/it]

Seen 550000 tokens
	 - Training data: loss is 2325.50; sparsity loss is 621.15
	 - Eval data: loss is 2781.70; sparsity loss is 552.82


Tokens seen: 575000; Loss: 2607.27; Sparsity-specific loss: 615.67:  57%|█████▋    | 1149/2000 [20:02<14:50,  1.05s/it]

Seen 575000 tokens
	 - Training data: loss is 2607.27; sparsity loss is 615.67
	 - Eval data: loss is 2605.43; sparsity loss is 571.00


Tokens seen: 600000; Loss: 2571.68; Sparsity-specific loss: 604.75:  60%|█████▉    | 1199/2000 [20:53<13:57,  1.05s/it]

Seen 600000 tokens
	 - Training data: loss is 2571.68; sparsity loss is 604.75
	 - Eval data: loss is 2694.80; sparsity loss is 610.41


Tokens seen: 625000; Loss: 2492.02; Sparsity-specific loss: 615.01:  62%|██████▏   | 1249/2000 [21:43<13:03,  1.04s/it]

Seen 625000 tokens
	 - Training data: loss is 2492.02; sparsity loss is 615.01
	 - Eval data: loss is 2749.69; sparsity loss is 646.32


Tokens seen: 650000; Loss: 2482.45; Sparsity-specific loss: 627.44:  65%|██████▍   | 1299/2000 [22:33<12:10,  1.04s/it]

Seen 650000 tokens
	 - Training data: loss is 2482.45; sparsity loss is 627.44
	 - Eval data: loss is 2738.97; sparsity loss is 622.18


Tokens seen: 675000; Loss: 2511.90; Sparsity-specific loss: 630.11:  67%|██████▋   | 1349/2000 [23:28<11:19,  1.04s/it]

Seen 675000 tokens
	 - Training data: loss is 2511.90; sparsity loss is 630.11
	 - Eval data: loss is 2634.61; sparsity loss is 609.06


Tokens seen: 700000; Loss: 2569.25; Sparsity-specific loss: 632.83:  70%|██████▉   | 1399/2000 [24:19<10:26,  1.04s/it]

Seen 700000 tokens
	 - Training data: loss is 2569.25; sparsity loss is 632.83
	 - Eval data: loss is 2618.81; sparsity loss is 606.46


Tokens seen: 725000; Loss: 2529.63; Sparsity-specific loss: 639.86:  72%|███████▏  | 1449/2000 [25:09<09:34,  1.04s/it]

Seen 725000 tokens
	 - Training data: loss is 2529.63; sparsity loss is 639.86
	 - Eval data: loss is 2670.80; sparsity loss is 646.92


Tokens seen: 750000; Loss: 2399.98; Sparsity-specific loss: 640.15:  75%|███████▍  | 1499/2000 [25:59<08:41,  1.04s/it]

Seen 750000 tokens
	 - Training data: loss is 2399.98; sparsity loss is 640.15
	 - Eval data: loss is 2733.29; sparsity loss is 672.32


Tokens seen: 775000; Loss: 2335.51; Sparsity-specific loss: 645.79:  77%|███████▋  | 1549/2000 [26:50<07:48,  1.04s/it]

Seen 775000 tokens
	 - Training data: loss is 2335.51; sparsity loss is 645.79
	 - Eval data: loss is 2685.68; sparsity loss is 672.58


Tokens seen: 800000; Loss: 2307.24; Sparsity-specific loss: 650.73:  80%|███████▉  | 1599/2000 [27:45<06:57,  1.04s/it]

Seen 800000 tokens
	 - Training data: loss is 2307.24; sparsity loss is 650.73
	 - Eval data: loss is 2657.08; sparsity loss is 702.93


Tokens seen: 825000; Loss: 2358.24; Sparsity-specific loss: 650.78:  82%|████████▏ | 1649/2000 [28:35<06:05,  1.04s/it]

Seen 825000 tokens
	 - Training data: loss is 2358.24; sparsity loss is 650.78
	 - Eval data: loss is 2652.35; sparsity loss is 647.55


Tokens seen: 850000; Loss: 2385.71; Sparsity-specific loss: 646.72:  85%|████████▍ | 1699/2000 [29:25<05:12,  1.04s/it]

Seen 850000 tokens
	 - Training data: loss is 2385.71; sparsity loss is 646.72
	 - Eval data: loss is 2613.16; sparsity loss is 663.98


Tokens seen: 875000; Loss: 2328.33; Sparsity-specific loss: 644.54:  87%|████████▋ | 1749/2000 [30:15<04:20,  1.04s/it]

Seen 875000 tokens
	 - Training data: loss is 2328.33; sparsity loss is 644.54
	 - Eval data: loss is 2553.48; sparsity loss is 657.72


Tokens seen: 900000; Loss: 2284.78; Sparsity-specific loss: 651.81:  90%|████████▉ | 1799/2000 [31:06<03:28,  1.04s/it]

Seen 900000 tokens
	 - Training data: loss is 2284.78; sparsity loss is 651.81
	 - Eval data: loss is 2413.09; sparsity loss is 661.73


Tokens seen: 925000; Loss: 2279.90; Sparsity-specific loss: 656.46:  92%|█████████▏| 1849/2000 [32:01<02:36,  1.04s/it]

Seen 925000 tokens
	 - Training data: loss is 2279.90; sparsity loss is 656.46
	 - Eval data: loss is 2469.32; sparsity loss is 649.63


Tokens seen: 950000; Loss: 2252.06; Sparsity-specific loss: 653.83:  95%|█████████▍| 1899/2000 [32:51<01:44,  1.04s/it]

Seen 950000 tokens
	 - Training data: loss is 2252.06; sparsity loss is 653.83
	 - Eval data: loss is 2444.00; sparsity loss is 669.06


Tokens seen: 975000; Loss: 2205.94; Sparsity-specific loss: 655.70:  97%|█████████▋| 1949/2000 [33:42<00:52,  1.04s/it]

Seen 975000 tokens
	 - Training data: loss is 2205.94; sparsity loss is 655.70
	 - Eval data: loss is 2496.29; sparsity loss is 642.31


Tokens seen: 1000000; Loss: 2247.79; Sparsity-specific loss: 664.02: 100%|█████████▉| 1999/2000 [34:34<00:01,  1.04s/it]

Seen 1000000 tokens
	 - Training data: loss is 2247.79; sparsity loss is 664.02
	 - Eval data: loss is 2584.46; sparsity loss is 657.67


Tokens seen: 1000000; Loss: 2247.79; Sparsity-specific loss: 664.02: 100%|██████████| 2000/2000 [34:37<00:00,  1.04s/it]


In [16]:
from torchinfo import summary

print(summary(phi, input_shape=(5, 100)))

print(summary(sae, input_data=next(all_activations(split='train'))[1][0:1]))

# sae2 = sae.half()
# print(summary(sae2, input_data=next(all_activations(split='train'))[1][0:1].half()))

Layer (type:depth-idx)                                  Param #
PhiForCausalLM                                          --
├─PhiModel: 1-1                                         --
│    └─Embedding: 2-1                                   131,072,000
│    └─Dropout: 2-2                                     --
│    └─ModuleList: 2-3                                  --
│    │    └─PhiDecoderLayer: 3-1                        78,671,360
│    │    └─PhiDecoderLayer: 3-2                        78,671,360
│    │    └─PhiDecoderLayer: 3-3                        78,671,360
│    │    └─PhiDecoderLayer: 3-4                        78,671,360
│    │    └─PhiDecoderLayer: 3-5                        78,671,360
│    │    └─PhiDecoderLayer: 3-6                        78,671,360
│    │    └─PhiDecoderLayer: 3-7                        78,671,360
│    │    └─PhiDecoderLayer: 3-8                        78,671,360
│    │    └─PhiDecoderLayer: 3-9                        78,671,360
│    │    └─PhiDecoderLayer: 

In [17]:
should_train_more = False
if should_train_more:
    train(sae, start_block=500, num_tokens=100000)

## Now let's see how our accuracy and sparsity do compared to a random Autoencoder.

In [18]:
sae2 = SparseAutoencoder()

In [19]:
tokens, activations = next(all_activations('test', minibatch_size=5, block_size=100, start_block=10000))

In [20]:
with torch.no_grad():
    print(f'Testing on minibatch including:\n---\n{tokenizer.decode(tokens[0])}\n---\n')
    trained_decoded, trained_encoded = sae(activations)    
    untrained_decoded, untrained_encoded  = sae2(activations)

Testing on minibatch including:
---
 and such conditions. Now let
us suppose a teacher of genius to obtain the post. He not only teaches
admirably, but he institutes school gardens for the children; he takes
long walks with the boys, and gives them the rudiments of geology. He
is in himself an uplifting moral influence, and introduces the children
into a whole new world of idea and of feeling. The parents are pleased.
I will not say that they are grateful; but they
---



In [24]:
with torch.no_grad():
    trained_total, trained_accuracy, trained_sparsity = calculate_loss(sae, trained_decoded, activations, trained_encoded)
    untrained_total, untrained_accuracy, untrained_sparsity = calculate_loss(sae2, untrained_decoded, activations, trained_encoded)

print('Untrained model:')
print(f'\t- Accuracy loss: {untrained_accuracy / 2500:.1f}')
print(f'\t- Sparsity loss: {untrained_sparsity / 2500:.1f}')
print(f'\t- Total loss: {untrained_total / 2500:.1f}')
print('')
print('Trained model:')
print(f'\t- Accuracy loss: {trained_accuracy / 2500:.1f}')
print(f'\t- Sparsity loss: {trained_sparsity / 2500:.1f}')
print(f'\t- Total loss: {trained_total / 2500:.1f}')

Untrained model:
	- Accuracy loss: 6043.6
	- Sparsity loss: 87.7
	- Total loss: 6131.2

Trained model:
	- Accuracy loss: 358.1
	- Sparsity loss: 121.2
	- Total loss: 479.3
