In [None]:
from dataclasses import dataclass

# PyTorch
import torch
import torch.nn as nn

In [None]:
@dataclass
class GPTConfig:
    vocab_size: int = 50257
    context_length: int = 1024
    embedding_size: int = 768
    num_layers: int = 12
    num_heads: int = 12
    hidden_size: int = 3072
    dropout: float = 0.1

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        assert config.embedding_size % config.num_heads == 0
        

In [None]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.fc1 = nn.Linear(config.embedding_size, config.hidden_size)
        self.gelu = nn.GELU(approximate='tanh')
        self.fc2 = nn.Linear(config.hidden_size, config.embedding_size)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.fc2(x)
        return x

In [None]:
class GPTLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.layernorm1 = nn.LayerNorm(config.embedding_size)
        self.attention = nn.SelfAttention(config.embedding_size, config.num_heads, config.dropout)
        self.layernorm2 = nn.LayerNorm(config.embedding_size)
        self.mlp = MLP(config)
        
    def forward(self, x):
        x = x + self.attention(self.layernorm1(x))
        x = x + self.mlp(self.layernorm2(x))
        return x

In [None]:
class GPT(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config
        
        self.transformer = nn.ModuleDict(dict(
            embeddings=nn.Embedding(config.vocab_size, config.embedding_size),
            positional_encoding=nn.Embedding(config.context_length, config.embedding_size),
            layers=nn.ModuleList([GPTLayer(config) for _ in range(config.num_layers)]),
            layernorm=nn.LayerNorm(config.embedding_size)
        ))
        self.
            
        