# SAE IMPORT

In [None]:
from dataclasses import dataclass
from typing import Any, Callable, Literal
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import einops
from tqdm import tqdm
# import dataloader from torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from typing import Any, Literal
from torch import Tensor
from jaxtyping import Float
import pickle
import gzip
from torch.utils.data import DataLoader


# SAE SRC

In [None]:
@dataclass
class SAEConfig:
    n_inst: int
    d_in: int
    d_sae: int
    sparsity_coeff: float = 0.2
    weight_normalize_eps: float = 1e-8
    tied_weights: bool = False
    architecture: Literal["standard",] = "standard"
    ste_epsilon: float = 0.01
    dataset: t.utils.data.Dataset = None
    batch_size: int = 2



In [44]:
class SAE(nn.Module):
    def __init__(self, cfg: SAEConfig, dataloader: DataLoader) -> None:
        """
        ARENA-based SAE implementation.
        Args:
            cfg: Configuration for the SAE.
            model: An instance of ToyModel (or similar) whose hidden activations we want to reconstruct.
                   It is assumed that model.cfg.d_hidden matches cfg.d_in.
        """
        super().__init__()

        self.cfg = cfg
        self.dataloader = DataLoader(cfg.dataset, batch_size=cfg.batch_size, shuffle=False)

        # Initialize encoder weights (W_enc) and biases (b_enc)
        self.W_enc = nn.Parameter(
            t.nn.init.kaiming_uniform_(t.empty((cfg.n_inst, cfg.d_in, cfg.d_sae)))
        )
        # Decoder weights: either learned separately or tied to encoder weights.
        self._W_dec = (
            nn.Parameter(t.nn.init.kaiming_uniform_(t.empty((cfg.n_inst, cfg.d_sae, cfg.d_in))))
            if not cfg.tied_weights
            else None
        )
        self.b_enc = nn.Parameter(t.zeros((cfg.n_inst, cfg.d_sae)))
        self.b_dec = nn.Parameter(t.zeros((cfg.n_inst, cfg.d_in)))

        self.to(t.device("cuda" if t.cuda.is_available() else "cpu"))

    @property
    def W_dec(self) -> t.Tensor:
        # If tied weights, return the transpose of the encoder weights.
        return self._W_dec if self._W_dec is not None else self.W_enc.transpose(-1, -2)

    @property
    def W_dec_normalized(self) -> t.Tensor:
        """
        Returns the decoder weights normalized over the autoencoder input dimension.
        """
        norm = t.norm(self.W_dec, dim=-1, keepdim=True) + self.cfg.weight_normalize_eps
        return self.W_dec / norm

    def generate_batch(self) -> t.Tensor: 
        """
        Generates a batch of hidden activations using the underlying model.
        """
        return next(iter(self.dataloader))

    def forward(
        self, h: t.Tensor
    ) -> tuple[
        dict[str, t.Tensor],
        t.Tensor,
        t.Tensor,
        t.Tensor,
    ]:
        """
        Forward pass of the autoencoder.
        
        Args:
            h: Hidden layer activations of shape (batch, inst, d_in).
            
        Returns:
            loss_dict: Dictionary of individual loss terms.
            total_loss: Combined loss (reconstruction + sparsity).
            sae_latents: Latent activations (after ReLU).
            h_reconstructed: Reconstructed autoencoder input.
        """
        # Compute latent logits and apply nonlinearity.
        sae_logits = einops.einsum(
            self.W_enc, h - self.b_dec, "inst d_in d_sae, batch inst d_in -> batch inst d_sae"
        ) + self.b_enc
        sae_latents = F.relu(sae_logits)
        # Reconstruct the hidden activations.
        h_reconstructed = einops.einsum(
            self.W_dec_normalized, sae_latents, "inst d_sae d_in, batch inst d_sae -> batch inst d_in"
        ) + self.b_dec

        # Losses: mean squared reconstruction loss and L1 sparsity penalty.
        reconstructed_loss = (h_reconstructed - h).pow(2).mean(-1)
        sparsity_loss = t.norm(sae_latents, p=1, dim=-1)
        loss_dict = {"L_reconstruction": reconstructed_loss, "L_sparsity": sparsity_loss}
        total_loss = reconstructed_loss + self.cfg.sparsity_coeff * sparsity_loss

        return loss_dict, total_loss, sae_latents, h_reconstructed

    @t.no_grad()
    def resample_simple(
        self,
        frac_active_in_window: t.Tensor,
        resample_scale: float,
    ) -> None:
        """
        Resamples dead neurons by replacing weights and resetting biases.
        """
        # Identify dead neurons (across the resampling window).
        dead_neurons = (frac_active_in_window < 1e-8).all(dim=0)
        n_dead = int(dead_neurons.int().sum().item())
        # Generate new weight vectors for dead neurons.
        sampled_neurons = t.randn((n_dead, self.cfg.d_in), device=self.W_enc.device)
        normalized_sampled_neurons = sampled_neurons / (
            sampled_neurons.norm(dim=-1, keepdim=True) + self.cfg.weight_normalize_eps
        )
        # Update decoder weights and adjust encoder weights accordingly.
        self.W_dec.data[dead_neurons] = normalized_sampled_neurons
        self.W_enc.data.transpose(-1, -2)[dead_neurons] = resample_scale * normalized_sampled_neurons
        self.b_enc.data[dead_neurons] = 0.0

    @t.no_grad()
    def resample_advanced(
        self,
        frac_active_in_window: t.Tensor,
        resample_scale: float,
        batch_size: int,
    ) -> None:
        """
        Resamples dead neurons using reconstruction loss to weight the resampling.
        """
        h = self.generate_batch()
        loss_dict, _, _, _ = self.forward(h)
        l2_loss = loss_dict["L_reconstruction"]

        for instance in range(self.cfg.n_inst):
            is_dead = (frac_active_in_window[:, instance] < 1e-8).all(dim=0)
            dead_latents = t.nonzero(is_dead).squeeze(-1)
            n_dead = dead_latents.numel()
            if n_dead == 0:
                continue

            l2_loss_instance = l2_loss[:, instance]  # Shape: [batch_size]
            if l2_loss_instance.max() < 1e-6:
                continue

            # Sample replacement indices with probabilities proportional to the squared loss.
            distn = Categorical(probs=l2_loss_instance.pow(2) / l2_loss_instance.pow(2).sum())
            replacement_indices = distn.sample((n_dead,))
            replacement_values = (h - self.b_dec)[replacement_indices, instance]
            replacement_values_normalized = replacement_values / (
                replacement_values.norm(dim=-1, keepdim=True) + self.cfg.weight_normalize_eps
            )

            # Compute scaling based on alive neurons.
            W_enc_norm_alive_mean = (
                self.W_enc[instance, :, ~is_dead].norm(dim=0).mean().item() if (~is_dead).any() else 1.0
            )
            self.W_dec.data[instance, dead_latents, :] = replacement_values_normalized
            self.W_enc.data[instance, :, dead_latents] = (
                replacement_values_normalized.T * W_enc_norm_alive_mean * resample_scale
            )
            self.b_enc.data[instance, dead_latents] = 0.0


