In [16]:
from generate_activations import *
import torch.nn as nn
import torch.nn.functional as F
import torch
import einops

In [6]:
buffer = Buffer(cfg)

In [7]:
data = buffer.next()

In [8]:
data.shape

torch.Size([4096, 512])

In [37]:
class GatedAutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        # d_hidden = cfg["dict_size"]
        d_hidden = 5000

        l1_coeff = cfg["l1_coeff"]
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])

        self.W_gate = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(cfg["act_size"], d_hidden, dtype=dtype)
            )
        )
        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(
                torch.empty(d_hidden, cfg["act_size"], dtype=dtype)
            )
        )
        self.b_dec = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype))

        self.b_enc_gate = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec_gate = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype))

        self.r_mag = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_mag = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))

        self.b_gate = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff

        self.to(cfg["device"])


    def gated_sae(self, x):
        preactivations_hidden = einops.einsum(x, self.W_gate, "... input_dim, input_dim hidden_dim -> ... hidden_dim")

        pre_mag_hidden = preactivations_hidden * torch.exp(self.r_mag) + self.b_mag
        post_mag_hidden = torch.relu(pre_mag_hidden)

        pre_gate_hidden = preactivations_hidden + self.b_gate
        post_gate_hidden = (torch.sign(pre_gate_hidden) + 1) / 2

        postactivations_hidden = post_mag_hidden * post_gate_hidden


        reconstruction =  einops.einsum(postactivations_hidden, self.W_dec, "... hidden_dim, hidden_dim output_dim -> ... output_dim") + self.b_dec

        return reconstruction, pre_gate_hidden
        


    # def gated_sae(self, x): # W_gate, b_gate, W_mag, b_mag, W_dec, b_dec
    #     x_center = x - self.b_dec

    #     feature_mags = F.relu(x_center @ self.W_mag + self.b_enc)

    #     active_features = torch.zeros_like(feature_mags)

    #     active_features[(x_center @ self.W_enc_gate + self.b_enc_gate) > 0] = 1

    #     return (active_features * feature_mags) @ self.W_dec + self.b_dec

    def forward(self, x):
        reconstruction, pre_gate_hidden = self.gated_sae(x)

        # Reconstruction Loss
        # gated_sae_loss = (reconstruction - x).pow(2).sum()
        gated_sae_loss = F.mse_loss(reconstruction, x, reduction='mean')
        

        # L1 loss
        gate_magnitude = F.relu(pre_gate_hidden)
        gated_sae_loss += self.l1_coeff * gate_magnitude.sum()

        # Auxiliary loss
        gate_reconstruction = einops.einsum(gate_magnitude, self.W_dec.detach(), "... hidden_dim, hidden_dim output_dim -> ... output_dim") + self.b_dec.detach()
        auxiliary_loss = F.mse_loss(gate_reconstruction, x, reduction='mean')

        gated_sae_loss += auxiliary_loss

        return gated_sae_loss


    # def forward(self, x):
        # x_cent = x - self.b_dec
        # acts = F.relu(x_cent @ self.W_gate + self.b_enc)


        # x_reconstruct = acts @ self.W_dec + self.b_dec
        # l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
        # l1_loss = self.l1_coeff * (acts.float().abs().sum())
        # loss = l2_loss + l1_loss
        # return loss, x_reconstruct, acts, l2_loss, l1_loss

    @torch.no_grad()
    def make_decoder_weights_and_grad_unit_norm(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(
            -1, keepdim=True
        ) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj
        # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders.


Foward pass:

$$
\tilde{\mathbf{f}}(\mathbf{x}):=\underbrace{\mathbb{1}[\overbrace{\left(\mathbf{W}_{\text {gate }}\left(\mathbf{x}-\mathbf{b}_{\text {dec }}\right)+\mathbf{b}_{\text {gate }}\right)}^{\boldsymbol{\pi}_{\text {gate }}(\mathbf{x})}>\mathbf{0}]}_{\mathbf{f}_{\text {gate }}(\mathbf{x})} \odot \underbrace{\operatorname{ReLU}\left(\mathbf{W}_{\text {mag }}\left(\mathbf{x}-\mathbf{b}_{\text {dec }}\right)+\mathbf{b}_{\text {mag }}\right)}_{\mathbf{f}_{\text {mag }}(\mathbf{x})},
$$

Vector-valued rescaling parameter:

$$
\left(\mathbf{W}_{\text {mag }}\right)_{i j}:=\left(\exp \left(\mathbf{r}_{\text {mag }}\right)\right)_i \cdot\left(\mathbf{W}_{\text {gate }}\right)_{i j}
$$

Loss:

$$
\mathcal{L}_{\text {gated }}(\mathbf{x}):=\underbrace{\|\mathbf{x}-\hat{\mathbf{x}}(\tilde{\mathbf{f}}(\mathbf{x}))\|_2^2}_{\mathcal{L}_{\text {reconstruct }}}+\underbrace{\lambda\left\|\operatorname{ReLU}\left(\boldsymbol{\pi}_{\text {gate }}(\mathbf{x})\right)\right\|_1}_{\mathcal{L}_{\text {sparsity }}}+\underbrace{\left\|\mathbf{x}-\hat{\mathbf{x}}_{\text {frozen }}\left(\operatorname{ReLU}\left(\boldsymbol{\pi}_{\text {gate }}(\mathbf{x})\right)\right)\right\|_2^2}_{\mathcal{L}_{\text {aux }}}
$$

data

In [43]:
gated_encoder = GatedAutoEncoder(cfg)

In [44]:
model_dtype = DTYPES[cfg["enc_dtype"]]

In [45]:
gated_encoder(data.to(model_dtype))

tensor(1264.5316, device='cuda:0', grad_fn=<AddBackward0>)

In [46]:
num_batches = cfg["num_tokens"] // cfg["batch_size"]
encoder = gated_encoder

# model_num_batches = cfg["model_batch_size"] * num_batches
encoder_optim = torch.optim.Adam(
    encoder.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"])
)
recons_scores = []
act_freq_scores_list = []
for i in range(num_batches):
    i = i % all_tokens.shape[0]

    acts = buffer.next().to(model_dtype)
    loss = encoder(acts)
    loss.backward()
    encoder.make_decoder_weights_and_grad_unit_norm()
    encoder_optim.step()
    encoder_optim.zero_grad()

    if (i) % 100 == 0:
        print(loss.item())
        # wandb.log(loss_dict)
        # print(loss_dict)

    del loss
    # if (i) % 1000 == 0:
    #     x = get_recons_loss(local_encoder=encoder)
    #     print("Reconstruction:", x)
    #     recons_scores.append(x[0])
    #     freqs = get_freqs(5, local_encoder=encoder)
    #     act_freq_scores_list.append(freqs)
    #     # histogram(freqs.log10(), marginal="box", histnorm="percent", title="Frequencies")
    #     wandb.log(
    #         {
    #             "recons_score": x[0],
    #             "dead": (freqs == 0).float().mean().item(),
    #             "below_1e-6": (freqs < 1e-6).float().mean().item(),
    #             "below_1e-5": (freqs < 1e-5).float().mean().item(),
    #         }
    #     )
    # if (i + 1) % 30000 == 0:
    #     encoder.save()
    #     # wandb.log({"reset_neurons": 0.0})
    #     freqs = get_freqs(50, local_encoder=encoder)
    #     to_be_reset = freqs < 10 ** (-5.5)
    #     print("Resetting neurons!", to_be_reset.sum())
    #     # re_init(to_be_reset, encoder)

  0%|          | 14/488281 [00:00<1:00:28, 134.56it/s]

1249.2452392578125


  0%|          | 119/488281 [00:02<2:56:48, 46.02it/s]

41.371585845947266


  0%|          | 119/488281 [00:03<3:38:20, 37.26it/s]


KeyboardInterrupt: 

In [30]:
gated_encoder.W_gate.shape

torch.Size([512, 16384])