In [None]:
from dataclasses import dataclass

@dataclass
class GPT2Config:
    block_size: int = 256
    vocab_size: int = 50257
    n_layer: int = 8
    n_head: int = 8
    d_model: int = 784
    dropout: float = 0.1
    d_ff: int | None = None
    activation_function: str = 'gelu'
    
    def __post_init__(self):
        if self.d_ff is None:
            self.d_ff = 4 * self.d_model
        if not isinstance(self.block_size, int) or self.block_size <= 0:
            raise ValueError("block_size must be an integer > 0")
        if not isinstance(self.vocab_size, int) or self.vocab_size <= 0:
            raise ValueError("vocab_size must be an integer > 0")
        if not isinstance(self.n_layer, int) or self.n_layer <= 0:
            raise ValueError("n_layer must be an integer > 0")
        if not isinstance(self.n_head, int) or self.n_head <= 0:
            raise ValueError("n_head must be an integer > 0")
        if not isinstance(self.d_model, int) or self.d_model <= 0:
            raise ValueError("d_model must be an integer > 0")
        if self.d_model % self.n_head != 0:
            raise ValueError("d_model must be divisible by n_head")
        if not isinstance(self.d_ff, int) or self.d_ff <= 0:
            raise ValueError("d_ff must be an integer > 0")
        if not isinstance(self.dropout, float) or not (0 <= self.dropout <= 1):
            raise ValueError("dropout must be a float between 0 and 1")
        if self.activation_function not in ('gelu', 'relu', 'tanh', 'sigmoid'):
            raise ValueError("activation_function must be one of: 'gelu', 'relu', 'tanh', 'sigmoid'")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, config: GPT2Config):
        super().__init__()
        self.load_config(config)
    
    def load_config(self, config: GPT2Config):
        self.config = config
        self.d_model = config.d_model
        self.n_head = config.n_head
        self.d_k = self.d_model // self.n_head

        self.w_q = nn.Linear(self.d_model, self.d_model)
        self.w_k = nn.Linear(self.d_model, self.d_model)
        self.w_v = nn.Linear(self.d_model, self.d_model)
        self.w_o = nn.Linear(self.d_model, self.d_model)
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()

        Q = self.w_q(x)
        K = self.w_k(x)
        V = self.w_v(x)

        Q = Q.view(batch_size, seq_len, self.n_head, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_head, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_head, self.d_k).transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)

        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, V)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        out = self.w_o(out)

        return out


class FeedForward(nn.Module):
    def __init__(self, config: GPT2Config):
        super().__init__()
        self.load_config(config)
        
    def load_config(self, config: GPT2Config):
        self.config = config
        self.d_model = config.d_model
        self.d_ff = config.d_ff
        self.activation_function = config.activation_function

        self.fc1 = nn.Linear(self.d_model, self.d_ff) # type: ignore
        self.fc2 = nn.Linear(self.d_ff, self.d_model) # type: ignore
        self.dropout = nn.Dropout(config.dropout)

        if self.activation_function == 'gelu':
            self.activation = F.gelu
        elif self.activation_function == 'relu':
            self.activation = F.relu
        elif self.activation_function == 'tanh':
            self.activation = torch.tanh
        elif self.activation_function == 'sigmoid':
            self.activation = torch.sigmoid
        else:
            raise ValueError("Unsupported activation function")
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    

class TransformerBlock(nn.Module):
    def __init__(self, config: GPT2Config):
        super().__init__()
        self.load_config(config)
        
    def load_config(self, config: GPT2Config):
        self.config = config
        self.attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)
        self.ln1 = nn.LayerNorm(config.d_model)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(self, x, mask=None):
        attn_out = self.attention(x, mask)
        x = x + self.dropout(attn_out)
        x = self.ln1(x)

        ff_out = self.feed_forward(x)
        x = x + self.dropout(ff_out)
        x = self.ln2(x)

        return x

    