# SAE Trainer 

In [45]:
# A simple constant learning rate schedule function.
def constant_lr(step: int, total_steps: int) -> float:
    return 1.0

@dataclass
class SAETrainerConfig:
    sae: SAE
    lr: float = 1e-3
    lr_scale: Callable[[int, int], float] = constant_lr
    resample_method: Literal["simple", "advanced", None] = None
    resample_freq: int = 2500
    resample_window: int = 500
    resample_scale: float = 0.5


In [51]:

class SAETrainer:
    def __init__(
        self,
        cfg: SAETrainerConfig,
    ):
        """
        Initializes the trainer.
        
        Args:
            sae: The SAE model instance to train.
            lr: Initial learning rate.
            lr_scale: A function to scale the learning rate over training.
            resample_method: Resampling method to apply on dead neurons ('simple' or 'advanced').
            resample_freq: Frequency (in steps) at which resampling is performed.
            resample_window: Window size to determine dead neurons.
            resample_scale: Scaling factor applied during resampling.
        """
        self.sae = cfg.sae
        self.lr = cfg.lr
        self.lr_scale = cfg.lr_scale
        self.resample_method = cfg.resample_method
        self.resample_freq = cfg.resample_freq
        self.resample_window = cfg.resample_window
        self.resample_scale = cfg.resample_scale
        self.optimizer = t.optim.Adam(list(self.sae.parameters()), lr=self.lr)
        self.device = next(self.sae.parameters()).device

    def train(
        self,
        steps: int = 10,
        batch_size: int = 2,
        log_freq: int = 2,
    ) -> list[dict[str, Any]]:
        """
        Runs the training loop for the SAE model.
        
        Args:
            steps: Total number of training steps.
            batch_size: Batch size for training.
            log_freq: Frequency (in steps) at which training statistics are logged.
            hidden_sample_size: Sample size used for logging hidden activations.
        
        Returns:
            A list of log dictionaries containing training progress data.
        """
        frac_active_list = []
        data_log = []
        progress_bar = tqdm(range(steps))
        for step in progress_bar:
            # Resample dead neurons if requested.
            if (self.resample_method is not None) and ((step + 1) % self.resample_freq == 0):
                frac_active_in_window = t.stack(frac_active_list[-self.resample_window:], dim=0)
                if self.resample_method == "simple":
                    self.sae.resample_simple(frac_active_in_window, self.resample_scale)
                elif self.resample_method == "advanced":
                    self.sae.resample_advanced(frac_active_in_window, self.resample_scale, batch_size)

            # Update learning rate.
            step_lr = self.lr * self.lr_scale(step, steps)
            for group in self.optimizer.param_groups:
                group["lr"] = step_lr

            # Generate a batch of hidden activations.
            with t.inference_mode():
                h = self.sae.generate_batch()

            # Forward pass, loss computation, and backpropagation.
            loss_dict, loss, acts, _ = self.sae(h)
            loss.mean(0).sum().backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

            # For untied weights, re-normalize the decoder weights.
            if not self.sae.cfg.tied_weights:
                self.sae.W_dec.data = self.sae.W_dec_normalized.data

            # Track the fraction of active neurons.
            frac_active = (acts.abs() > 1e-8).float().mean(0)
            frac_active_list.append(frac_active)

            # Logging.
            if step % log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(
                    lr=step_lr,
                    loss=loss.mean(0).sum().item(),
                    frac_active=frac_active.mean().item(),
                    **{k: v.mean(0).sum().item() for k, v in loss_dict.items()},
                )
                with t.inference_mode():
                    h_sample = self.sae.generate_batch()
                    loss_dict_sample, loss_sample, acts_sample, h_r_sample = self.sae(h_sample)
                data_log.append(
                    {
                        "steps": step,
                        "frac_active": (acts_sample.abs() > 1e-8).float().mean(0).detach().cpu(),
                        "loss": loss_sample.detach().cpu(),
                        "h": h_sample.detach().cpu(),
                        "h_r": h_r_sample.detach().cpu(),
                        **{name: param.detach().cpu() for name, param in self.sae.named_parameters()},
                        **{name: loss_term.detach().cpu() for name, loss_term in loss_dict_sample.items()},
                    }
                )
        return data_log

