In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

from transformer import TransformerConfig
from lm import LM

from data import Dataset
from misc import print_colore

In [2]:
class AutoEncoder(nn.Module):
    def __init__(self, act_size, num_features, l1_coeff):
        super().__init__()

        self.l1_coeff = l1_coeff
        self.num_features = num_features

        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(act_size, num_features)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(num_features, act_size)))
        self.b_enc = nn.Parameter(torch.zeros(num_features))
        self.b_dec = nn.Parameter(torch.zeros(act_size))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
    
    def forward(self, x):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        loss = l2_loss + l1_loss
        return loss, x_reconstruct, acts, l2_loss, l1_loss
    
    @torch.no_grad()
    def make_decoder_weights_and_grad_unit_norm(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj
        # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders.
        self.W_dec.data = W_dec_normed

In [3]:
dataset = Dataset()

config = TransformerConfig(d_model=128, n_layers=1, n_heads=4, max_len=dataset.max_len, dropout=0.)
model = LM(config, vocab_size=len(dataset.vocabulaire))
model.load_state_dict(torch.load("transformer.pth"))

sae = AutoEncoder(act_size=config.d_model, num_features=2*128, l1_coeff=3e-4) # 3e-4 marche bien
sae.load_state_dict(torch.load('sae.pth'))

<All keys matched successfully>

# interprétation neurones & features

In [12]:
top_k = 20
batch_size = 64

In [13]:
def update_top_k(top_values, top_indices, new_values, new_indices, k=20):
    combined_values = torch.cat([top_values, new_values])
    combined_indices = torch.cat([top_indices, new_indices])
    
    new_top_values, topk_indices = torch.topk(combined_values, k)
    new_top_indices = combined_indices[topk_indices]
    
    return new_top_values, new_top_indices

### neurones interprétables ?

In [6]:
top_values = torch.full((config.d_model, top_k), -float('inf'))
top_indices = torch.full((config.d_model, top_k), -1, dtype=torch.long)

for i in range(0, dataset.X_train.shape[0], batch_size):
    X = dataset.X_train[i:i+batch_size]
    act = model(X, act=True) # (B, L, 128)
    max_act = act.max(dim=1).values # (B, 128)

    for dim in range(config.d_model):
        dim_values = max_act[:, dim]
        dim_indices = i + torch.arange(batch_size)

        top_values[dim], top_indices[dim] = update_top_k(top_values[dim], top_indices[dim], dim_values, dim_indices)

In [10]:
neurone = 1
for i in top_indices[neurone]:

    ville = "".join([dataset.int_to_char[k] for k in [p.item() for p in dataset.X_train[i.item()] if p.item() != 0] if k != 1 and k != 2])
    act = model(dataset.X_train[i.item()].unsqueeze(0), act=True) # (B, L, 128)

    print_colore(ville, act[0, :, neurone].tolist()[:len(ville)])

# 21 = morville
# 58 = saint
# 56 = premiere lettre apres un -
# 55 = x en fin de mot
# 1 = -vX <-> lettre après -v

### features interprétables ?

In [14]:
top_values = torch.full((sae.num_features, top_k), -float('inf'))
top_indices = torch.full((sae.num_features, top_k), -1, dtype=torch.long)

for i in range(0, dataset.X_train.shape[0], batch_size):
    X = dataset.X_train[i:i+batch_size]
    act = model(X, act=True) # (B, L, 128)
    _, _, features, _, _ = sae(act)
    max_features = features.max(dim=1).values # (B, 128)

    for dim in range(sae.num_features):
        dim_values = max_features[:, dim]
        dim_indices = i + torch.arange(batch_size)

        top_values[dim], top_indices[dim] = update_top_k(top_values[dim], top_indices[dim], dim_values, dim_indices)

In [35]:
feature = 150
for i in top_indices[feature]:
    ville = "".join([dataset.int_to_char[k] for k in [p.item() for p in dataset.X_train[i.item()] if p.item() != 0] if k != 1 and k != 2])
    act = model(dataset.X_train[i.item()].unsqueeze(0), act=True) # (B, L, 128)
    _, _, features, _, _ = sae(act)

    print_colore(ville, features[0, :, feature].tolist()[:len(ville)])

# 1 = premiere lettre apres un -
# 2 = a apres gr
# 150 = lettre apres cha
