In [1]:
import einops
from dataclasses import dataclass
# from transformer_lens import HookedTransformer
# from transformer_lens.utils import gelu_new, tokenize_and_concatenate
import torch as t
from torch import Tensor
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm
from jaxtyping import Float, Int
from pathlib import Path
from transformers import GPT2Model, GPT2Config
from typing import Tuple, List, Optional, Dict, Callable
from torch.cuda.amp import GradScaler, autocast

from tqdm import tqdm
import plot
import wandb

from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new, tokenize_and_concatenate

device = t.device("cuda" if t.cuda.is_available() else "cpu")

In [2]:
class Config:
    d_model: int = 128
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 4
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()

In [37]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.zeros(cfg.d_model))

    def forward(self, residual: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        # SOLUTION
        residual_mean = residual.mean(dim=-1, keepdim=True)
        residual_std = (residual.var(dim=-1, keepdim=True, unbiased=False) + self.cfg.layer_norm_eps).sqrt()

        residual = (residual - residual_mean) / residual_std
        return residual * self.w + self.b

In [58]:
class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty(cfg.d_vocab, cfg.d_model))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, vecs: Float[Tensor, "batch position d_vocab"]) -> Float[Tensor, "batch position d_model"]:
        # SOLUTION
        return einops.einsum(vecs, self.W_E, "batch position d_vocab, d_vocab d_model -> batch position d_model")

In [59]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, vecs: Int[Tensor, "batch position d_vocab"]) -> Float[Tensor, "batch position d_model"]:
        # SOLUTION
        batch, seq_len, _ = vecs.shape
        return einops.repeat(self.W_pos[:seq_len], "seq d_model -> batch seq d_model", batch=batch)

In [60]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(float("-inf"), dtype=t.float32, device=device))

    def forward(
        self, normalized_resid_pre: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        # SOLUTION
        # Calculate query, key and value vectors
        q = einops.einsum(
            normalized_resid_pre, self.W_Q,
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head",
        ) + self.b_Q
        k = einops.einsum(
            normalized_resid_pre, self.W_K,
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head",
        ) + self.b_K
        v = einops.einsum(
            normalized_resid_pre, self.W_V,
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head",
        ) + self.b_V

        # Calculate attention scores, then scale and mask, and apply softmax to get probabilities
        attn_scores = einops.einsum(
            q, k,
            "batch posn_Q nheads d_head, batch posn_K nheads d_head -> batch nheads posn_Q posn_K",
        )
        attn_scores_masked = self.apply_causal_mask(attn_scores / self.cfg.d_head ** 0.5)
        attn_pattern = attn_scores_masked.softmax(-1)

        # Take weighted sum of value vectors, according to attention probabilities
        z = einops.einsum(
            v, attn_pattern,
            "batch posn_K nheads d_head, batch nheads posn_Q posn_K -> batch posn_Q nheads d_head",
        )

        # Calculate output (by applying matrix W_O and summing over heads, then adding bias b_O)
        attn_out = einops.einsum(
            z, self.W_O,
            "batch posn_Q nheads d_head, nheads d_head d_model -> batch posn_Q d_model",
        ) + self.b_O

        return attn_out

    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        '''
        Applies a causal mask to attention scores, and returns masked scores.
        '''
        # SOLUTION
        # Define a mask that is True for all positions we want to set probabilities to zero for
        all_ones = t.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device)
        mask = t.triu(all_ones, diagonal=1).bool()
        # Apply the mask to attention scores, then return the masked scores
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

