In [None]:
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_attn = nn.Linear(config.n_embd, 3*config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
    def forward(self,x):   
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q,k,v = qkv.split(self.n_embd, dim=2)

        
class MLP(nn.Modlule):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4*config.n_embd)
        self.gelu = nn.GELU # Gaussian Error for Linear Unit # we use GELU for non linearity and variation
        self.c_proj = nn.Linear(4*config.n_embd, config.n_embd)

    def forward(self,x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x
    
    
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config) # Causal Self-Attention = Masked Self-Attention
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)# MLP = Positionwise FFN

    def forward(self,x):
        x = x + self.attn(self.ln_1(x)) #Prenorm
        x = x + self.mlp(self.ln_2(x))
        return x

@dataclass # only for configuration
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50257
    n_embd: int = 768
    n_layer: int = 12
    n_head: int = 12



class GPT(nn.Module):
    def __init__(self, config):
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd)
            wte = nn.Embedding(config.vocab_size, config.n_embd)
        ))


    def forward(self, x, targets = None):
        B,T = x.size()
        assert T <= self.config.block_size
        pos = torch.arange(0, T, dtype=torch.long, device=x.device)