In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import os
import torch
from  models.simple_sae import SAE 
# import tqdm
# import matplotlib.pyplot as plt
# from evaluate_feature import calculate_AUC_matrix

In [3]:
data_ = np.load("layer_11_embeddings_30subset.npy")
data_.shape

(1972, 1003, 768)

In [4]:
data = data_.reshape(-1, data_.shape[-1])
data.shape

(1977916, 768)

In [14]:
data.max()

np.float32(6.0507812)

In [9]:
import wandb
from dotenv import load_dotenv

wandb.login(key=os.getenv("WANDB_API_KEY"))
wandb.init(
    project="ml4rg",
    entity='elizabeth-lochert-flx'
)



In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# === Dummy Data Example (Replace this with your real data) ===
# Assuming inputs are normalized between 0 and 1
X = torch.from_numpy(data)
dataset = TensorDataset(X)
dataloader = DataLoader(dataset, batch_size=1024*16, shuffle=True)

# === Model, Loss, Optimizer ===
input_dim = 768
latent_dim = 8000
model = SAE(input_dim=input_dim, latent_space_dim=latent_dim).to(device)
criterion = nn.MSELoss()  # Or use nn.BCELoss() if your inputs are binary
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# L1 sparsity strength
l1_lambda = 1e-5

# === Training loop ===
n_epochs = 3

for epoch in range(n_epochs):
    model.train()
    running_loss = 0.0

    for batch in dataloader:
        inputs = batch[0].to(device)

        optimizer.zero_grad()

        # print(inputs.shape)

        # Forward pass
        outputs, latent = model(inputs)

        # Losses
        recon_loss = criterion(outputs, inputs)
        l1_loss = l1_lambda * torch.mean(torch.abs(latent))  # L1 on latent space

        loss = recon_loss + l1_loss

        # Backward + optimize
        loss.backward()
        optimizer.step()

        wandb.log({"loss": loss.item()})

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    print(f"Epoch {epoch + 1}/{n_epochs}, Loss: {epoch_loss:.6f}")

Epoch 1/3, Loss: 0.050917
Epoch 2/3, Loss: 0.021118
Epoch 3/3, Loss: 0.013892


In [11]:
model_save_path = f"./model_e3.pt"
torch.save(model.state_dict(), model_save_path)