In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformer import TransformerConfig
from lm import LM

from data import Dataset

In [13]:
class AutoEncoder(nn.Module):
    def __init__(self, act_size, num_features, l1_coeff):
        super().__init__()

        self.l1_coeff = l1_coeff
        self.num_features = num_features

        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(act_size, num_features)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(num_features, act_size)))
        self.b_enc = nn.Parameter(torch.zeros(num_features))
        self.b_dec = nn.Parameter(torch.zeros(act_size))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
    
    def forward(self, x):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        loss = l2_loss + l1_loss
        return loss, x_reconstruct, acts, l2_loss, l1_loss
    
    def encode(self, act):
        act_cent = act - self.b_dec
        features = F.relu(act_cent @ self.W_enc + self.b_enc)
        return features
    
    def decode(self, features):
        act_reconstruct = features @ self.W_dec + self.b_dec
        return act_reconstruct

    @torch.no_grad()
    def make_decoder_weights_and_grad_unit_norm(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj
        # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders.
        self.W_dec.data = W_dec_normed

In [14]:
dataset = Dataset()

config = TransformerConfig(d_model=64, n_layers=1, n_heads=8, max_len=dataset.max_len, dropout=0.)
model = LM(config, vocab_size=len(dataset.vocabulaire))
model.load_state_dict(torch.load("transformer_d_model64_n_heads8_1_57_1_68.pth", map_location=torch.device('cpu')))

sae = AutoEncoder(act_size=config.d_model, num_features=4*config.d_model, l1_coeff=3e-4) # 3e-4 marche bien
sae.load_state_dict(torch.load('sae_d_model64_e4.pth'))

<All keys matched successfully>

In [238]:
def sample(self, prompt = "", g = torch.Generator()):
    idx = torch.tensor([dataset.char_to_int[c] for c in prompt], dtype=torch.int32).unsqueeze(0)
    idx = torch.cat([torch.tensor(dataset.char_to_int['<SOS>']).view(1, 1), idx], dim=1)
    next_id = -1

    while next_id != dataset.char_to_int['<EOS>']:
        act = self(idx, act=True) # (1, l, d_model)

        # SAE
        features = sae.encode(act) # (1, l, num_features)
        act_reconstruct_1 = sae.decode(features) # (1, l, d_model) # reconstruction sans modification

        """
        l = act.shape[1]
        if l >= 5:
            features[:, l:, 142] = 100
        
        """
        features[:, :, 187] = 1
        act_reconstruct_2 = sae.decode(features) # reconstruction avec modification

        error = act - act_reconstruct_1
        final_act = act_reconstruct_2 + error

        x = self.out_norm(final_act)
        logits = self.lm_head(x)

        probs = F.softmax(logits[:, -1, :], dim=-1)
        next_id = torch.multinomial(probs, num_samples=1, generator=g).item()
        idx = torch.cat([idx, torch.tensor(next_id).view(1, 1)], dim=1)

        if idx.shape[1] > config.max_len:
            break

    #print(act.max(dim=2).indices)
    #print(act[:, :, 21])
        
    return "".join([dataset.int_to_char[p.item()] for p in idx[0, 1:-1]])

In [241]:
for i in range(10):
    print(sample(model, ""))

cavigny
courceiry-enginny
saint-cirgigni
candriville
attigny
flincy
mirigny-lès-saint-giigoumin
la guimic
troin
chigicorce-den-lècrecotte


In [None]:
# manipulations de concepts intéressants :

# 142 (apres un saint) : 0.5 -> villes en -
#                        3 -> que des tirets/e

# 187 (aucun rapport, après un a) : 2 -> génère des fin en y (courcy, vabigny,)

# 203 (- apres un es) : 1 ou plus -> beaucoup de tirets
# 214 (premiere lettre, debut ou apres -) : 2 -> aucun tirets