class GPT2Model(nn.Module):
    def __init__(self, config: GPT2Config):
        super().__init__()
        self.load_config(config)
        
    def load_config(self, config: GPT2Config):
        self.config = config
        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.position_embedding = nn.Embedding(config.block_size, config.d_model)
        self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
        self.norm = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.dropout = nn.Dropout(config.dropout)
        
        self.head.weight = self.token_embedding.weight
        
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)
    
    def forward(self, input_ids, labels=None):
        batch_size, seq_len = input_ids.size()
        
        token_embeds = self.token_embedding(input_ids)
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        pos_embeds = self.position_embedding(positions)
        x = token_embeds + pos_embeds

        mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1)
        mask = mask == 1
        mask = mask.unsqueeze(0).unsqueeze(0)

        for layer in self.layers:
            x = layer(x, mask)
        x = self.norm(x)
        logits = self.head(x)

        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 
                           shift_labels.view(-1))

        return {'logits': logits, 'loss': loss}
    
    def generate(self, input_ids, max_length=50, temperature=1.0, top_k=50):
        self.eval()
        with torch.no_grad():
            for _ in range(max_length):
                outputs = self(input_ids)
                next_token_logits = outputs['logits'][:, -1, :] / temperature

                if top_k > 0:
                    top_k = min(top_k, next_token_logits.size(-1))
                    indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
                    next_token_logits[indices_to_remove] = -float('Inf')

                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)

                input_ids = torch.cat([input_ids, next_token], dim=-1)

                if next_token.item() == 50256:
                    break
        
        return input_ids

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from tqdm import tqdm
class CharDataset(Dataset):
    def __init__(self, text, block_size, stoi):
        self.block_size = block_size
        self.stoi = stoi
        self.data = [self.stoi[c] for c in text if c in self.stoi]
    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx:idx+self.block_size], dtype=torch.long)
        y = torch.tensor(self.data[idx+1:idx+self.block_size+1], dtype=torch.long)
        return x, y
def train_gpt2_tiny_shakespeare(config: GPT2Config, device='cuda', epochs=1, batch_size=32, lr=3e-4, fp16=True):
    dataset = load_dataset('tiny_shakespeare')['train'] #type: ignore
    text = dataset[0]['text']
    vocab = sorted(list(set(text)))
    stoi = {ch: i for i, ch in enumerate(vocab)}
    itos = {i: ch for ch, i in stoi.items()}
    config.vocab_size = len(vocab)

    ds = CharDataset(text, config.block_size, stoi)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)

    model = GPT2Model(config).to(device)
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs for training!")
        model = torch.nn.DataParallel(model)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scaler = torch.amp.GradScaler('cuda', enabled=fp16)
    model.train()
    for epoch in range(epochs):
        pbar = tqdm(dl, desc=f"Epoch {epoch+1}/{epochs}")
        for x, y in pbar:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            with torch.amp.autocast('cuda', enabled=fp16):
                out = model(x, y)
                loss = out['loss'] if not isinstance(model, torch.nn.DataParallel) else out['loss']
            if loss is not None:
                if loss.dim() > 0:
                    loss = loss.mean()
                if not torch.isnan(loss):
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
            pbar.set_postfix({"loss": f"{loss.item():.4f}" if loss is not None else 'None'})

    return model, stoi, itos


In [None]:
import torch
from config import GPT2Config
from train import train_gpt2_tiny_shakespeare

if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    config = GPT2Config(block_size=128, n_layer=4, n_head=4, d_model=128, dropout=0.1)
    print("Training GPT-2 on tiny Shakespeare (FP16)...")
    model, stoi, itos = train_gpt2_tiny_shakespeare(config, device=device, epochs=1, batch_size=32, lr=3e-4, fp16=True)

    if isinstance(model, torch.nn.DataParallel):
        print("Model is wrapped in DataParallel.")
    elif torch.cuda.device_count() > 1:
        print(f"Wrapping model in DataParallel for inference on {torch.cuda.device_count()} GPUs!")
        model = torch.nn.DataParallel(model)

    torch.save(model.state_dict(), "gpt2_tiny_shakespeare.pt")
    print("Model saved to gpt2_tiny_shakespeare.pt")

    prompt = "ROMEO: "
    input_ids = torch.tensor([[stoi[c] for c in prompt]], dtype=torch.long).to(device)
    with torch.cuda.amp.autocast():
        out_ids = model.module.generate(input_ids, max_length=100) if isinstance(model, torch.nn.DataParallel) else model.generate(input_ids, max_length=100)
    out_text = ''.join([itos[i] for i in out_ids[0].tolist()])
    print("\nSample generated text:\n")
    print(out_text)
