In [1]:
!pip install einops transformer_lens

Defaulting to user installation because normal site-packages is not writeable


In [2]:
import torch
from torch import Tensor, tensor, arange, randn, randint, tril, where, full_like, ones, allclose, empty, zeros
from torch.nn import Module, Linear, GELU, ReLU, Parameter, Embedding, ModuleList, MSELoss
from torch.nn.functional import softmax, cross_entropy
from torch.optim import AdamW 
from torch.utils.data import Dataset, DataLoader, TensorDataset
from datasets import load_dataset, DatasetDict
from transformer_lens import HookedTransformer
from dataclasses import dataclass
from typing import Optional, List, Dict, Callable
from einops import einsum
from tqdm import tqdm
import matplotlib.pyplot as plt
from os.path import isfile
from math import sqrt
from typing import Dict

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("using", device)

using cpu


In [4]:
@dataclass
class TransformerConfig:
    vocab_size: int
    ncontext: int
    dmodel: int
    dhead: int
    nhead: int
    dmlp : int
    nlayers: int
    activation_function: Callable = GELU()

def normalize(x, dim=-1, eps=1e-5):
    return x / (x.pow(2).mean(dim=dim, keepdim=True) + eps).sqrt()

@dataclass
class BlockActivations:
    mid:       Optional[Tensor] = None
    post:      Optional[Tensor] = None
    attention: Optional[Tensor] = None
    mlp:       Optional[Tensor] = None

class BlockAutoencoders:
    mid:       Callable = lambda self, x: x
    post:      Callable = lambda self, x: x
    attention: Callable = lambda self, x: x
    mlp:       Callable = lambda self, x: x

