## Load Common Methods

In [1]:
%run Common\ Code.ipynb



2024-08-03 16:21:54.282196: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-08-03 16:21:54.319653: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Now define our autoencoder

In [1]:
# SparseAutoencoder is defined in 'Common Code'

In [3]:
if Path('sae.pt').exists():
    sae = torch.load('sae.pt').to(device)
    is_trained = True
else:
    sae = SparseAutoencoder()
    is_trained = False

In [4]:
from torch.nn import functional as F
def calculate_loss(sae_model, prediction, target, feature_activations, lamb=5):
    weight_norms = torch.norm(sae_model.decoder.weight, dim=0, p=2).view(-1)
    feature_sizes = torch.abs(feature_activations)
    sparsity_loss = lamb * torch.sum(weight_norms * feature_sizes)

    prediction_loss = F.mse_loss(prediction, target, reduction='sum')

    total_loss = prediction_loss + sparsity_loss
    
    return total_loss, prediction_loss, sparsity_loss    

In [5]:
t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
f = r-a  # free inside reserved
print(t, r, a, f)

15655829504 7822376960 7616183296 206193664


## Train our autoencoder on 1M tokens' worth of activations.

In [6]:
from itertools import islice

@torch.no_grad()
def get_eval_loss(sae_model, minibatches, tokens_per_minibatch):
    losses = []
    sparsities = []
    accuracies = []
    for tokens, activations in minibatches:
        decoded, encoded = sae_model(activations)
        loss, accuracy, sparsity = calculate_loss(sae_model, decoded, activations, encoded)
        losses.append(loss.item() / tokens_per_minibatch)
        accuracies.append(accuracy.item() / tokens_per_minibatch)        
        sparsities.append(sparsity.item() / tokens_per_minibatch)        
    return sum(losses) / len(losses), sum(accuracies) / len(accuracies), sum(sparsities) / len(sparsities)

def train(sae_model, num_tokens=1000000, minibatch_size=5, block_size=100, loss_history=100, adam_beta1=0.9, adam_beta2=0.999, lr=5E-5, eval_tokens=1000, show_eval_token_interval=25000, start_block=0):            
    optimizer = torch.optim.AdamW(sae_model.parameters(), lr=lr, betas=(adam_beta1, adam_beta2))
    tokens_per_minibatch = minibatch_size * block_size
    num_minibatches = num_tokens // tokens_per_minibatch
    eval_minibatches = eval_tokens // tokens_per_minibatch
    losses = []
    sparsities = []
    activations_gen = PhiProbeCommons.all_activations('train', minibatch_size=minibatch_size, block_size=block_size + start_block)
    test_activations_gen = PhiProbeCommons.all_activations('test', minibatch_size=minibatch_size, block_size=block_size, start_block=1000+start_block)    
    
    training_data = islice(activations_gen, num_minibatches)         
    show_eval_minibatch_interval = show_eval_token_interval // tokens_per_minibatch

    pbar = tqdm(training_data, total=num_minibatches, smoothing=0)
    for idx, (tokens, activations) in enumerate(pbar):        
        decoded, encoded = sae_model(activations)        
        loss, accuracy, sparsity = calculate_loss(sae_model, decoded, activations, encoded)        
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()        
        losses.append(loss.item() / tokens_per_minibatch)
        sparsities.append(sparsity.item() / tokens_per_minibatch)
        if len(losses) > loss_history:
            losses = losses[1:]
            sparsities = sparsities[1:]
        avg_loss = sum(losses) / len(losses)
        avg_sparsity = sum(sparsities) / len(sparsities)
        tokens_seen = (idx + 1) * tokens_per_minibatch
        pbar.set_description(f'Tokens seen: {tokens_seen}; Loss: {avg_loss:.2f}; Sparsity-specific loss: {avg_sparsity:.2f}')
        if (idx + 1) % show_eval_minibatch_interval == 0:
            minibatches = islice(test_activations_gen, eval_minibatches)
            eval_loss, _, eval_sparsity = get_eval_loss(sae_model, minibatches, tokens_per_minibatch)
            print(f'Seen {tokens_seen} tokens')
            print(f'\t - Training data: loss is {avg_loss:.2f}; sparsity loss is {avg_sparsity:.2f}')
            print(f'\t - Eval data: loss is {eval_loss:.2f}; sparsity loss is {eval_sparsity:.2f}')            
            torch.save(sae, 'sae.pt')

if not is_trained:
    train(sae)
    is_trained = True
    torch.save(sae, 'sae.pt')

Tokens seen: 25000; Loss: 21140.35; Sparsity-specific loss: 10790.55:   2%|▏         | 49/2000 [00:57<38:25,  1.18s/it] 

Seen 25000 tokens
	 - Training data: loss is 21140.35; sparsity loss is 10790.55
	 - Eval data: loss is 10860.99; sparsity loss is 7113.66


Tokens seen: 50000; Loss: 15341.17; Sparsity-specific loss: 8403.92:   5%|▍         | 99/2000 [01:57<37:35,  1.19s/it] 

Seen 50000 tokens
	 - Training data: loss is 15341.17; sparsity loss is 8403.92
	 - Eval data: loss is 8358.58; sparsity loss is 4661.49


Tokens seen: 50000; Loss: 15341.17; Sparsity-specific loss: 8403.92:   5%|▍         | 99/2000 [02:00<38:27,  1.21s/it]


KeyboardInterrupt: 

## Now let's see how our accuracy and sparsity do compared to a random Autoencoder.

In [7]:
sae2 = SparseAutoencoder()

In [9]:
tokens, activations = next(PhiProbeCommons.all_activations('test', minibatch_size=5, block_size=100, start_block=10000))

In [11]:
with torch.no_grad():
    print(f'Testing on minibatch including:\n---\n{PhiProbeCommons.tokenizer.decode(tokens[0])}\n---\n')
    trained_decoded, trained_encoded = sae(activations)    
    untrained_decoded, untrained_encoded  = sae2(activations)

Testing on minibatch including:
---
 and such conditions. Now let
us suppose a teacher of genius to obtain the post. He not only teaches
admirably, but he institutes school gardens for the children; he takes
long walks with the boys, and gives them the rudiments of geology. He
is in himself an uplifting moral influence, and introduces the children
into a whole new world of idea and of feeling. The parents are pleased.
I will not say that they are grateful; but they
---



In [12]:
with torch.no_grad():
    trained_total, trained_accuracy, trained_sparsity = calculate_loss(sae, trained_decoded, activations, trained_encoded)
    untrained_total, untrained_accuracy, untrained_sparsity = calculate_loss(sae2, untrained_decoded, activations, trained_encoded)

print('Untrained model:')
print(f'\t- Accuracy loss: {untrained_accuracy / 2500:.1f}')
print(f'\t- Sparsity loss: {untrained_sparsity / 2500:.1f}')
print(f'\t- Total loss: {untrained_total / 2500:.1f}')
print('')
print('Trained model:')
print(f'\t- Accuracy loss: {trained_accuracy / 2500:.1f}')
print(f'\t- Sparsity loss: {trained_sparsity / 2500:.1f}')
print(f'\t- Total loss: {trained_total / 2500:.1f}')

Untrained model:
	- Accuracy loss: 6086.3
	- Sparsity loss: 902.1
	- Total loss: 6988.4

Trained model:
	- Accuracy loss: 673.8
	- Sparsity loss: 871.9
	- Total loss: 1545.7