# Dataset

In [52]:
class ToyDataset(Dataset):

    def __init__(self, batch_path: str, layer: Literal["MLP", "RESIDUAL"] = "RESIDUAL", depth: Literal["1","2","3"] = "3") -> None:
        """
        Initializes the ToyDataset class.

        Args:
            batch_path: Path to the batch data file.
            layer: The layer type to extract data from (MLP or RESIDUAL).
            depth: The depth of the layer to extract data from.
        """
        self.layer = layer
        self.depth = depth
        self.data_path = batch_path
        self.raw_data = ToyDataset.load_data(self.data_path)
        self.dataset = self.dataloader()

    

    def dataloader(self) -> Float[Tensor, "batch inst d_in"]:
        """
        Returns a dataloader for the dataset.
        """

        type = "residual_stream" if self.layer == "RESIDUAL" else "mlp_output"
        layer_depth = f"encoder.outer.residual{self.depth}" if self.layer == "RESIDUAL" else f"encoder.ffns.{self.depth}.output"
        data = [self.raw_data[i][type][layer_depth][(1,100,256)].squeeze() for i in range(len(self.raw_data))]
        data = t.stack(data, dim=0)
        return data
    
    def __len__(self) -> int:
        return len(self.raw_data)
    
    def __getitem__(self, idx: int) -> Float[Tensor, "inst d_in"]:
        return self.dataset[idx]
        

    @staticmethod
    def load_data(file_path: str) -> Any:
        """
        Loads a pickle file (.pkl.gz) that is compressed with gzip.
        
        :param file_path: Path to the .pkl.gz file
        :return: The unpickled data
        """
        try:
            with gzip.open(file_path, "rb") as file:
                data = pickle.load(file)
            return data
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            return None