In [61]:
class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)))
        self.W_out = nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)))
        self.b_in = nn.Parameter(t.zeros((cfg.d_mlp)))
        self.b_out = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(
        self, normalized_resid_mid: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        # SOLUTION
        pre = einops.einsum(
            normalized_resid_mid, self.W_in,
            "batch position d_model, d_model d_mlp -> batch position d_mlp",
        ) + self.b_in
        post = gelu_new(pre)
        mlp_out = einops.einsum(
            post, self.W_out,
            "batch position d_mlp, d_mlp d_model -> batch position d_model",
        ) + self.b_out
        return mlp_out

In [62]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(
        self, resid_pre: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_model"]:
        # SOLUTION
        resid_mid = self.attn(self.ln1(resid_pre)) + resid_pre
        resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid
        return resid_post

In [63]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(t.zeros((cfg.d_vocab), requires_grad=False))

    def forward(
        self, normalized_resid_final: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_vocab"]:
        # SOLUTION
        return einops.einsum(
            normalized_resid_final, self.W_U,
            "batch posn d_model, d_model d_vocab -> batch posn d_vocab",
        ) + self.b_U
        # Or, could just do `normalized_resid_final @ self.W_U + self.b_U`

In [64]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens: Float[Tensor, "batch position d_vocab"]) -> Float[Tensor, "batch position d_vocab"]:
        # SOLUTION
        residual = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.blocks:
            residual = block(residual)
        prediction = self.unembed(self.ln_final(residual))
        return prediction

In [65]:
@dataclass
class TransformerTrainingArgs():
	batch_size = 16
	lr = 1e-3
	weight_decay = 1e-2
	wandb_project: Optional[str] = "day1-demotransformer"
	wandb_name: Optional[str] = None


args = TransformerTrainingArgs()

In [66]:
def rand_range(low, high, shape):
  return t.rand(shape) * (high - low) + low

def generate_linear_recurrences(batch_size, vector_dim=4, compl=1, length=10, param_bds = (1, 1), return_type="both"):

    assert compl < length
    assert param_bds[0] <= param_bds[1]

    params = rand_range(param_bds[0], param_bds[1], (batch_size, compl)).to(device)
    consts = rand_range(0, 0, (batch_size, vector_dim)).to(device)

    recurrences = t.empty((batch_size, length, vector_dim)).to(device)

    recurrences[:, :compl] = rand_range(1, 3, (batch_size, 1, vector_dim))

    for j in range(compl, length):
        recurrences[:, j] = consts + einops.einsum(params, recurrences[:, j-compl:j], "batch compl, batch compl vector_dim -> batch vector_dim")

    #find max norm in each batch and divide each batch by that max norm
    max_norms, _ = t.max(t.norm(recurrences, dim=2, keepdim=True), dim=1, keepdim=True)
    caps = (t.rand((batch_size, 1, 1)) * (10 - 1) + 1).to(device)
    recurrences = recurrences / (max_norms / caps)

    max_norms = max_norms.squeeze(1)

    if return_type == 'seq':
        return recurrences
    elif return_type == 'both':
        return recurrences, (params, consts / (max_norms / caps))

    assert False

generate_linear_recurrences(1)

(tensor([[[4.0596, 1.5858, 2.3464, 3.5013],
          [4.0596, 1.5858, 2.3464, 3.5013],
          [4.0596, 1.5858, 2.3464, 3.5013],
          [4.0596, 1.5858, 2.3464, 3.5013],
          [4.0596, 1.5858, 2.3464, 3.5013],
          [4.0596, 1.5858, 2.3464, 3.5013],
          [4.0596, 1.5858, 2.3464, 3.5013],
          [4.0596, 1.5858, 2.3464, 3.5013],
          [4.0596, 1.5858, 2.3464, 3.5013],
          [4.0596, 1.5858, 2.3464, 3.5013]]]),
 (tensor([[1.]]), tensor([[[0., 0., 0., 0.]]])))

In [69]:
class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: DemoTransformer):
        super().__init__()
        self.model = model
        self.args = args
        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.loss = nn.MSELoss()
        self.step = 0


    def training_step(self, batch: Float[Tensor, "batch seq d_vocab"]) -> Float[Tensor, ""]:
        '''
        Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.

        Remember that `batch` is a dictionary with the single key 'tokens'.
        '''
        # SOLUTION
        pred = self.model(batch)
        loss = self.loss(pred, batch)
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.step += 1
        return loss


    def train(self, steps=100_000):
        '''
        Trains the model, for `self.args.epochs` epochs. Also handles wandb initialization, and early stopping
        for each epoch at `self.args.max_steps_per_epoch` steps.
        '''
        # Initialize Weights & Biases logging if needed
        # wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)

        with tqdm(total=steps, desc="Training Progress", ascii=True) as pbar:
            for i in range(steps):
                # Simulating the generation of batches with varying lengths
                batch = generate_linear_recurrences(self.args.batch_size, length=t.randint(4, 10, (1,)).item(), return_type='seq')
                loss = self.training_step(batch)

                # Update tqdm progress bar with loss information
                pbar.update(1)
                pbar.set_postfix_str(f"Step: {i+1}, Loss: {loss.item():.4f}")

                # Optionally log metrics to Weights & Biases
                # wandb.log({"loss": loss.item()}, step=i)

    # Clean up Weights & Biases session after training is complete
    # wandb.finish()



In [70]:
model = DemoTransformer(cfg).to(device)
args = TransformerTrainingArgs()
trainer = TransformerTrainer(args, model)
trainer.train()

Training Progress:   0%|          | 41/100000 [00:14<10:04:15,  2.76it/s, Step: 41, Loss: 2.3105]


KeyboardInterrupt: 