In [1]:
import torch
import torch.nn as nn
import tiktoken
tokenizer = tiktoken.get_encoding('gpt2')
import torch
from torch.utils.data import Dataset,DataLoader

In [None]:
class MultiHeadAttentio(nn.Module):
  def __init__(self, d_in,d_out,num_heads,dropout,context_len,qvk_bias=False):
    super().__init__()
    assert (d_out % num_heads == 0), \
      "d_out must be divisible by num_heads"

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads

    self.W_query = nn.Linear(d_in,d_out,bias=qvk_bias)
    self.W_value = nn.Linear(d_in,d_out,bias=qvk_bias)
    self.W_key = nn.Linear(d_in,d_out,bias=qvk_bias)
    self.out_proj = nn.Linear(d_out,d_out) #linear layer to combine head outputs
    self.dropout = nn.Dropout(dropout)
    self.register_buffer(
        "mask", torch.triu(torch.ones(context_len,context_len),diagonal=1)
    )

  def forward(self,x):
    b,num_tokens,d_in = x.shape

    keys = self.W_key(x)
    values= self.W_value(x)
    queries = self.W_query(x)

    keys = keys.view(b,num_tokens,self.num_heads,self.head_dim)
    values = values.view(b,num_tokens,self.num_heads,self.head_dim)
    queries = queries.view(b,num_tokens,self.num_heads,self.head_dim)

    keys = keys.transpose(1,2)
    values = values.transpose(1,2)
    queries = queries.transpose(1,2)

    attn_scores = queries @ keys.transpose(-2,-1)

    mask_bool = self.mask.bool()[:num_tokens,:num_tokens]

    attn_scores = attn_scores.masked_fill(mask_bool,-torch.inf)

    attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5,dim = -1)
    attn_weights = self.dropout(attn_weights)

    context_vec = (attn_weights @ values).transpose(1,2)

    context_vec = context_vec.contiguous().view(b,num_tokens,self.d_out)
    context_vec = self.out_proj(context_vec)

    return context_vec
    '''
    view(...) → you're collecting opinions from 4 experts (attention heads) and writing them all down on one page.

out_proj(...) → you're analyzing that combined report to make a decision, with learned weights.
'''


class LayerNorm(nn.Module):
  def __init__(self,emb_dim):
    super().__init__()
    self.eps = 1e-5
    self.scale = nn.Parameter(torch.ones(emb_dim))
    self.shift = nn.Parameter(torch.zeros(emb_dim))

  def forward(self,x):
    mean = x.mean(dim = -1,keepdim = True)
    var = x.var(dim = -1,keepdim = True,unbiased = False)

    norm_x = (x-mean) / torch.sqrt(var + self.eps)

    return self.scale * norm_x + self.shift


class GELU(nn.Module):
  def __init__(self):
    super().__init__()
    self.register_buffer('c', torch.sqrt(torch.tensor(2.0 / torch.pi)))

  def forward(self, x):
    return 0.5 * x * (1 + torch.tanh(self.c * (x + 0.044715 * x.pow(3))))




class FeedForward(nn.Module):
  def __init__(self,cfg):
    super().__init__()
    self.layers = nn.Sequential(
        nn.Linear(cfg["emb_dim"],4*cfg["emb_dim"]),
        GELU(),
        nn.Linear(4*cfg["emb_dim"],cfg["emb_dim"])
    )



  def forward(self,x):
    return self.layers(x)



#self, d_in,d_out,num_heads,dropout,context_len,qvk_bias=False
class TransformerBlock(nn.Module):
  def __init__(self,cfg):
    super().__init__()
    self.att = MultiHeadAttentio(
        d_in = cfg["emb_dim"],
        d_out = cfg["emb_dim"],
        num_heads=cfg["n_heads"],
        dropout = cfg["drop_rate"],
        context_len = cfg["context_length"],
        qvk_bias = cfg["qvk_bias"]
    )
    self.ff = FeedForward(cfg)
    self.norm1 = LayerNorm(cfg["emb_dim"])
    self.norm2 = LayerNorm(cfg["emb_dim"])
    self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

  def forward(self, x):
    shortcut = x
    x = self.norm1(x)
    x = self.att(x) # shape [batchsize, num_token, emb_size]
    x = self.drop_shortcut(x)
    x = shortcut + x


    shortcut = x
    x = self.norm2(x)
    x = self.ff(x)
    x = self.drop_shortcut(x)
    x = x + shortcut

    return x
