In [1]:
!pip install einops transformers==4.35.2 more_itertools

Defaulting to user installation because normal site-packages is not writeable
Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers==4.35.2
  Downloading transformers-4.35.2-py3-none-any.whl (7.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m102.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers==4.35.2)
  Downloading huggingface_hub-0.19.4-py3-none-any.whl (311 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.7/311.7 kB[0m [31m47.2 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.19,>=0.14 (from transformers==4.35.2)
  Downloading tokenizers-0.15.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31

In [2]:
import torch
from torch import Tensor, tensor, arange, randn, randint, tril, where, ones, allclose, empty, zeros, inference_mode, Storage, FloatTensor
from torch.nn import Module, Linear, GELU, ReLU, Parameter, Embedding, ModuleList, LayerNorm, MSELoss, KLDivLoss
from torch.nn.functional import softmax, cross_entropy
from torch.nn.init import zeros_
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader, TensorDataset, Subset
from datasets import load_dataset, DatasetDict
# from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
import pickle
from dataclasses import dataclass
from copy import copy
from typing import Optional, Tuple, List, Dict, Callable, Iterable, Any
from einops import einsum
from tqdm import tqdm
import matplotlib.pyplot as plt
from os.path import isfile
from math import sqrt, pi, prod
from more_itertools import pairwise

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("using", device)

using cuda


In [4]:
@dataclass
class TransformerConfig:
    vocab_size: int
    ncontext: int
    dmodel: int
    dhead: int
    nhead: int
    dmlp : int
    nlayers: int
    activation_function: Callable = GELU()
    mask_value: float = 1e-5
    attention_scale: float = None

    def __post_init__(self):
        if self.attention_scale is None:
            self.attention_scale = 1 / sqrt(self.dhead)

# def normalize(x, dim=-1, eps=1e-5):
#     return x / (x.pow(2).mean(dim=dim, keepdim=True) + eps).sqrt()

# copy pasted from https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
class NewGELUActivation(Module):
    def forward(self, input: Tensor) -> Tensor:
        return 0.5 * input * (1.0 + torch.tanh(sqrt(2.0 / pi) * (input + 0.044715 * torch.pow(input, 3.0))))

class MLP(Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.up = Linear(cfg.dmodel, cfg.dmlp)
        self.down = Linear(cfg.dmlp, cfg.dmodel)

    def forward(self, x):
        x = self.up(x)
        x = self.cfg.activation_function(x)
        x = self.down(x)
        return x

class Attention(Module):
    def __init__(self, cfg):
        super().__init__()

        self.cfg = cfg

        self.query_weight  = Parameter(randn(cfg.nhead, cfg.dmodel, cfg.dhead) / sqrt(cfg.dmodel))
        self.key_weight    = Parameter(randn(cfg.nhead, cfg.dmodel, cfg.dhead) / sqrt(cfg.dmodel))
        self.value_weight  = Parameter(randn(cfg.nhead, cfg.dmodel, cfg.dhead) / sqrt(cfg.dmodel))
        self.output_weight = Parameter(randn(cfg.nhead, cfg.dhead, cfg.dmodel) / sqrt(cfg.nhead * cfg.dhead))

        self.query_bias    = Parameter(randn(cfg.nhead, cfg.dhead) / sqrt(cfg.dmodel))
        self.key_bias      = Parameter(randn(cfg.nhead, cfg.dhead) / sqrt(cfg.dmodel))
        self.value_bias    = Parameter(randn(cfg.nhead, cfg.dhead) / sqrt(cfg.dmodel))
        self.output_bias   = Parameter(randn(cfg.dmodel)           / sqrt(cfg.nhead * cfg.dhead))

    def forward(self, x):
        ncontext = x.size(-2)

        query = einsum(x, self.query_weight, "... ncontext dmodel, nhead dmodel dhead -> ... ncontext nhead dhead")
        key   = einsum(x, self.key_weight,   "... ncontext dmodel, nhead dmodel dhead -> ... ncontext nhead dhead")
        value = einsum(x, self.value_weight, "... ncontext dmodel, nhead dmodel dhead -> ... ncontext nhead dhead")
        query = query + self.query_bias
        key   = key   + self.key_bias
        value = value + self.value_bias

        attention = einsum(
            key,
            query,
            "... ncontext_key nhead dhead, ... ncontext_query nhead dhead -> ... nhead ncontext_query ncontext_key"
        )
        attention = self.cfg.attention_scale * attention
        attention_mask = tril(ones((ncontext, ncontext), dtype=torch.bool, device=device))
        attention = where(attention_mask, attention, tensor(self.cfg.mask_value, device=device))
        attention = softmax(attention, dim=-1)
        
        output = einsum(
            attention,
            value,
            "... nhead ncontext_query ncontext_key, ... ncontext_key nhead dhead -> ... ncontext_query nhead dhead"
        )
        result = einsum(output, self.output_weight, "... ncontext nhead dhead, nhead dhead dmodel -> ... ncontext dmodel")
        result = result + self.output_bias
        return result
    
@dataclass
class BlockOutput:
    output:                  Tensor
    activations:             Optional[Dict[str, Tensor]] = None
    autoencoder_activations: Optional[Dict[str, Tensor]] = None

class TransformerBlock(Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.attention_layer_norm = LayerNorm(cfg.dmodel)
        self.mlp_layer_norm = LayerNorm(cfg.dmodel)

        self.attention = Attention(cfg)
        self.mlp = MLP(cfg)

    def forward(self, pre, return_activations=False, return_autoencoder_activations=False, autoencoders=dict()):
        assert set(autoencoders.keys()) <= {"pre", "mid", "post", "attentio", "mlp"}

        autoencoder_outputs = dict()

        if "pre" in autoencoders:
            pre, autoencoder_outputs["pre"] = autoencoders["pre"](pre)

        attention = self.attention(self.attention_layer_norm(pre))
        
        if "attention" in autoencoders:
            attention, autoencoder_outputs["attention"] = autoencoders["attention"](attention)
        
        mid = pre + attention
        
        if "mid" in autoencoders:
            mid, autoencoder_outputs["mid"] = autoencoders["mid"]
        
        mlp = self.mlp(self.mlp_layer_norm(mid))
        
        if "mlp" in autoencoders:
            mid, autoencoder_outputs["mlp"] = autoencoders["mlp"]
        
        post = mid + mlp
        
        if "post" in autoencoders:
            post, autoencoder_outputs["post"] = autoencoders["post"]

        activations = {"mid": mid, "post": post, "attention": attention, mlp: mlp}

        return BlockOutput(
            output=                  post,
            activations=             activations if return_activations else None,
            autoencoder_activations= autoencoder_outputs if return_autoencoder_activations else None
        )
    
@dataclass
class TransformerOutput:
    logits:                  Tensor
    activations:             Optional[Dict[Tuple[int, str], Tensor]] = None
    autoencoder_activations: Optional[Dict[Tuple[int, str], Tensor]] = None

class Transformer(Module):
    def __init__(self, cfg, tokenizer=None):
        super().__init__()
        if type(cfg.activation_function) == str:
            cfg.activation_function = {"gelu": GELU(), "relu": ReLU(), "gelu_new": NewGELUActivation()}[cfg.activation_function]
        
        self.cfg = cfg
        self.tokenizer = tokenizer

        self.embedding = Embedding(cfg.vocab_size, cfg.dmodel)
        self.positional_embedding = Embedding(cfg.ncontext, cfg.dmodel)
        self.blocks = ModuleList([TransformerBlock(cfg) for _ in range(cfg.nlayers)])
        self.unembedding = Linear(cfg.dmodel, cfg.vocab_size)
        self.final_layer_norm = LayerNorm(cfg.dmodel)

    def forward( self,
                 x,
                 return_activations=False,
                 return_autoencoder_activations=False,
                 stop_at_layer=None,
                 autoencoders: Dict[Tuple[int, str], Callable] = dict() ):

        assert all( layer in range(self.cfg.nlayers) and checkpoint in ["pre", "mid", "post", "mlp", "attention"]
                    for layer, checkpoint in autoencoders.keys() )

        if isinstance(x, str):
            x = self.tokenizer(x)

        x = self.embedding(x)
        ncontext = x.size(-2)
        x = x + self.positional_embedding(arange(ncontext, device=device))
        
        activations = dict() if return_activations else None
        autoencoder_activations = dict() if return_autoencoder_activations else None
        
        blocks = self.blocks if stop_at_layer is None else self.blocks[:stop_at_layer]
        for layer, block in enumerate(blocks):
            pre = x

            if layer == 0 and (0, "pre") in autoencoders:
                pre, autoencoder_activations[(0, "pre")] = autoencoders[(0, "pre")](pre)

            auoencoders_on_layer = { checkpoint: autoencoder
                                     for (layer_, checkpoint), autoencoder in autoencoders
                                     if layer_ == layer }
            output = block( x,
                            return_activations=             return_activations,
                            return_autoencoder_activations= autoencoder_activations,
                            autoencoders=                   autoencoders_on_layer)

            if return_activations:
                for checkpoint, activation in output.activations:
                    activations[(layer, checkpoint)] = activation
                if layer == 0:
                    activations[(0, "pre")] = pre

            if return_autoencoder_activations:
                for checpoint, activation in output.autoencoder_activations:
                    activations[(layer, checkpoint)] = activation

        x = self.final_layer_norm(x)
        x = self.unembedding(x)
        
        return TransformerOutput(logits=x, activations=activations, autoencoder_activations=autoencoder_activations)

    @staticmethod
    def from_pretrained(pretrained_model_name, test=True, test_atol=1e-4):
        theirs = AutoModelForCausalLM.from_pretrained(pretrained_model_name) # .to(device)
        tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")

        ours = Transformer(TransformerConfig( vocab_size=          tokenizer.vocab_size,
                                              ncontext=            theirs.config.max_position_embeddings,
                                              dmodel=              theirs.config.hidden_size,
                                              dhead=               theirs.config.hidden_size // theirs.config.num_heads,
                                              nhead=               theirs.config.num_heads,
                                              dmlp=                4 * theirs.config.hidden_size,
                                              nlayers=             theirs.config.num_layers,
                                              activation_function= theirs.config.activation_function,
                                              attention_scale=     1.0,
                                              mask_value=          torch.finfo(torch.float).min )).to(device)
        
        ours.tokenizer = tokenizer

        with inference_mode():
            ours.embedding.weight.copy_(theirs.transformer.wte.weight)
            ours.positional_embedding.weight.copy_(theirs.transformer.wpe.weight)
            ours.unembedding.weight.copy_(theirs.transformer.wte.weight)
            zeros_(ours.unembedding.bias)
            ours.final_layer_norm.weight.copy_(theirs.transformer.ln_f.weight)
            ours.final_layer_norm.bias.copy_(theirs.transformer.ln_f.bias)
                    
            for layer in range(ours.cfg.nlayers):
                ours.blocks[layer].attention_layer_norm.weight.copy_(theirs.transformer.h[layer].ln_1.weight)
                ours.blocks[layer].attention_layer_norm.bias.copy_(theirs.transformer.h[layer].ln_1.bias)
                ours.blocks[layer].mlp_layer_norm.weight.copy_(theirs.transformer.h[layer].ln_2.weight)
                ours.blocks[layer].mlp_layer_norm.bias.copy_(theirs.transformer.h[layer].ln_2.bias)

                ours.blocks[layer].attention.query_weight.copy_(theirs.transformer.h[layer].attn.attention.q_proj.weight.reshape(ours.cfg.nhead, ours.cfg.dhead, ours.cfg.dmodel).permute(0, 2, 1))
                ours.blocks[layer].attention.key_weight.copy_(theirs.transformer.h[layer].attn.attention.k_proj.weight.reshape(ours.cfg.nhead, ours.cfg.dhead, ours.cfg.dmodel).permute(0, 2, 1))
                ours.blocks[layer].attention.value_weight.copy_(theirs.transformer.h[layer].attn.attention.v_proj.weight.reshape(ours.cfg.nhead, ours.cfg.dhead, ours.cfg.dmodel).permute(0, 2, 1))
                ours.blocks[layer].attention.output_weight.copy_(theirs.transformer.h[layer].attn.attention.out_proj.weight.reshape(ours.cfg.dmodel, ours.cfg.nhead, ours.cfg.dhead).permute(1, 2, 0))

                zeros_(ours.blocks[layer].attention.query_bias)
                zeros_(ours.blocks[layer].attention.key_bias)
                zeros_(ours.blocks[layer].attention.value_bias)
                ours.blocks[layer].attention.output_bias.copy_(theirs.transformer.h[layer].attn.attention.out_proj.bias)

                ours.blocks[layer].mlp.up.weight.copy_(theirs.transformer.h[layer].mlp.c_fc.weight)
                ours.blocks[layer].mlp.down.weight.copy_(theirs.transformer.h[layer].mlp.c_proj.weight)

                ours.blocks[layer].mlp.up.bias.copy_(theirs.transformer.h[layer].mlp.c_fc.bias)
                ours.blocks[layer].mlp.down.bias.copy_(theirs.transformer.h[layer].mlp.c_proj.bias)
            

        if test:
            with inference_mode():
                print("Testing that the model behaves the same as the library model... ", end="", flush=True)
                inputs = randint(0, ours.cfg.vocab_size, (64, 64), device=device)
                assert allclose(ours(inputs).logits, theirs(inputs).logits, atol=test_atol), "Tests failed!"
                print("Test passed!")

        return ours
    
        """
        theirs = HookedTransformer.from_pretrained(pretrained_model_name)
        
        ours = Transformer(TransformerConfig( vocab_size=theirs.tokenizer.vocab_size,
                                      ncontext=theirs.cfg.n_ctx,
                                      dmodel=theirs.cfg.d_model,
                                      dhead=theirs.cfg.d_head,
                                      nhead=theirs.cfg.n_heads,
                                      dmlp=theirs.cfg.d_mlp,
                                      nlayers=theirs.cfg.n_layers,
                                      activation_function=theirs.cfg.act_fn ))

        ours.tokenizer = theirs.tokenizer

        with torch.no_grad():
            ours.embedding.weight.copy_(theirs.embed.W_E)
            ours.positional_embedding.weight.copy_(theirs.pos_embed.W_pos)
            ours.unembedding.weight.copy_(theirs.unembed.W_U.transpose(0, 1))
            ours.unembedding.bias.copy_(theirs.unembed.b_U)
            
            for layer in range(ours.cfg.nlayers):
                ours.blocks[layer].attention.query_weight.copy_(theirs.blocks[layer].attn.W_Q)
                ours.blocks[layer].attention.key_weight.copy_(theirs.blocks[layer].attn.W_K)
                ours.blocks[layer].attention.value_weight.copy_(theirs.blocks[layer].attn.W_V)
                ours.blocks[layer].attention.output_weight.copy_(theirs.blocks[layer].attn.W_O)

                ours.blocks[layer].attention.query_bias.copy_(theirs.blocks[layer].attn.b_Q)
                ours.blocks[layer].attention.key_bias.copy_(theirs.blocks[layer].attn.b_K)
                ours.blocks[layer].attention.value_bias.copy_(theirs.blocks[layer].attn.b_V)
                ours.blocks[layer].attention.output_bias.copy_(theirs.blocks[layer].attn.b_O)

                ours.blocks[layer].mlp.up.weight.copy_(theirs.blocks[layer].mlp.W_in.transpose(0, 1))
                ours.blocks[layer].mlp.down.weight.copy_(theirs.blocks[layer].mlp.W_out.transpose(0, 1))

                ours.blocks[layer].mlp.up.bias.copy_(theirs.blocks[layer].mlp.b_in)
                ours.blocks[layer].mlp.down.bias.copy_(theirs.blocks[layer].mlp.b_out)

        return ours
        """

ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`.

In [5]:
model = Transformer.from_pretrained("roneneldan/TinyStories-1M", test=False).to(device)
model.eval()

config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/48.6M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]

Transformer(
  (embedding): Embedding(50257, 64)
  (positional_embedding): Embedding(2048, 64)
  (blocks): ModuleList(
    (0): TransformerBlock(
      (attention_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mlp_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (attention): Attention()
      (mlp): MLP(
        (up): Linear(in_features=64, out_features=256, bias=True)
        (down): Linear(in_features=256, out_features=64, bias=True)
      )
    )
    (1): TransformerBlock(
      (attention_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mlp_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (attention): Attention()
      (mlp): MLP(
        (up): Linear(in_features=64, out_features=256, bias=True)
        (down): Linear(in_features=256, out_features=64, bias=True)
      )
    )
    (2): TransformerBlock(
      (attention_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
   

In [None]:
# def train_val_test_split(dataset, val_size=0.1, test_size=0.1):
#     dataset = dataset.train_test_split(test_size=val_size+test_size)
#     val_test_dataset = dataset["test"].train_test_split(test_size = val_size / (val_size + test_size))
#     return DatasetDict({ "train": dataset["train"],
#                          "val":   val_test_dataset["train"],
#                          "test":  val_test_dataset["test"] })

def make_tokens_dataset(text_dataset, tokenizer, ncontext, _tqdm=True, max_size=None, save_to=None):
    if save_to is not None and isfile(save_to):
        print(f"Loading tokens dataset from file '{save_to}'.")
        return torch.load(save_to)
    
    if _tqdm and max_size is not None:
        print("WARNING: tqdm doesn't work properly when max_size is not None")

    token_seqs = []
    for x in tqdm(text_dataset) if _tqdm else text_dataset:
        tokens = tokenizer(x["text"])["input_ids"]
        if len(tokens) <= ncontext:
            continue
        tokens = tokens[:ncontext]
        token_seqs.append(tokens)
        
        if max_size is not None and len(token_seqs) >= max_size:
            break
    
    dataset = TensorDataset(tensor(token_seqs, requires_grad=False))

    if save_to is not None:
        torch.save(dataset, save_to)

    return dataset

def all_equal(xs):
    return all(current == next for current, next in zipnext(xs))

class DictTensorDataset(Dataset):
    def __init__(self, tensors: Dict[Any, Tensor]):
        assert all_equal(tensor.size(0) for tensor in tensors.values()), "Size mismatch between tensors"
        self.tensors = tensors

    def __len__(self):
        return next(iter(self.tensors.values())).size(0)

    def __getitem(self, index):
        return {key: tensor[index] for key, tensor in self.tensors.items()}

def dict_collate_fn(dicts: Iterable[Dict[Any, Tensor]]):
    if not isinstance(dicts, list):
        dicts = list(dicts)

    collated = { key: empty(len(dicts), *tensor.shape)
                 for key, tensor in next(iter(dicts)) }

    for i, dict in enumerate(dicts):
        for key, tensor in dicts:
            collated[key][i, :] = tensor

    return collated

def list_to_english(xs, comma=", ", and_=" and ", oxford_and=", and "):
    if len(xs) == 0:
        return ""
    if len(xs) == 1:
        return xs[0]
    if len(xs) == 2:
        return xs[0] + and_ + xs[1]
    return comma.join(xs[:-2]) + oxford_and + xs[-1]

def make_activation_dataset(model, tokens_dataloader, checkpoints, on_disk=False, storage_directory=None, _tqdm=True):
    with inference_mode():
        dataset_size = tokens_dataloader.dataset.shape(0)
        _, ncontext, dmodel = model(next(iter(tokens_dataloader)), return_activations=True).activations[next(iter(checkpoints))]

        def storage_filename(checkpoint: Tuple[int, str]):
            assert storage_directory is not None
            return f"{storage_directory}/layer{checkpoint[0]}-{checkpoint[1]}.dat"
        already_computed_checkpoints_list_filename = f"{storage_directory}/already_computed_checkpoints_list.pickle"

        shape = (dataset_size, ncontext, dmodel)
        if not on_disk:
            dataset_activations = {checkpoint: empty(shape) for checkpoint in checkpoints}
        else:
            # same as the then branch, but stored on disk

            if isfile(already_computed_checkpoints_list_filename):
                with open(already_computed_checkpoints_list_filename, "wb") as file:
                    already_computed_checkpoints = pickle.load(file)
            else:
                already_computed_checkpoints = []
            
            print("Loading activations", list_to_english(f"'{checkpoint}'" for checkpoint in checkpoints_already_on_disk), "from file.")
            
            checkpoints = [checkpoint for checkpoint in checkpoints if checkpoint not in checkpoints_already_on_disk]
            
            dataset_activations = { checkpoint: FloatTensor(Storage.from_file(storage_filename(checkpoint), True, prod.shape)).reshape(size)
                                    for checkpoint in checkpoints + already_computed_checkpoints }

            for checkpoint in copy(checkpoints_already_on_disk):
                if (dataset_activations[checkpoint][-1, :] == 0).all():
                    print( "The activations file for '{checkpoint}' seems weird - the last datapoint is all zeros." +
                           "The program was probably interrupted before it finished being computed." +
                           "It will be recalculated from scratch" )
                    checkpoints_already_on_disk.remove(checkpoint)
                    checkpoints.append(checkpoints)

        if len(checkpoints) > 0:
            i = 0
            for tokens in tokens_dataloader:
                output = model(tokens, return_activations=True, stop_at_layer=1+max(layer for layer, _ in checkpoints))
                batch_size = output.activations[(0, "pre")].size(0)
                for checkpoint in checkpoints:
                    dataset_activations[checkpoint][i:i+batch_size] = output.activations[checkpoint]
                i += batch_size

        if on_disk:
            with open(already_computed_checkpoints_list_filename, "wb") as file:
                pickle.dump(list(set(list(checkpoints) + list(already_computed_checkpoints))), file)

        return DictTensorDataset(dataset_activations)

In [9]:
dataset           = load_dataset("roneneldan/TinyStories")
test_dataset      = dataset["validation"]
train_val_dataset = dataset["train"].train_test_split()
train_dataset     = train_val_dataset["train"]
val_dataset       = train_val_dataset["test"]

Using custom data configuration roneneldan--TinyStories-a62fc98e062666ca
Reusing dataset parquet (/home/paperspace/.cache/huggingface/datasets/roneneldan___parquet/roneneldan--TinyStories-a62fc98e062666ca/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

In [11]:
model.tokenizer

GPT2TokenizerFast(name_or_path='EleutherAI/gpt-neo-125M', vocab_size=50257, model_max_length=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [None]:
tokens_dataset = make_tokens_dataset(train_dataset, model.tokenizer, ncontext=250, save_to="/storage/tokens_dataset.pickle") # "/home/paperspace/tokens_dataset.pickle")
# tokens_dataset = Subset(tokens_dataset, range(100_000)) # not enough ram :( and waiting for access to a machine with a big disk
tokens_dataloader = DataLoader(tokens_dataset, batch_size=64, shuffle=True)

val_tokens_dataset = make_tokens_dataset(val_dataset, model.tokenizer, ncontext=250, save_to="/storage/val_tokens_dataset.pickle")# "/home/paperspace/val_tokens_dataset.pickle")
val_tokens_dataloader = DataLoader(val_tokens_dataset, batch_size=64)

activations_dataset = make_activation_dataset(model, tokens_dataloader, layers=[0], checkpoints=["mlp"], on_disk=False)# on_disk=True, storage_root_directory="/home/paperspace/activations/")
activations_dataloader = DataLoader(activations_dataset, batch_size=4096, collate_fn=dict_collate_fn)

val_activations_dataset = make_activation_dataset(model, val_tokens_dataloader, layers=[0], checkpoints=["mlp"], on_disk=False)
activations_dataloader = DataLoader(activations_dataset, batch_size=4096, collate_fn=dict_collate_fn)

In [None]:
class SparseAutoencoder(Module):
    def __init__(self, d, dhidden, activation_function=ReLU()):
        super().__init__()
        self.pre_bias = Parameter(zeros(d))
        self.up = Linear(d, dhidden)
        self.activation_function = activation_function
        self.down = Linear(dhidden, d)

    def forward(self, x):
        hidden = self.up(x - self.pre_bias)
        hidden = self.activation_function(hidden)
        output = self.down(hidden)
        return output, hidden

class L1Penalty(Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.norm(dim=-1, p=1).mean()

In [None]:
def train_sparse_autoencoder( sparse_autoencoder,
                              activations_dataloader,
                              checkpoint,
                              epochs,
                              sparsity_penalty_weight,
                              lr=1e-4,
                              reconstruction_loss_fn=MSELoss(),
                              sparsity_penalty_fn=L1Penalty(),
                              optimizer=None,
                              epoch_tqdm=True,
                              batch_tqdm=False,
                              plot=True ):
    optimizer = AdamW(sparse_autoencoder.parameters(), lr=lr)
    
    reconstruction_loss_history = []
    sparsity_penalty_history = []
    for epoch in tqdm(range(epochs)) if epoch_tqdm else range(epochs):
        epoch_reconstruction_loss = 0
        epoch_sparsity_penalty = 0
        for activations in tqdm(activations_dataloader) if batch_tqdm else activations_dataloader:
            activations = activations[checkpoint].detach()
            reconstructed, hidden = sparse_autoencoder(activations)
            reconstruction_loss = reconstruction_loss_fn(reconstructed, activations)
            sparsity_penalty = sparsity_penalty_fn(hidden)
            loss = reconstruction_loss + sparsity_penalty_weight * sparsity_penalty
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_reconstruction_loss += reconstruction_loss.item()
            epoch_sparsity_penalty += sparsity_penalty.item()

        reconstruction_loss_history.append(epoch_reconstruction_loss / len(activations_dataloader))
        sparsity_penalty_history.append(epoch_sparsity_penalty / len(activations_dataloader))

    if plot:
        plt.title("Sparse autoencoder training loss")
        plt.xlabel("epoch")
        plt.ylabel("loss")
        plt.yscale("log")
        plt.plot(reconstruction_loss_history, label="reconstruction loss")
        plt.plot(sparsity_penalty_history, label="sparsity penalty")
        plt.legend()
        plt.show()

def test_sparse_autoencoder(model, dataloader, autoencoder, checkpoint, reconstruction_loss_fn=KLDivLoss(), sparsity_penalty_fn=L1Penalty(), _tqdm=True):
    autoencoders = {checkpoint: autoencoder}

    with inference_mode():
        reconstruction_loss = 0
        sparsity_penalty = 0
        for tokens, in tqdm(dataloader) if _tqdm else dataloader:
            tokens = tokens.to(device)
            result = model(tokens)
            result_with_autoencoders = model(tokens, autoencoders=autoencoders, return_autoencoder_activations=True)
            reconstruction_loss += reconstruction_loss_fn(result.logits, result_with_autoencoders.logits)
            autoencoder_activaitons = result_with_autoencoders.autoencoder_activations[checkpoint]
            sparsity_penalty += sparsity_penalty_fn(autoencoder_activaitons)
        
    return { "reconstruction_loss": reconstruction_loss / len(dataloader),
             "sparsity_penalty":    sparsity_penalty    / len(dataloader) }

In [None]:
sparse_autoencoder = SparseAutoencoder(model.cfg.dmodel, 4*model.cfg.dmodel).to(device)
train_sparse_autoencoder(sparse_autoencoder, activations_dataloader, checkpoint=(0, "mlp"), epochs=250, lr=1e-3, sparsity_penalty_weight=1e-3)

In [None]:
test_sparse_autoencoder(model, val_tokens_dataloader, sparse_autoencoder, checkpoint=(0, "mlp"))

# training the model - don't look at it, we will only train the model from scratch if we have time

In [None]:
def next_token_logits(model, seq):
    return model(seq)[..., -1, :]

In [None]:
def repetition_dataset(vocab_size, ncontext, size):
    assert ncontext % 2 == 1
    data = randint(vocab_size, (size, (ncontext + 1) // 2), device=device)
    data = data.repeat(1, 2)
    return TensorDataset(data)

In [None]:
def transformer_cross_entropy_loss(pred, true):
    return cross_entropy(pred.transpose(1, -1), true.transpose(1, -1))

In [None]:
def train(model, dataloader, epochs, loss_fn=transformer_cross_entropy_loss, lr=1e-3, epoch_tqdm=True, batch_tqdm=False, plot_loss=True):
    model.train()
    optimizer = AdamW(model.parameters(), lr=lr)
    
    loss_history = []
    for epoch in tqdm(range(epochs)) if epoch_tqdm else range(epochs):
        for x, in tqdm(dataloader) if batch_tqdm else dataloader:
            optimizer.zero_grad()
            loss = loss_fn(model(x[..., :-1]), x[..., 1:])
            loss_history.append(loss.item())
            loss.backward()
            optimizer.step()

    if plot_loss:
        plt.title("training_loss")
        plt.xlabel("training iteration")
        plt.ylabel("loss")
        plt.yscale("log")
        plt.plot(loss_history)
        plt.show()

In [None]:
cfg = TransformerConfig(vocab_size=10, ncontext=17, dmodel=16, dhead=4, nhead=4, dmlp=32, nlayers=2)
train_dataloader = DataLoader(repetition_dataset(vocab_size=cfg.vocab_size, ncontext=cfg.ncontext, size=500_000), batch_size=64, shuffle=True)
model = Transformer(cfg).to(device)
train(model, train_dataloader, epochs=1, batch_tqdm=True, epoch_tqdm=False)

In [None]:
test_dataloader = DataLoader(repetition_dataset(vocab_size=cfg.vocab_size, ncontext=cfg.ncontext, size=1_000), batch_size=64, shuffle=True)
x, = next(iter(test_dataloader))
model(x[0, :2])
print(x.shape)
print(model(x[..., :-1]).argmax(-1)[0, ...])
print(x[..., 1:][0, ...])
print(x[..., :-1][0, ...])
print(transformer_cross_entropy_loss(model(x[..., :-1]), x[..., 1:]))