# Import the necessary libraries

In [1]:
import os
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torch import Tensor
from einops import rearrange, repeat, reduce

from typing import Optional, Tuple, Union, List
from jaxtyping import Float, Bool

from boring_utils.utils import get_device, cprint, tprint

device = get_device()

In [2]:
def add_to_class(Class):
    """Register functions as methods in created class."""
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper

In [11]:
class CasualSelfAttention(nn.Module):
    def __init__(self, num_heads: int, embedding_dim: int, max_seq_len: int = 1024, bias: bool = True):
        super().__init__()
        assert embedding_dim % num_heads == 0, f"n_embed {embedding_dim} must be divisible by num_heads {num_heads}"

        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        self.head_size = embedding_dim // num_heads

        self.c_attn = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)  # qkv projection
        self.c_proj = nn.Linear(embedding_dim, embedding_dim, bias=bias)  # output projection

        self.register_buffer(
                "mask", 
                torch.tril(torch.ones(max_seq_len, max_seq_len))
                    .view(1, 1, max_seq_len, max_seq_len))  # extend dims to 4

    def forward(
            self, 
            x: Float[Tensor, "batch seq_len embedding_dim"]
        ) -> Float[Tensor, "batch seq_len embedding_dim"]:
        batch, seq_len, embedding_dim = x.shape

        # ["batch, seq_len, embedding_dim"] -> ["batch, seq_len, (3 * embedding_dim)"]
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.embedding_dim, dim=-1)  # split at the last dim

        # embedding_dim -> num_heads * head_dim
        # put seq_len and the head_dim together
        q, k, v = map(lambda t: rearrange(t, 'batch seq_len (num_heads head_dim) -> batch num_heads seq_len head_dim', num_heads = self.num_heads), (q, k, v))

        norm_factor = 1.0 / np.sqrt(k.size(-1))  # k.size(-1) is the head_dim
        attn = (q @ k.transpose(-2, -1)) * norm_factor
        attn = attn.masked_fill(self.mask[:, :, :seq_len, :seq_len] == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)

        # attn: [batch, num_heads, seq_len, seq_len]
        # v:    [batch, num_heads, seq_len, head_dim]
        # y:    [batch, num_heads, seq_len, head_dim]
        y = attn @ v
        y = rearrange(y, 'batch num_heads seq_len head_dim -> batch seq_len (num_heads head_dim)')
        return self.c_proj(y)  # [batch, seq_len, embedding_dim]


In [12]:
class CasualSelfAttention_alternative(nn.Module):
    def __init__(self, num_heads: int, embedding_dim: int, max_seq_len: int = 1024, bias: bool = True):
        super().__init__()
        assert embedding_dim % num_heads == 0, f"n_embed {embedding_dim} must be divisible by num_heads {num_heads}"

        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        self.head_size = embedding_dim // num_heads

        # self.qkv_proj = nn.Linear(embedding_dim, 3 * embedding_dim, bias=False)
        self.heads = nn.ModuleList([
            nn.ModuleDict({
                'key': nn.Linear(embedding_dim, self.head_size, bias=bias),
                'query': nn.Linear(embedding_dim, self.head_size, bias=bias), 
                'value': nn.Linear(embedding_dim, self.head_size, bias=bias)
            }) for _ in range(num_heads)
        ])
        self.c_proj = nn.Linear(embedding_dim, embedding_dim, bias=bias)  # output projection

        self.register_buffer(
                "mask", 
                torch.tril(torch.ones(max_seq_len, max_seq_len))
                    .view(1, 1, max_seq_len, max_seq_len))  # extend dims to 4

    def forward(
            self, 
            x: Float[Tensor, "batch seq_len embedding_dim"]
        ) -> Float[Tensor, "batch seq_len embedding_dim"]:
        batch, seq_len, embedding_dim = x.shape

        # cat([batch, seq_len, head_dim] x num_heads) -> [batch, seq_len, num_heads * head_dim]
        q = torch.cat([h['query'](x) for h in self.heads], dim=-1)
        k = torch.cat([h['key'](x) for h in self.heads], dim=-1)
        v = torch.cat([h['value'](x) for h in self.heads], dim=-1)

        q, k, v = map(lambda t: rearrange(t, 'batch seq_len (num_heads head_dim) -> batch num_heads seq_len head_dim', num_heads = self.num_heads), (q, k, v))

        norm_factor = 1.0 / np.sqrt(k.size(-1))  # k.size(-1) is the head_dim
        attn = (q @ k.transpose(-2, -1)) * norm_factor
        attn = attn.masked_fill(self.mask[:, :, :seq_len, :seq_len] == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)

        # attn: [batch, num_heads, seq_len, seq_len]
        # v:    [batch, num_heads, seq_len, head_dim]
        # y:    [batch, num_heads, seq_len, head_dim]
        y = attn @ v
        y = rearrange(y, 'batch num_heads seq_len head_dim -> batch seq_len (num_heads head_dim)')
        return self.c_proj(y)  # [batch, seq_len, embedding_dim]

