In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
import numpy as np 
from tqdm.auto import tqdm 
from contextlib import nullcontext
import os 

In [None]:
class LayerNorm(nn.Module):
    def __init__(self,ndim,bias) : 
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.parameter(torch.zeros(ndim)) if bias else None
    def forward(self,x):
        return F.layer_norm(x,self.weight.shape,self.weight,self.bias,1e-5)
class CausalSelfAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        assert config.n_embed % config.n_head == 0 
        self.c_attn = nn.Linear(config.n_embed,3*config.n_embed,bias=config.bias) ## for the 3 matrices ; K Q V
        self.c_proj = nn.Linear(config.n_embed,config.n_embed,bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.residual_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embed = config.n_embed
        self.flash = hasattr(F,"scaled_dot_product_attention")
        if not self.flash : 
            self.register_buffer("bias",torch.tril(torch.ones(config.block_size,config.block_size)).view(1,1,config.block_size,config.block_size))
    
    def forward(self,x) : 
        B,T,C = x.size() 
        q,k,v = self.c_attn(x).split(self.n_embed,dim=2)
        q = q.view(B,T,self.n_head,C//self.n_head).transpose(1,2)
        k = k.view(B,T,self.n_head,C//self.n_head).transpose(1,2)
        v = v.view(B,T,self.n_head,C//self.n_head).transpose(1,2)
        
        if self.flash : 
            y = F.scaled_dot_product_attention(q,k,v,attn_mask=None,dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
        else : 
            att = (q@k.transpose(-2,-1)) * (1.0/math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T]==0,float("-inf"))
            att = F.softmax(att,dim=-1)
            att = self.attn_dropout(att)
            y = att @ v 
        y = y.transpose(1,2).contiguous().view(B,T,C)
        y = self.residual_dropout(self.c_proj(y))
        
        return y         

class MLP(nn.Module):
    def __init__(self,config) : 
        super().__init__()
        self.c_fc = nn.Linear(config.n_embed,4*config.n_embed,bias=config.bias)
        self.gelu = nn.GELU()
        self.c_proj  = nn.Linear(4*config.n_embed,config.n_embed,bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self,x):
        return self.dropout(self.c_proj(self.gelu(self.c_fc(x)))) 
class Transformer_Block(nn.Module):
    def __init__(self,config) : 
        super().__init__()
        self.ln1 = LayerNorm(config.n_embed,config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln1 = LayerNorm(config.n_embed,config.bias)
        self.mlp = MLP(config)
    def forward(self,x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        
        return x 