# Trying the Code

In [53]:
path = "/Users/tommasomencattini/Desktop/GitHub/MIVLDE/data/batch_1249.pkl.gz"
dataset = ToyDataset(path)

In [54]:
config = SAEConfig(n_inst=1, d_in=256, d_sae=256, dataset=dataset)
sae = SAE(config, dataset)

In [55]:
sae(dataset[0].unsqueeze(0))

({'L_reconstruction': tensor([[0.8486, 0.6958, 0.8232, 0.7027, 0.8738, 0.6923, 0.6943, 0.6958, 0.6994,
           0.7803, 0.8534, 0.8458, 0.7001, 0.7933, 0.6996, 0.7019, 0.6961, 0.7021,
           0.6966, 0.6977, 0.6979, 0.6985, 0.8508, 0.7010, 0.7021, 0.6987, 0.6952,
           0.8718, 0.6986, 0.7035, 0.7041, 0.6974, 0.8603, 0.7008, 0.6984, 0.6965,
           0.7008, 0.8416, 0.8129, 0.6996, 0.6942, 0.8737, 0.8660, 0.8591, 0.6925,
           0.7030, 0.6985, 0.6975, 0.7026, 0.8695, 0.6939, 0.7003, 0.6950, 0.8375,
           0.7030, 0.8541, 0.7036, 0.8122, 0.7045, 0.8638, 0.6912, 0.6959, 0.6993,
           0.7028, 0.7030, 0.6978, 0.6936, 0.6942, 0.8242, 0.6966, 0.7385, 0.8374,
           0.8596, 0.8220, 0.7011, 0.7039, 0.7038, 0.6944, 0.7014, 0.7025, 0.6981,
           0.6927, 0.6916, 0.6995, 0.7042, 0.6933, 0.8109, 0.7001, 0.8465, 0.6947,
           0.6990, 0.8299, 0.6975, 0.6929, 0.6969, 0.7020, 0.7019, 0.6965, 0.8535,
           0.7042]], grad_fn=<MeanBackward1>),
  'L_sparsity': tens

In [56]:
cfgtrainer = SAETrainerConfig(sae=sae)
trainer = SAETrainer(cfgtrainer)
trainer.train(10)


100%|██████████| 10/10 [00:00<00:00, 196.94it/s, L_reconstruction=72.8, L_sparsity=0.136, frac_active=0.000371, loss=72.8, lr=0.001]


[{'steps': 0,
  'frac_active': tensor([[0.0000, 1.0000, 1.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 1.0000, 1.0000,  ..., 0.5000, 0.5000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.5000, 0.5000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.5000, 0.5000,  ..., 0.0000, 0.0000, 0.0000]]),
  'loss': tensor([[1.5612, 0.7123, 2.1774, 0.7218, 1.5781, 0.7132, 0.7129, 0.7125, 0.7143,
           1.5777, 1.5400, 1.5579, 0.7203, 1.5111, 0.7199, 0.7217, 0.7158, 0.7195,
           0.7162, 0.7157, 0.7116, 0.7207, 1.5627, 0.7204, 0.7209, 0.7194, 0.7131,
           1.5732, 0.7177, 0.7216, 0.7218, 0.7155, 1.5481, 0.7207, 0.7176, 0.7142,
           0.7204, 1.5658, 1.4662, 0.7194, 0.7130, 1.5890, 1.5777, 1.5316, 0.7142,
           0.7220, 0.7193, 0.7185, 0.7215, 1.5829, 0.7139, 0.7203, 0.7146, 1.5282,
           0.7220, 1.5311, 0.7216, 1.4470, 0.7224, 1