In [None]:

"""
My First Sparse Autoencoder
"""

import torch
import torch.nn as nn
from torch.optim import Adam
import numpy


from transformers import GPT2Tokenizer, GPT2Model

from datasets import load_dataset

# For progress bar
from tqdm import tqdm


# === SETTINGS ===
model_name = "stanford-crfm/caprica-gpt2-small-x81"
layer_index = 6     # 6 is middle layer
input_size = 3072
hidden_size = 8192
batch_size_num = 32 # Might need to reduce if crashes


# ********** PART 1: LOAD THE MODEL **********
print("Loading model...")

tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# Load actual model
gpt_model = GPT2Model.from_pretrained(model_name)
gpt_model.eval()

if torch.cuda.is_available():
    gpt_model.cuda()
    print("Using GPU!")
else:
    print("Using CPU :(")


# ~~~~~ Hook to get activations ~~~~~
activations_storage = {}  # Dictionary to store values

def save_activations(name):
    def hook(model, input, output):
        activations_storage[name] = output[0].detach()
    return hook

gpt_model.h[layer_index].mlp.c_fc.register_forward_hook(
    save_activations('middle_layer')
)


# ++++++++++ AUTOENCODER CLASS ++++++++++
class MyAE(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        # Encoder part
        self.encoder_layer = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU() # Non-linear activation function

        # Decoder part
        self.decoder_layer = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        # Step 1: Encode
        hidden = self.relu(self.encoder_layer(x))
        # Step 2: Decode
        output = self.decoder_layer(hidden)
        return output, hidden

# Make autoencoder
my_autoencoder = MyAE(input_size, hidden_size)
if torch.cuda.is_available():
    my_autoencoder.cuda()


print("\nPreparing dataset...")


dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

def tokenize_data(examples):
    return tokenizer(
        examples["text"],
        max_length=128,  # Fixed number from tutorial
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )


dataset = dataset.map(tokenize_data, batched=True)
dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
train_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size_num,
    shuffle=True
)

print("\nStarting training...")

optimizer = Adam(my_autoencoder.parameters(), lr=0.0001)
loss_function = nn.MSELoss()

# Store losses
losses = []

# Loop for epochs
for epoch in range(10):  # Find better number for epoch
    print(f"\nEpoch {epoch+1}/10")

    total_loss = 0.0
    avg_activation = 0.0

    # Progress bar
    progress_bar = tqdm(train_loader, desc="Processing batches")

    for batch in progress_bar:
        input_ids = batch["input_ids"].cuda() if torch.cuda.is_available() else batch["input_ids"]
        attention_mask = batch["attention_mask"].cuda() if torch.cuda.is_available() else batch["attention_mask"]

        with torch.no_grad():
            outputs = gpt_model(input_ids=input_ids, attention_mask=attention_mask)

        acts = activations_storage['middle_layer']
        batch_size, seq_len, feat_dim = acts.shape
        flattened_acts = acts.view(-1, feat_dim)

        # Autoencoder step
        reconstructed, hidden = my_autoencoder(flattened_acts)

        # Calculate loss
        reconstruction_loss = loss_function(reconstructed, flattened_acts)
        sparsity_loss = 0.01 * torch.mean(torch.abs(hidden))  #Check with universality paper

        total_loss = reconstruction_loss + sparsity_loss

        # Backprop
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        losses.append(total_loss.item())
        progress_bar.set_postfix(loss=total_loss.item())
    avg_loss = sum(losses[-len(train_loader):])/len(train_loader)
    print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}")

torch.save(my_autoencoder.state_dict(), "my_sparse_ae.pth")
print("\nTraining complete! Saved model.")



Epoch 1: 100%|██████████| 1148/1148 [05:24<00:00,  3.53it/s]


Epoch 1 - Loss: 0.0953, Avg Activation: 0.2030


Epoch 2: 100%|██████████| 1148/1148 [05:24<00:00,  3.54it/s]


Epoch 2 - Loss: 0.0482, Avg Activation: 0.1678


Epoch 3: 100%|██████████| 1148/1148 [05:24<00:00,  3.54it/s]


Epoch 3 - Loss: 0.0334, Avg Activation: 0.1422


Epoch 4: 100%|██████████| 1148/1148 [05:24<00:00,  3.54it/s]


Epoch 4 - Loss: 0.0272, Avg Activation: 0.1274


Epoch 5: 100%|██████████| 1148/1148 [05:23<00:00,  3.54it/s]


Epoch 5 - Loss: 0.0232, Avg Activation: 0.1125


Epoch 6: 100%|██████████| 1148/1148 [05:24<00:00,  3.54it/s]


Epoch 6 - Loss: 0.0205, Avg Activation: 0.1066


Epoch 7: 100%|██████████| 1148/1148 [05:23<00:00,  3.55it/s]


Epoch 7 - Loss: 0.0184, Avg Activation: 0.1016


Epoch 8: 100%|██████████| 1148/1148 [05:23<00:00,  3.54it/s]


Epoch 8 - Loss: 0.0169, Avg Activation: 0.0980


Epoch 9: 100%|██████████| 1148/1148 [05:24<00:00,  3.54it/s]


Epoch 9 - Loss: 0.0158, Avg Activation: 0.0939


Epoch 10: 100%|██████████| 1148/1148 [05:24<00:00,  3.54it/s]


Epoch 10 - Loss: 0.0151, Avg Activation: 0.0926