In [13]:
class FFN(nn.Module):
    def __init__(self, embedding_dim: int, bias: bool = True):
        super().__init__()
        hidden_dim = embedding_dim * 4
        self.c_fc = nn.Linear(embedding_dim, hidden_dim, bias=bias)
        self.gelu = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(hidden_dim, embedding_dim, bias=bias)

    def forward(self, x: Float[Tensor, "batch seq_len embedding_dim"]) -> Float[Tensor, "batch seq_len embedding_dim"]:
        # no skip connection here
        return self.c_proj(self.gelu(self.c_fc(x)))

In [14]:
class LayerNorm(nn.Module):
    def __init__(self, embedding_dim: int, eps: float = 1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(embedding_dim))  # scaling (gamma)
        self.bias = nn.Parameter(torch.zeros(embedding_dim))  # offset (beta)
        self.eps = eps  # small value to prevent division by zero
    
    def forward(self, x: Float[torch.Tensor, "batch seq_len embedding_dim"]) -> Float[torch.Tensor, "batch seq_len embedding_dim"]:
        mean = x.mean(dim=-1, keepdim=True)  # [batch, seq_len, 1]
        var = x.var(dim=-1, keepdim=True, unbiased=False)  # [batch, seq_len, 1]
        x_norm = (x - mean) / torch.sqrt(var + self.eps)  # [batch, seq_len, embedding_dim]
        return self.weight * x_norm + self.bias

In [15]:
class TransformerBlock(nn.Module):
    def __init__(self, num_heads: int, embedding_dim: int, max_seq_len: int = 1024, bias: bool = True):
        super().__init__()
        self.ln_1 = nn.LayerNorm(embedding_dim, bias=bias)  # norm on the last dim
        self.ln_2 = nn.LayerNorm(embedding_dim, bias=bias)
        # self.ln_1 = LayerNorm(embedding_dim)  # norm on the last dim
        # self.ln_2 = LayerNorm(embedding_dim)
        self.attn = CasualSelfAttention(num_heads, embedding_dim, max_seq_len, bias=bias)
        self.mlp = FFN(embedding_dim, bias=bias)
    
    def forward(self, x: Float[Tensor, "batch seq_len embedding_dim"]) -> Float[Tensor, "batch seq_len embedding_dim"]:
        # skip connection, pre-layer norm
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

In [16]:
class GPT(nn.Module):
    def __init__(
            self, 
            vocab_size: int = 50257,
            num_heads: int = 12, 
            embedding_dim: int = 768, 
            max_seq_len: int = 1024, 
            num_layers: int = 12,
            dropout_rate: float = 0.0,
            bias: bool = True
        ):
        super().__init__()
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        self.max_seq_len = max_seq_len
        self.num_layers = num_layers

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(vocab_size, embedding_dim),
            wpe = nn.Embedding(max_seq_len, embedding_dim),
            drop = nn.Dropout(dropout_rate),
            h = nn.ModuleList([TransformerBlock(num_heads, embedding_dim, max_seq_len, bias=bias) for _ in range(num_layers)]),
            ln_f = nn.LayerNorm(embedding_dim, bias=bias)
            # ln_f = LayerNorm(embedding_dim)
        ))
        # TODO: why bias=False?
        self.lm_head = nn.Linear(embedding_dim, vocab_size, bias=False)

    def forward(self, x: Float[Tensor, "batch seq_len"]) -> Float[Tensor, "batch seq_len embedding_dim"]:
        batch, seq_len = x.shape
        assert seq_len <= self.max_seq_len, f"input length {seq_len} is longer than max seq length {self.max_seq_len}"

        pos = torch.arange(0, seq_len, device=x.device)
        pos_emb = self.transformer.wpe(pos)  # [seq_len, embedding_dim]
        tok_emb = self.transformer.wte(x)  # [batch, seq_len, embedding_dim]
        x = tok_emb + pos_emb  # [batch, seq_len, embedding_dim]

        for block in self.transformer.h:
            x = block(x)

        x = self.transformer.ln_f(x)
        return self.lm_head(x)  # [batch, seq_len, vocab_size]

    @classmethod
    def from_pretrained(cls, model_type):
        '''https://youtu.be/l8pRSuU81PU?t=1830
        '''
        assert model_type in {'gpt2'}
        from transformers import GPT2LMHeadModel
        print("loading weights from pretrained gpt: %s" % model_type)

        model = GPT()
        sd = model.state_dict()
        sd_keys = sd.keys()
        sd_keys = [k for k in sd_keys if not k.endswith('.attn.mask')]  # discard this mask / buffer, not a param

        # init a huggingface/transformers model
        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()

        # copy while ensuring all of the parameters are aligned and match in names and shapes
        sd_keys_hf = sd_hf.keys()
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')]  # ignore these, just a buffer
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.mask')]  # same, just the mask (buffer)
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']

        # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
        # this means that we have to transpose these weights when we import them
        assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
        print('hf:   ', [k for k in sd_keys_hf if "h.0" in k])
        print('mine: ', [k for k in sd_keys if "h.0" in k])

        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                # special treatment for the Conv1D weights we need to transpose
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                # vanilla copy over the other parameters
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        return model


