# Import the necessary libraries

In [None]:
import os
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from einops import rearrange, repeat, reduce

from typing import Optional, Tuple, Union, List
from jaxtyping import Float, Bool

from boring_utils.utils import get_device, cprint, tprint

device = get_device()

In [None]:
def add_to_class(Class):
    """Register functions as methods in created class."""
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, num_heads: int, embedding_dim: int, max_seq_len: int = 1024):
        super().__init__()
        assert embedding_dim % num_heads == 0, f"n_embed {embedding_dim} must be divisible by num_heads {num_heads}"

        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        self.head_size = embedding_dim // num_heads

        self.qkv_proj = nn.Linear(embedding_dim, 3 * embedding_dim, bias=False)
        self.out_proj = nn.Linear(embedding_dim, embedding_dim, bias=False)

        self.register_buffer(
                "mask", 
                torch.tril(torch.ones(max_seq_len, max_seq_len))
                    .view(1, 1, max_seq_len, max_seq_len))  # extend dims to 4

    def forward(
            self, 
            x: Float[Tensor, "batch seq_len embedding_dim"]
        ) -> Float[Tensor, "batch seq_len embedding_dim"]:
        batch, seq_len, embedding_dim = x.shape

        # ["batch, seq_len, embedding_dim"] -> ["batch, seq_len, (3 * embedding_dim)"]
        qkv = self.qkv_proj(x)
        q, k, v = qkv.split(self.embedding_dim, dim=-1)  # split at the last dim

        # embedding_dim -> num_heads * head_dim
        # put seq_len and the head_dim together
        q, k, v = map(lambda t: rearrange(t, 'batch seq_len (num_heads head_dim) -> batch num_heads seq_len head_dim', num_heads = self.num_heads), (q, k, v))

        norm_factor = 1.0 / np.sqrt(embedding_dim)
        attn = (q @ k.transpose(-2, -1)) * norm_factor
        attn = attn.masked_fill(self.mask[:, :, :seq_len, :seq_len] == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)

        # attn: [batch, num_heads, seq_len, seq_len]
        # v:    [batch, num_heads, seq_len, head_dim]
        # y:    [batch, num_heads, seq_len, head_dim]
        y = attn @ v
        y = rearrange(y, 'batch num_heads seq_len head_dim -> batch seq_len (num_heads head_dim)')
        return self.out_proj(y)  # [batch, seq_len, embedding_dim]


In [None]:
class FFN(nn.Module):
    def __init__(self, embedding_dim: int):
        super().__init__()
        hidden_dim = embedding_dim * 4
        self.fc_1 = nn.Linear(embedding_dim, hidden_dim)
        self.activation = nn.GELU(approximate='tanh')
        self.fc_2 = nn.Linear(hidden_dim, embedding_dim)

    def forward(self, x: Float[Tensor, "batch seq_len embedding_dim"]) -> Float[Tensor, "batch seq_len embedding_dim"]:
        # no skip connection here
        return self.fc_2(self.activation(self.fc_1(x)))

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, embedding_dim: int, eps: float = 1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(embedding_dim))  # scaling
        self.beta = nn.Parameter(torch.zeros(embedding_dim))  # offset 
        self.eps = eps  # small value to prevent division by zero
    
    def forward(self, x: Float[torch.Tensor, "batch seq_len embedding_dim"]) -> Float[torch.Tensor, "batch seq_len embedding_dim"]:
        mean = x.mean(dim=-1, keepdim=True)  # [batch, seq_len, 1]
        var = x.var(dim=-1, keepdim=True, unbiased=False)  # [batch, seq_len, 1]
        x_norm = (x - mean) / torch.sqrt(var + self.eps)  # [batch, seq_len, embedding_dim]
        return self.gamma * x_norm + self.beta

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, num_heads: int, embedding_dim: int, max_seq_len: int = 1024):
        super().__init__()
        # self.ln_1 = nn.LayerNorm(embedding_dim)  # norm on the last dim
        # self.ln_2 = nn.LayerNorm(embedding_dim)
        self.ln_1 = LayerNorm(embedding_dim)  # norm on the last dim
        self.ln_2 = LayerNorm(embedding_dim)
        self.attn = CausalSelfAttention(num_heads, embedding_dim, max_seq_len)
        self.ffn = FFN(embedding_dim)
    
    def forward(self, x: Float[Tensor, "batch seq_len embedding_dim"]) -> Float[Tensor, "batch seq_len embedding_dim"]:
        # skip connection, pre-layer norm
        x = x + self.attn(self.ln_1(x))
        x = x + self.ffn(self.ln_2(x))
        return x

In [None]:
class GPT(nn.Module):
    def __init__(
            self, 
            vocab_size: int = 50257,
            num_heads: int = 12, 
            embedding_dim: int = 768, 
            max_seq_len: int = 1024, 
            num_layers: int = 12,
            dropout_rate: float = 0.0
        ):
        super().__init__()
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        self.max_seq_len = max_seq_len
        self.num_layers = num_layers

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(vocab_size, embedding_dim),
            wpe = nn.Embedding(max_seq_len, embedding_dim),
            drop = nn.Dropout(dropout_rate),
            h = nn.ModuleList([TransformerBlock(num_heads, embedding_dim, max_seq_len) for _ in range(num_layers)]),
            ln_f = nn.LayerNorm(embedding_dim, bias=False)
        ))
        self.lm_head = nn.Linear(embedding_dim, embedding_dim, bias=False)

    def forward(self, x: Float[Tensor, "batch seq_len"]) -> Float[Tensor, "batch seq_len embedding_dim"]:
        x = self.transformer.wte(x) + self.transformer.wpe(x)
        x = self.transformer.drop(x)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        return self.lm_head(x)