class MLP(Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.up = Linear(cfg.dmodel, cfg.dmlp)
        self.down = Linear(cfg.dmlp, cfg.dmodel)

    def forward(self, x):
        x = self.up(x)
        x = self.cfg.activation_function(x)
        x = self.down(x)
        return x

class Attention(Module):
    def __init__(self, cfg):
        super().__init__()

        self.cfg = cfg

        self.query_weight  = Parameter(randn(cfg.nhead, cfg.dmodel, cfg.dhead) / sqrt(cfg.dmodel))
        self.key_weight    = Parameter(randn(cfg.nhead, cfg.dmodel, cfg.dhead) / sqrt(cfg.dmodel))
        self.value_weight  = Parameter(randn(cfg.nhead, cfg.dmodel, cfg.dhead) / sqrt(cfg.dmodel))
        self.output_weight = Parameter(randn(cfg.nhead, cfg.dhead, cfg.dmodel) / sqrt(cfg.nhead * cfg.dhead))

        self.query_bias    = Parameter(randn(cfg.nhead, cfg.dhead) / sqrt(cfg.dmodel))
        self.key_bias      = Parameter(randn(cfg.nhead, cfg.dhead) / sqrt(cfg.dmodel))
        self.value_bias    = Parameter(randn(cfg.nhead, cfg.dhead) / sqrt(cfg.dmodel))
        self.output_bias   = Parameter(randn(cfg.dmodel)           / sqrt(cfg.nhead * cfg.dhead))

    def forward(self, x):
        ncontext = x.size(-2)

        query = einsum(x, self.query_weight, "... ncontext dmodel, nhead dmodel dhead -> ... ncontext nhead dhead")
        key   = einsum(x, self.key_weight,   "... ncontext dmodel, nhead dmodel dhead -> ... ncontext nhead dhead")
        value = einsum(x, self.value_weight, "... ncontext dmodel, nhead dmodel dhead -> ... ncontext nhead dhead")
        query = query + self.query_bias
        key   = key   + self.key_bias
        value = value + self.value_bias

        attention = einsum(
            key,
            query,
            "... ncontext_key nhead dhead, ... ncontext_query nhead dhead -> ... nhead ncontext_query ncontext_key"
        )
        attention = attention / sqrt(self.cfg.dhead)
        attention_mask = tril(ones((ncontext, ncontext), dtype=torch.bool, device=device))
        attention = where(attention_mask, attention, full_like(attention, -1e5))
        attention = softmax(attention, dim=-1)
        
        output = einsum(
            attention,
            value,
            "... nhead ncontext_query ncontext_key, ... ncontext_key nhead dhead -> ... ncontext_query nhead dhead"
        )
        result = einsum(output, self.output_weight, "... ncontext nhead dhead, nhead dhead dmodel -> ... ncontext dmodel")
        result = result + self.output_bias
        return result
    
class TransformerBlock(Module):
    def __init__(self, cfg, autoencoders=BlockAutoencoders()):
        super().__init__()
        self.cfg = cfg
        self.autoencoders = autoencoders

        self.attention = Attention(cfg)
        self.mlp = MLP(cfg)

    def forward(self, pre, return_activations=False):
        attention = self.attention(normalize(pre))
        attention = self.autoencoders.attention(attention)
        mid = pre + attention
        mid = self.autoencoders.mid(mid)
        mlp = self.mlp(normalize(mid))
        mlp = self.autoencoders.mlp(mlp)
        post = mid + mlp
        post = self.autoencoders.post(post)

        if return_activations:
            return post, BlockActivations(mid=mid, post=post, attention=attention, mlp=mlp)
        else:
            return post

class Transformer(Module):
    def __init__(self, cfg, autoencoders=None, tokenizer=None):
        super().__init__()
        if type(cfg.activation_function) == str:
            cfg.activation_function = {"gelu": GELU(), "relu": ReLU()}[cfg.activation_function]
        if autoencoders is None:
            autoencoders = [BlockAutoencoders() for _ in range(cfg.nlayers)]
        assert len(autoencoders) == cfg.nlayers

        self.cfg = cfg
        self.tokenizer = tokenizer

        self.embedding = Embedding(cfg.vocab_size, cfg.dmodel)
        self.positional_embedding = Embedding(cfg.ncontext, cfg.dmodel)
        self.blocks = ModuleList([TransformerBlock(cfg, autoencoders[i]) for i in range(cfg.nlayers)])
        self.unembedding = Linear(cfg.dmodel, cfg.vocab_size)

    def forward(self, x, return_activations=False, stop_at_layer=None):
        if isinstance(x, str):
            x = self.tokenizer(x)

        x = self.embedding(x)
        ncontext = x.size(-2)
        x = x + self.positional_embedding(arange(ncontext, device=device))
        
        if not return_activations:
            for layer, block in enumerate(self.blocks):
                x = block(x)
            
                if stop_at_layer is not None and layer > stop_at_layer:
                    break
        else:
            activations = []
            for layer, block in enumerate(self.blocks):
                x, block_activations = block(x, return_activations=True)
                activations.append(block_activations)
            
                if stop_at_layer is not None and layer > stop_at_layer:
                    break

        x = normalize(x)
        x = self.unembedding(x)
        
        if not return_activations:
            return x
        else:
            return x, activations

    @staticmethod
    def from_pretrained(pretrained_model_name):
        theirs = HookedTransformer.from_pretrained(pretrained_model_name)
        
        ours = Transformer(TransformerConfig( vocab_size=theirs.tokenizer.vocab_size,
                                      ncontext=theirs.cfg.n_ctx,
                                      dmodel=theirs.cfg.d_model,
                                      dhead=theirs.cfg.d_head,
                                      nhead=theirs.cfg.n_heads,
                                      dmlp=theirs.cfg.d_mlp,
                                      nlayers=theirs.cfg.n_layers,
                                      activation_function=theirs.cfg.act_fn ))

        ours.tokenizer = theirs.tokenizer

        with torch.no_grad():
            ours.embedding.weight.copy_(theirs.embed.W_E)
            ours.positional_embedding.weight.copy_(theirs.pos_embed.W_pos)
            ours.unembedding.weight.copy_(theirs.unembed.W_U.transpose(0, 1))
            ours.unembedding.bias.copy_(theirs.unembed.b_U)
            
            for layer in range(ours.cfg.nlayers):
                ours.blocks[layer].attention.query_weight.copy_(theirs.blocks[layer].attn.W_Q)
                ours.blocks[layer].attention.key_weight.copy_(theirs.blocks[layer].attn.W_K)
                ours.blocks[layer].attention.value_weight.copy_(theirs.blocks[layer].attn.W_V)
                ours.blocks[layer].attention.output_weight.copy_(theirs.blocks[layer].attn.W_O)

                ours.blocks[layer].attention.query_bias.copy_(theirs.blocks[layer].attn.b_Q)
                ours.blocks[layer].attention.key_bias.copy_(theirs.blocks[layer].attn.b_K)
                ours.blocks[layer].attention.value_bias.copy_(theirs.blocks[layer].attn.b_V)
                ours.blocks[layer].attention.output_bias.copy_(theirs.blocks[layer].attn.b_O)

                ours.blocks[layer].mlp.up.weight.copy_(theirs.blocks[layer].mlp.W_in.transpose(0, 1))
                ours.blocks[layer].mlp.down.weight.copy_(theirs.blocks[layer].mlp.W_out.transpose(0, 1))

                ours.blocks[layer].mlp.up.bias.copy_(theirs.blocks[layer].mlp.b_in)
                ours.blocks[layer].mlp.down.bias.copy_(theirs.blocks[layer].mlp.b_out)

        return ours

In [5]:
model = Transformer.from_pretrained("gelu-1l")
model.eval()

their_model = HookedTransformer.from_pretrained("gelu-1l")
their_model.eval()
input = randint(their_model.tokenizer.vocab_size, (64, 64))
assert(allclose(their_model(input), model(input), atol=1e-4))

Loaded pretrained model gelu-1l into HookedTransformer
Loaded pretrained model gelu-1l into HookedTransformer


In [6]:
def train_val_test_split(dataset, val_size=0.1, test_size=0.1):
    assert set(dataset.keys()) == {"train"}
    dataset = dataset["train"].train_test_split(test_size=val_size+test_size)
    val_test_dataset = dataset["test"].train_test_split(test_size = val_size / (val_size + test_size))
    return DatasetDict({ "train": dataset["train"],
                         "val":   val_test_dataset["train"],
                         "test":  val_test_dataset["test"] })

def make_tokens_dataset(text_dataset, tokenizer, ncontext, _tqdm=True, max_size=None, save_to=None):
    if save_to is not None and isfile(save_to):
        print(f"Loading tokens dataset from file '{save_to}'.")
        return torch.load(save_to)
    
    if max_size is not None:
        print("WARNING: tqdm doesn't work properly when max_size is not None") # we don't care because max_size is only temporal

    token_seqs = []
    for x in tqdm(text_dataset) if _tqdm else text_dataset:
        tokens = tokenizer(x["text"])["input_ids"]
        if len(tokens) <= ncontext:
            continue
        tokens = tokens[:ncontext]
        token_seqs.append(tokens)
        
        if max_size is not None and len(token_seqs) >= max_size:
            break
    
    dataset = TensorDataset(tensor(token_seqs))

    if save_to is not None:
        torch.save(dataset, save_to)

    return dataset

class BlockActivationsDataset(Dataset):
    def __init__(self, activations: BlockActivations):
        self.activations = activations

    def __len__(self):
        for attr in ["mid", "post", "attention", "mlp"]:
            activation = getattr(self.activations, attr)
            if activation is not None:
                return activation.size(0)
        assert False

    def __getitem__(self, index):
        slice = BlockActivations(None, None, None, None)
        for attr in ["mid", "post", "attention", "mlp"]:
            activation = getattr(self.activations, attr)
            if activation is not None:
                setattr(slice, attr, activation[index])
        return slice

class ActivationsDataset(Dataset):
    def __init__(self, activations: Dict[int, BlockActivations]):
        self.activations = { layer: BlockActivationsDataset(layer_activations)
                             for layer, layer_activations in activations.items() }

    def __len__(self):
        return len(next(iter(self.activations.values())))
    
    def __getitem__(self, index):
        return {layer: activations[index] for layer, activations in self.activations.items()}

def block_activations_collate_fn(activations: List[BlockActivations]):
    checkpoints = [checkpoint for checkpoint in ["mid", "post", "attention", "mlp"] if getattr(activations[0], checkpoint) is not None]
    shape = (len(activations), *getattr(activations[0], checkpoints[0]).shape)
    result = BlockActivations(**{checkpoint: empty(shape) for checkpoint in checkpoints})
    for i, acts in enumerate(activations):
        for checkpoint in checkpoints:
            getattr(result, checkpoint)[i] = getattr(acts, checkpoint)
    return result

def activations_collate_fn(activations: List[Dict[int, BlockActivations]]):
    return {layer: block_activations_collate_fn([acts[layer] for acts in activations]) for layer in activations[0].keys()}

def make_activation_dataset(model, dataloader, layers, checkpoints, _tqdm=True, save_to=None):
    if save_to is not None and isfile(save_to):
        print(f"Loading activations dataset from file '{save_to}'.")
        return torch.load(save_to)

    all_activations = None
    i = 0
    for data, in tqdm(dataloader) if _tqdm else dataloader:
        _, activations = model(data, return_activations=True, stop_at_layer=max(layers))
        
        if all_activations is None:
            shape = (len(dataloader.dataset), *activations[0].mid.shape[1:])
            all_activations = { layer: BlockActivations(**{checkpoint: empty(shape) for checkpoint in checkpoints})
                                for layer in layers }
            
        batch_size = activations[0].mid.size(0)
        for layer in layers:
            for checkpoint in checkpoints:
                getattr(all_activations[layer], checkpoint)[i:i+batch_size] = getattr(activations[layer], checkpoint)
        
        i += batch_size

    dataset = ActivationsDataset(all_activations)
    assert i == len(dataset)

    if save_to is not None:
        torch.save(dataset, save_to)

    return dataset

In [7]:
dataset = load_dataset("maxtli/OpenWebText-2M")
dataset = train_val_test_split(dataset)
train_dataset = dataset["train"]
val_dataset   = dataset["val"]
test_dataset  = dataset["test"]

In [8]:
tokens_dataset = make_tokens_dataset(train_dataset, model.tokenizer, ncontext=250, max_size=1_000, save_to="tokens_dataset.pickle")
tokens_dataloader = DataLoader(tokens_dataset, batch_size=64, shuffle=True)
activations_dataset = make_activation_dataset(model, tokens_dataloader, layers=[0], checkpoints=["mlp"], save_to="activations_dataset.pickle")
activations_dataloader = DataLoader(activations_dataset, batch_size=64, collate_fn=activations_collate_fn)

Loading tokens dataset from file tokens_dataset.pickle.
Loading activations dataset from file activations_dataset.pickle.


In [None]:
def next_token_logits(model, seq):
    return model(seq)[..., -1, :]

In [None]:
def repetition_dataset(vocab_size, ncontext, size):
    assert ncontext % 2 == 1
    data = randint(vocab_size, (size, (ncontext + 1) // 2), device=device)
    data = data.repeat(1, 2)
    return TensorDataset(data)

In [None]:
def transformer_cross_entropy_loss(pred, true):
    return cross_entropy(pred.transpose(1, -1), true.transpose(1, -1))

In [None]:
def train(model, dataloader, epochs, loss_fn=transformer_cross_entropy_loss, lr=1e-3, epoch_tqdm=True, batch_tqdm=False, plot_loss=True):
    model.train()
    optimizer = AdamW(model.parameters(), lr=lr)
    
    loss_history = []
    for epoch in tqdm(range(epochs)) if epoch_tqdm else range(epochs):
        for x, in tqdm(dataloader) if batch_tqdm else dataloader:
            optimizer.zero_grad()
            loss = loss_fn(model(x[..., :-1]), x[..., 1:])
            loss_history.append(loss.item())
            loss.backward()
            optimizer.step()

    if plot_loss:
        plt.title("training_loss")
        plt.xlabel("training iteration")
        plt.ylabel("loss")
        plt.yscale("log")
        plt.plot(loss_history)
        plt.show()

In [None]:
cfg = TransformerConfig(vocab_size=10, ncontext=17, dmodel=16, dhead=4, nhead=4, dmlp=32, nlayers=2)
train_dataloader = DataLoader(repetition_dataset(vocab_size=cfg.vocab_size, ncontext=cfg.ncontext, size=500_000), batch_size=64, shuffle=True)
model = Transformer(cfg).to(device)
train(model, train_dataloader, epochs=1, batch_tqdm=True, epoch_tqdm=False)

  0%|          | 6/7813 [00:00<02:24, 54.04it/s]

  7%|▋         | 552/7813 [00:05<01:18, 92.57it/s] 


KeyboardInterrupt: 

In [None]:
test_dataloader = DataLoader(repetition_dataset(vocab_size=cfg.vocab_size, ncontext=cfg.ncontext, size=1_000), batch_size=64, shuffle=True)
x, = next(iter(test_dataloader))
model(x[0, :2])
print(x.shape)
print(model(x[..., :-1]).argmax(-1)[0, ...])
print(x[..., 1:][0, ...])
print(x[..., :-1][0, ...])
print(transformer_cross_entropy_loss(model(x[..., :-1]), x[..., 1:]))

torch.Size([64, 18])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 8, 2, 5, 2, 1, 3, 1, 6])
tensor([8, 2, 5, 2, 1, 3, 1, 6, 1, 8, 2, 5, 2, 1, 3, 1, 6])
tensor([1, 8, 2, 5, 2, 1, 3, 1, 6, 1, 8, 2, 5, 2, 1, 3, 1])
tensor(1.0842, grad_fn=<NllLoss2DBackward0>)