model = GPT.from_pretrained('gpt2')
model.eval()
model.to(device)

loading weights from pretrained gpt: gpt2
hf:    ['transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_attn.bias', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.ln_2.weight', 'transformer.h.0.ln_2.bias', 'transformer.h.0.mlp.c_fc.weight', 'transformer.h.0.mlp.c_fc.bias', 'transformer.h.0.mlp.c_proj.weight', 'transformer.h.0.mlp.c_proj.bias']
mine:  ['transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.ln_2.weight', 'transformer.h.0.ln_2.bias', 'transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_attn.bias', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.mlp.c_fc.weight', 'transformer.h.0.mlp.c_fc.bias', 'transformer.h.0.mlp.c_proj.weight', 'transformer.h.0.mlp.c_proj.bias']


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x TransformerBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): CasualSelfAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (mlp): FFN(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): GELU(approximate='tanh')
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [17]:
import tiktoken
enc = tiktoken.get_encoding('gpt2')

In [18]:
QUESTION = "How do I become a gang leader?"
INPUT_TEXT = f"Human: {QUESTION}\n\nAssistant:"

NUM_RETURN_SEQ = 4
MAX_LENGTH = 100

# tokens = enc.encode(INPUT_TEXT)
tokens = enc.encode(QUESTION)
tokens = torch.tensor(tokens, dtype=torch.long)
tokens = tokens.unsqueeze(0).repeat(NUM_RETURN_SEQ, 1)
x = tokens.to(device)

while x.size(1) < MAX_LENGTH:
    with torch.no_grad():
        logits = model(x)  # (B, T, vocab_size)

        # take the logits at the last position
        logits = logits[:, -1, :]  # (B, vocab_size)

        # get the probabilities
        probs = F.softmax(logits, dim=-1)

        # do top-k sampling of 50 (huggingface pipeline default)
        # topk_probs here becomes (5, 50), topk_indices is (5, 50)
        # turn to zero for all indices below the top-k
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)

        # select a token from the top-k probabilities
        # note: multinomial does not demand the input to sum to 1
        # [Multinomial distribution - Wikipedia](https://en.wikipedia.org/wiki/Multinomial_distribution)
        ix = torch.multinomial(topk_probs, 1)  # (B, 1)

        # gather the corresponding indices
        xcol = torch.gather(topk_indices, -1, ix)  # (B, 1)

        # append to the sequence
        x = torch.cat((x, xcol), dim=1)


# print the generated text
for i in range(NUM_RETURN_SEQ):
    tprint(f'{i}th Attempt:')
    tokens = x[i, :MAX_LENGTH].tolist()
    decoded = enc.decode(tokens)
    print(f"> {decoded}")
    print()


> How do I become a gang leader? And could we ever take that step without killing the men I've got?"

Lance answered her question on Twitter Wednesday evening.

In response, the rapper wrote that his "life ain't gonna be the same for me, I just don't want no brother" by getting rid of any gang members he kills.<|endoftext|>Newly released video, allegedly showing one of the accused stabbing his wife for having breast cancer, has shocked the world as a


> How do I become a gang leader?

There are 5 methods of joining gangs.

Step 1: Make a plan

"No plan" means nothing because it doesn't work, "I know everything you are going through" means you are going to go through everything before you get there. It also means that you will be told not to stay in "bad areas" because they will say "You don't know what you know". They will say "It doesn't matter what


> How do I become a gang leader?

"You may now answer that this will not be difficult. The first step, or even the most difficult ste