# Importing libraries

In [1]:
import os
import json
import random
import time
from datetime import datetime
import numpy as np
from dataclasses import dataclass
from tqdm import tqdm
from typing import Dict, Optional, Tuple, Type, Literal
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.profiler import profile, record_function, ProfilerActivity
from transformers import get_cosine_schedule_with_warmup
import wandb

# Configuration

In [2]:
@dataclass
class TokenizerConfig:
    vocab_path: str = "char_vocab.json"

@dataclass
class ModelConfig:
    vocab_size: int = -1
    max_seq_len: int = 256
    d_embed: int = 128
    n_layers: int = 4
    norm_eps: float = 1e-5
    dropout: float = 0.1

    # Attention
    attn_type: Literal["mha", "gqa", "mla"] = "mha"
    n_heads: int = 4
    d_head: int = d_embed // n_heads
    attn_bias: bool = False
    n_kv_heads: Optional[int] = None
    d_latent: Optional[int] = None
    ## Mixture of Attention Heads
    moh: bool = False
    n_activated_heads: Optional[int] = None
    n_shared_heads: Optional[int] = None

    # FeedForward
    d_ff: int = d_embed * 4
    mlp_bias: bool = False
    activation: Type[nn.Module] = nn.GELU
    d_ff_multiplier: Optional[float] = None
    d_ff_multiple_of: int = 256
    ## Mixture of Experts
    moe: bool = True
    n_experts: Optional[int] = 4
    n_activated_experts: Optional[int] = 1
    n_shared_experts: Optional[int] = None

@dataclass
class DatasetConfig:
    local_dir: str = f"datasets/Shakespeare/shakespeare.txt"
    val_size: float = 0.05

@dataclass
class TrainConfig:
    debug: bool = False
    wandb_project: str = "nanoGPT"
    run_name = f"nanoGPT-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
    output_dir: str = f"checkpoints/nanoGPT"
    num_workers: int = 4

    # Training
    per_device_train_batch_size: int = 512
    per_device_eval_batch_size: int = 1024
    num_train_epochs: int = 1
    learning_rate: float = 2e-3
    weight_decay: float = 0.1
    optim: torch.optim.Optimizer = torch.optim.AdamW
    betas: tuple[float, float] = (0.9, 0.95)
    eps: float = 1e-8
    warmup_ratio: float = 0.1
    max_grad_norm: float = 1.0
    gradient_accumulation_steps: int = 1024 // per_device_train_batch_size
    eval_steps: int = 100
    seed: int = 101
    ## Precision
    mixed_precision: bool = True
    matmul_precision: Literal["highest", "high", "medium"] = "high"

@dataclass
class GenerationConfig:
    max_new_tokens: int = 200
    temperature: float = 1.0
    top_k: int = 50

tokenizer_config = TokenizerConfig()
model_config = ModelConfig()
dataset_config = DatasetConfig()
train_config = TrainConfig()
generation_config = GenerationConfig()

## Weights & Biases

In [3]:
wandb.login(key=os.environ.get("WANDB_API_KEY"))

[34m[1mwandb[0m: Currently logged in as: [33mpathfinderkr[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# Utils

## Reproducibility

In [4]:
def set_seed(seed: int):
    """
    Set the random seed for reproducibility.

    Args:
        seed (int): The seed value to set.
    """
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Random seed set to {seed}")

set_seed(train_config.seed)

Random seed set to 101


## Device

In [5]:
device = torch.device("cuda")
torch.set_float32_matmul_precision(train_config.matmul_precision)  # Tensor Cores

# Dataset

In [8]:
def load_text(file_path: str, encoding: str = 'utf-8') -> str:
    """
    Load and read text data from a file.

    Args:
        file_path (str): Path to the text file.
        encoding (str, optional): File encoding. Defaults to 'utf-8'.

    Returns:
        str: The content of the text file.
    """
    if not os.path.isfile(file_path):
        print(f"File not found: {file_path}")
        raise FileNotFoundError(f"File not found: {file_path}")

    with open(file_path, 'r', encoding=encoding) as f:
        text = f.read()

    print(f"Loaded text data from {file_path} (length: {len(text)} characters).")
    return text

shakespeare_text = load_text(f"../{dataset_config.local_dir}")

Loaded text data from ../datasets/Shakespeare/shakespeare.txt (length: 1115394 characters).


In [7]:
if train_config.debug:
    print(shakespeare_text[:1000])

# Tokenizer

In [8]:
class CharTokenizer:
    def __init__(self, vocab: Optional[Dict[str, int]] = None):
        """
        Initialize the character-level tokenizer.

        Args:
            vocab (dict, optional): A pre-defined vocabulary mapping. If None, it will be built from data.
        """
        if vocab is not None:
            self.char2idx = vocab
            self.idx2char = {idx: char for char, idx in vocab.items()}
            self.vocab_size = len(vocab)
        else:
            self.char2idx: Dict[str, int] = {}
            self.idx2char: Dict[int, str] = {}
            self.vocab_size: int = 0

    def build_vocab(self, text: str):
        """
        Build vocabulary from the provided text.

        Args:
            text (str): The text data to build the vocabulary from.
        """
        unique_chars = sorted(set(text))
        print(f"Unique characters: {len(unique_chars)}")
        self.char2idx = {char: idx for idx, char in enumerate(unique_chars)}
        self.idx2char = {idx: char for char, idx in self.char2idx.items()}
        self.vocab_size = len(self.char2idx)

    def encode(self, text: str) -> torch.Tensor:
        """
        Encode a string into a tensor of integer token IDs.

        Args:
            text (str): The text to encode.

        Returns:
            torch.Tensor: The encoded tensor.
        """
        ids = []
        for char in text:
            if char in self.char2idx:
                ids.append(self.char2idx[char])
            else:
                ids.append(self.char2idx["?"])
        return torch.tensor(ids, dtype=torch.long)

    def decode(self, tokens: torch.Tensor) -> str:
        """
        Decode a tensor of integer token IDs into a string.

        Args:
            tokens (torch.Tensor): The tensor of token IDs.

        Returns:
            str: The decoded string.
        """
        chars = []
        for idx in tokens:
            if idx in self.idx2char:
                chars.append(self.idx2char[idx])
            else:
                chars.append("?")
        return ''.join(chars)

    def save_vocab(self, file_path: str):
        """
        Save the vocabulary to a JSON file.

        Args:
            file_path (str): The path to save the vocabulary file.
        """
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(self.char2idx, f, ensure_ascii=False, indent=4)
        print(f"Vocabulary saved to {file_path}.")

    def load_vocab(self, file_path: str):
        """
        Load the vocabulary from a JSON file.

        Args:
            file_path (str): The path to the vocabulary file.
        """
        with open(file_path, 'r', encoding='utf-8') as f:
            self.char2idx = json.load(f)
        self.idx2char = {idx: char for char, idx in self.char2idx.items()}
        self.vocab_size = len(self.char2idx)
        print(f"Vocabulary loaded from {file_path}.")

In [9]:
char_tokenizer = CharTokenizer()
char_tokenizer.build_vocab(text=shakespeare_text)
char_tokenizer.save_vocab(file_path=tokenizer_config.vocab_path)
model_config.vocab_size = char_tokenizer.vocab_size

Unique characters: 65
Vocabulary saved to char_vocab.json.


In [10]:
if train_config.debug:
    print(f"Vocabulary size: {char_tokenizer.vocab_size}")
    print("Vocabulary:", char_tokenizer.char2idx)

# Preprocessing

In [11]:
def split_text(text: str, val_size: float) -> Tuple[str, str]:
    """
    Split text into training and validation sets.

    Args:
        text (str): The data to split.
        val_size (float): Size of the validation set.

    Returns:
        Tuple[str, str]: Training and validation data.
    """
    if val_size < 0 or val_size >= 1:
        raise ValueError(f"Invalid validation size: {val_size}")

    split_idx = int(len(text) * (1 - val_size))
    return text[:split_idx], text[split_idx:]

train_text, val_text = split_text(shakespeare_text, val_size=dataset_config.val_size)

In [12]:
class TextDataset(Dataset):
    def __init__(self, text: str, tokenizer: CharTokenizer, max_seq_len: int):
        self.encoded = tokenizer.encode(text)
        self.max_seq_len = max_seq_len

    def __len__(self) -> int:
        return len(self.encoded) - self.max_seq_len

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        input_ids = self.encoded[idx:idx + self.max_seq_len]
        target_ids = self.encoded[idx + 1:idx + self.max_seq_len + 1]
        return input_ids, target_ids

def collate_fn(batch):
    input_ids = torch.stack([item[0] for item in batch])
    target_ids = torch.stack([item[1] for item in batch])
    return {"input_ids": input_ids, "target_ids": target_ids}

train_dataset = TextDataset(train_text, char_tokenizer, model_config.max_seq_len)
val_dataset = TextDataset(val_text, char_tokenizer, model_config.max_seq_len)

train_loader = DataLoader(
    train_dataset,
    batch_size=train_config.per_device_eval_batch_size,
    shuffle=True,
    num_workers=train_config.num_workers,
    collate_fn=collate_fn
)
val_loader = DataLoader(
    val_dataset,
    batch_size=train_config.per_device_eval_batch_size,
    shuffle=False,
    num_workers=train_config.num_workers,
    collate_fn=collate_fn
)

In [13]:
if train_config.debug:
    sample_batch = next(iter(train_loader))
    print(f"Sample input IDs: {sample_batch['input_ids'][:5]}")
    print(f"Sample target IDs: {sample_batch['target_ids'][:5]}")

# Model

## Multi Head Self-Attention

In [14]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        assert config.d_embed % config.n_heads == 0, "Embedding dimension must be divisible by number of heads"
        self.config = config
        self.qkv_proj = nn.Linear(config.d_embed, 3 * config.d_embed, bias=config.attn_bias)
        self.out_proj = nn.Linear(config.d_embed, config.d_embed, bias=config.attn_bias)
        self.dropout = nn.Dropout(config.dropout)

        self.flash = hasattr(F, "scaled_dot_product_attention")
        if not self.flash:
            print("Flash attention not available, using standard implementation.")
            self.scale = config.d_head ** -0.5
            self.attn_dropout = nn.Dropout(config.dropout)
            self.register_buffer(
                "mask",
                torch.tril(torch.ones(config.max_seq_len, config.max_seq_len)).view(1, 1, config.max_seq_len, config.max_seq_len)
            )

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        # Linear projection
        q, k, v = self.qkv_proj(x).split(self.config.d_embed, dim=2)  # [batch_size, seq_len, d_embed]
        q = q.view(batch_size, seq_len, self.config.n_heads, self.config.d_head).transpose(1, 2)  # [batch_size, n_heads, seq_len, d_head]
        k = k.view(batch_size, seq_len, self.config.n_heads, self.config.d_head).transpose(1, 2)  # [batch_size, n_heads, seq_len, d_head]
        v = v.view(batch_size, seq_len, self.config.n_heads, self.config.d_head).transpose(1, 2)  # [batch_size, n_heads, seq_len, d_head]

        # Casual self-attention
        if self.flash:
            attn = F.scaled_dot_product_attention(q, k, v, dropout_p=self.config.dropout if self.training else 0.0, is_causal=True)  # [batch_size, n_heads, seq_len, d_head]
        else:
            attn = (q @ k.transpose(-2, -1)) * self.scale  # [batch_size, n_heads, seq_len, seq_len]
            attn = attn.masked_fill(self.mask[:, :, :seq_len, :seq_len] == 0, float('-inf'))
            attn = F.softmax(attn, dim=-1)
            attn = self.attn_dropout(attn)
            attn = attn @ v  # [batch_size, n_heads, seq_len, d_head]
        attn = attn.transpose(1, 2).contiguous().view(batch_size, seq_len, self.config.d_embed)  # [batch_size, seq_len, d_embed]

        # Output projection
        attn = self.out_proj(attn)  # [batch_size, seq_len, d_embed]
        attn = self.dropout(attn)

        return attn

## Feed Forward

In [15]:
class FeedForward(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.d_embed, config.d_ff, bias=config.mlp_bias)
        self.activation = config.activation()
        self.fc2 = nn.Linear(config.d_ff, config.d_embed, bias=config.mlp_bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.fc1(x)  # [batch_size, seq_len, d_ff]
        x = self.activation(x)
        x = self.fc2(x)  # [batch_size, seq_len, d_embed]
        x = self.dropout(x)
        return x

## Mixture of Experts

In [16]:
#Expert module
class Expert(nn.Module):
    """ An MLP is a simple linear layer followed by a non-linearity i.e. each Expert """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd)
        )

    def forward(self, x):
        return self.net(x)

In [25]:
class TopkRouter(nn.Module):
    def __init__(self, d_embed = 32, n_experts = 4, top_k= 4):
        super().__init__()
        self.top_k = top_k
        self.gate = nn.Linear(d_embed, n_experts)

    def forward(self, x):
        logits = self.gate(x)  # [batch_size, seq_len, n_experts]
        gating_probs = F.softmax(logits, dim=-1)
        top_probs, indices = torch.topk(gating_probs, self.top_k, dim=-1)  # [batch_size, seq_len, top_k], [batch_size, seq_len, top_k]
        print(f"Top-k probabilities shape: {top_probs.shape}, Indices shape: {indices.shape}")
        indices = indices.unsqueeze(-1).expand(-1, -1, -1, x.size(-1))  # [batch_size, seq_len, top_k, d_embed]
        gating_probs = top_probs.unsqueeze(-1).expand(-1, -1, -1, x.size(-1))  # [batch_size, seq_len, top_k, d_embed]
        return gating_probs, indices

sample_x = torch.randn(batch, seq_len, embed)
print(f"Sample input shape: {sample_x.shape}")

router = TopkRouter(d_embed=32, n_experts=4, top_k=4)
gating_output, indices = router(sample_x)

Sample input shape: torch.Size([1, 4, 32])
Top-k probabilities shape: torch.Size([1, 4, 4]), Indices shape: torch.Size([1, 4, 4])


In [23]:
class TopkRouter(nn.Module):
    def __init__(self, d_embed = 32, n_experts = 4, top_k= 4):
        super().__init__()
        self.top_k = top_k
        self.gate = nn.Linear(d_embed, n_experts)

    def forward(self, x):
        logits = self.gate(x)  # [batch_size, seq_len, n_experts]
        gating_probs = F.softmax(logits, dim=-1)
        top_probs, indices = torch.topk(gating_probs, self.top_k, dim=-1)  # [batch_size, seq_len, top_k], [batch_size, seq_len, top_k]
        print(f"Top-k probabilities shape: {top_probs.shape}, Indices shape: {indices.shape}")
        indices = indices.unsqueeze(-1).expand(-1, -1, -1, x.size(-1))  # [batch_size, seq_len, top_k, d_embed]
        gating_probs = top_probs.unsqueeze(-1).expand(-1, -1, -1, x.size(-1))  # [batch_size, seq_len, top_k, d_embed]
        return gating_probs, indices

class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(SparseMoE, self).__init__()
        self.router = TopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k

    def forward(self, x):
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)
        print(final_output)

        # Reshape inputs for batch processing
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        # Process each expert in parallel
        for i, expert in enumerate(self.experts):
            # Create a mask for the inputs where the current expert is in top-k
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)

            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                expert_output = expert(expert_input)

                # Extract and apply gating scores
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                # Update final output additively by indexing and adding
                final_output[expert_mask] += weighted_output.squeeze(1)

        return final_output

sample_x = torch.randn(batch, seq_len, embed)
print(f"Sample input shape: {sample_x.shape}")

Sample input shape: torch.Size([1, 4, 32])


In [24]:
top_k_gate = TopkRouter(n_embed=embed, num_experts=experts, top_k=activated)
gating_output, indices = top_k_gate(sample_x)

sparse_moe = SparseMoE(n_embed=embed, num_experts=experts, top_k=activated)
final_output = sparse_moe(mh_output)
print("Shape of the final output:", final_output.shape)

sparse_logits shape: torch.Size([1, 4, 4]), indices shape: torch.Size([1, 4, 1])
sparse_logits shape: torch.Size([2, 4, 4]), indices shape: torch.Size([2, 4, 1])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0

In [16]:
class Expert(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.d_embed, config.d_ff, bias=config.mlp_bias)
        self.activation = config.activation()
        self.fc2 = nn.Linear(config.d_ff, config.d_embed, bias=config.mlp_bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.fc1(x)  # [batch_size, seq_len, d_ff]
        x = self.activation(x)
        x = self.fc2(x)  # [batch_size, seq_len, d_embed]
        x = self.dropout(x)
        return x


class Router(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.top_k = config.n_activated_experts
        self.gate = nn.Linear(config.d_embed, config.n_experts)

    def forward(self, x):
        logits = self.gate(x)  # [batch_size, seq_len, n_experts]
        gating_probs = F.softmax(logits, dim=-1)
        top_probs, indices = torch.topk(gating_probs, self.top_k, dim=-1)  # [batch_size, seq_len, top_k], [batch_size, seq_len, top_k]
        indices = indices.unsqueeze(-1).expand(-1, -1, -1, x.size(-1))  # [batch_size, seq_len, top_k, d_embed]
        gating_probs = top_probs.unsqueeze(-1).expand(-1, -1, -1, x.size(-1))  # [batch_size, seq_len, top_k, d_embed]
        return gating_probs, indices


class MoE(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.router = Router(config)
        self.experts = nn.ModuleList([Expert(config) for _ in range(config.n_experts)])

    def forward(self, x):
        gating_probs, indices = self.router(x)
        batch, seq_len, _ = x.size()

        final = torch.zeros_like(x)
        flat_x = x.view(-1, x.size(-1))
        flat_probs = gating_probs.view(-1, gating_probs.size(-1))
        flat_idx = indices.view(-1, indices.size(-1))

        for i, expert in enumerate(self.experts):
            mask = (flat_idx == i).any(dim=-1)
            if not mask.any():
                continue
            expert_in = flat_x[mask]
            expert_out = expert(expert_in)
            scores = flat_probs[mask, i].unsqueeze(1)
            weighted = expert_out * scores
            final.view(-1, final.size(-1))[mask] += weighted

        return final

## Block

In [17]:
class Block(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.norm1 = nn.LayerNorm(config.d_embed, eps=config.norm_eps)
        self.attn = MultiHeadAttention(config)
        self.norm2 = nn.LayerNorm(config.d_embed, eps=config.norm_eps)
        if config.moe:
            self.mlp = MoE(config)
        else:
            self.mlp = FeedForward(config)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

## GPT

In [18]:
class GPT(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config.vocab_size, config.d_embed)
        self.positional_encoding = nn.Embedding(config.max_seq_len, config.d_embed)
        self.dropout = nn.Dropout(config.dropout)
        self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layers)])
        self.norm = nn.LayerNorm(config.d_embed, eps=config.norm_eps)
        self.lm_head = nn.Linear(config.d_embed, config.vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def num_params(self):
        unique = {p.data_ptr(): p for p in self.parameters()}
        return sum(p.numel() for p in unique.values())

    def forward(self, idx, targets=None):
        device = idx.device
        batch_size, seq_len = idx.size()

        # Embedding
        tok_embed = self.embedding(idx)  # [batch_size, seq_len, d_embed]
        pos = torch.arange(0, seq_len, dtype=torch.long, device=device)  # [seq_len]
        pos_embed = self.positional_encoding(pos)  # [seq_len, d_embed]
        x = tok_embed + pos_embed  # [batch_size, seq_len, d_embed]
        x = self.dropout(x)

        # Blocks
        for block in self.blocks:
            x = block(x)  # [batch_size, seq_len, d_embed]

        # Final normalization and linear layer
        x = self.norm(x)
        if targets is not None:
            logits = self.lm_head(x)  # [batch_size, seq_len, vocab_size]
            logits = logits.view(-1, self.config.vocab_size)  # [batch_size * seq_len, vocab_size]
            targets = targets.view(-1)  # [batch_size * seq_len]
            loss = F.cross_entropy(logits, targets, ignore_index=-1)
        else:
            logits = self.lm_head(x[:, [-1], :])  # [batch_size, 1, vocab_size]
            loss = None

        return logits, loss

    @torch.inference_mode()
    def generate(self, idx, tokenizer, max_new_tokens: int, temperature: float = 1.0, top_k: int = 50):
        if not (temperature > 0):
            raise ValueError("temperature must be positive")
        self.eval()

        # Generation loop
        for _ in range(max_new_tokens):
            # Truncate if necessary
            idx_cond = idx if idx.size(1) <= self.config.max_seq_len else idx[:, -self.config.max_seq_len:]
            logits, _ = self(idx_cond)  # [batch_size, 1, vocab_size]

            # Apply temperature and top-k filtering
            logits = logits[:, -1, :] / temperature  # [batch_size, vocab_size]
            if top_k is not None:
                k_logits, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < k_logits[:, [-1]]] = -float('Inf')

            # Sample from the distribution
            probs = F.softmax(logits, dim=-1)
            next_idx = torch.multinomial(probs, num_samples=1)  # [batch_size, 1]
            # Concatenate the next token to the input
            idx = torch.cat((idx, next_idx), dim=1)  # [batch_size, seq_len + 1]

            # Decode and print the next token
            text = tokenizer.decode([next_idx[0].item()])
            print(text, end='', flush=True)

In [19]:
# Initialize the model
model = GPT(model_config).to(device)
#model = torch.compile(model)
print(model)
print(f"Number of parameters: {model.num_params() / 1e6:.2f}M")

GPT(
  (embedding): Embedding(65, 128)
  (positional_encoding): Embedding(256, 128)
  (dropout): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-3): 4 x Block(
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (qkv_proj): Linear(in_features=128, out_features=384, bias=False)
        (out_proj): Linear(in_features=128, out_features=128, bias=False)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp): MoE(
        (router): Router(
          (gate): Linear(in_features=128, out_features=4, bias=True)
          (noise_linear): Linear(in_features=128, out_features=4, bias=True)
        )
        (experts): ModuleList(
          (0-3): 4 x Expert(
            (fc1): Linear(in_features=128, out_features=512, bias=False)
            (activation): GELU(approximate='none')
            (fc2): Linear(in_features=512, out_features=128, bias=

In [None]:
def testTemplate(customFunc, params, test_key):
    start = time.time()
    N, d, B, H = params
    #compute pytorch unfused softmax
    Q, K, V = createQKVSimple(N,d,B,H)
    QKV = badSoftmax(Q,K,V)
    end = time.time()
    pytorch_time = end - start

    with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof:
        with record_function("model_inference"):
            #compute with Naive Unfused
            start = time.time()
            QKS1 = customFunc()
            end = time.time()
            manual_time = end - start

    assert torch.allclose(QKV,QKS1, atol=1e-4), correctness_error_message
    print("manual attention == pytorch attention",torch.allclose(QKV,QKS1, atol=1e-4))
    #print("Pytorch Execution Time:", pytorch_time, "\n")
    print("Manual Execution Time: ", manual_time, "\n")
    print(prof.key_averages().table(sort_by="cpu_memory_usage", row_limit=10))
    r = prof.key_averages()
    for rr in r:
        if rr.key == test_key:
            key, cpu_time, mem_usage = rr.key, rr.cpu_time, rr.cpu_memory_usage
            print (test_key+ " statistics")
            print("cpu time: ", str(cpu_time / 1000.0) + "ms")
            print("mem usage: ", mem_usage, "bytes")

In [None]:
def part1Test(N, d, B, H):
    print("Running Part 1 Test: Naive Unfused Attention\n")
    Q,K,V = createQKVSimple(N,d,B,H)
    attentionModuleStudent = CustomAttention(Q,K,V, B, H, N, d)
    attentionModuleReference = CustomAttention(Q,K,V, B, H, N, d, True)
    params = (N, d, B, H)
    print("-----RUNNING REFERENCE IMPLEMENTATION-----\n")
    testTemplate(attentionModuleStudent.myUnfusedAttention, params, "REFERENCE - NAIVE ATTENTION")
    time.sleep(3)
    print("-----RUNNING STUDENT IMPLEMENTATION-----\n")
    testTemplate(attentionModuleReference.myUnfusedAttention, params, "STUDENT - NAIVE ATTENTION")


# Training

In [20]:
if train_config.debug:
    raise ValueError("Debug mode is enabled. Stopping execution for debugging.")

In [21]:
class Trainer:
    def __init__(
            self,
            train_config: TrainConfig,
            model: nn.Module,
            train_loader: DataLoader,
            val_loader: DataLoader,
            device: torch.device
    ):
        self.train_config = train_config
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device

    def train(self):
        wandb.init(
            project=self.train_config.wandb_project,
            name=self.train_config.run_name,
            config=self.train_config.__dict__
        )
        wandb.watch(self.model, log="all")

        total_steps = (len(self.train_loader) * self.train_config.num_train_epochs // self.train_config.gradient_accumulation_steps)
        warmup_steps = int(self.train_config.warmup_ratio * total_steps)

        optimizer = self.train_config.optim(
            self.model.parameters(),
            lr=self.train_config.learning_rate,
            weight_decay=self.train_config.weight_decay,
            betas=self.train_config.betas,
            eps=self.train_config.eps,
            fused=True
        )
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps
        )

        progress_bar = tqdm(total=total_steps, desc="Training")
        step = 0

        for epoch in range(self.train_config.num_train_epochs):
            for batch_idx, batch in enumerate(self.train_loader):
                self.model.train()
                input_ids = batch["input_ids"].to(self.device)
                target_ids = batch["target_ids"].to(self.device)

                if self.train_config.mixed_precision:
                    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                        outputs, loss = self.model(input_ids, target_ids)
                else:
                    outputs, loss = self.model(input_ids, target_ids)
                loss = loss / self.train_config.gradient_accumulation_steps
                loss.backward()

                if (batch_idx + 1) % self.train_config.gradient_accumulation_steps == 0:
                    grad_norm = clip_grad_norm_(self.model.parameters(), self.train_config.max_grad_norm)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    step += 1

                    wandb.log({
                        "Train Loss": loss.item() * self.train_config.gradient_accumulation_steps,
                        "Learning Rate": scheduler.get_last_lr()[0],
                        "Grad Norm": grad_norm,
                        "Epoch": epoch + 1
                    })
                    progress_bar.set_postfix(
                        loss=f"{loss.item() * self.train_config.gradient_accumulation_steps:.4f}",
                        lr=f"{scheduler.get_last_lr()[0]:.6f}",
                        grad_norm=f"{grad_norm:.4f}",
                        epoch=epoch + 1
                    )
                    progress_bar.update(1)

                    if step % self.train_config.eval_steps == 0:
                        self.validate()

        self.model.eval()
        self.validate()  # Final validation
        progress_bar.close()
        wandb.finish()

    @torch.no_grad()
    def validate(self):
        self.model.eval()
        total_val_loss = 0.0
        total_samples = 0

        for val_batch in self.val_loader:
            val_input_ids = val_batch["input_ids"].to(self.device)
            val_target_ids = val_batch["target_ids"].to(self.device)
            if self.train_config.mixed_precision:
                with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                    _, val_loss = self.model(val_input_ids, val_target_ids)
            else:
                _, val_loss = self.model(val_input_ids, val_target_ids)
            total_val_loss += val_loss.item() * val_input_ids.size(0)
            total_samples += val_input_ids.size(0)

        wandb.log({"Val Loss": total_val_loss / total_samples})

In [22]:
trainer = Trainer(
        train_config=train_config,
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device
)
trainer.train()

Training: 100%|██████████| 517/517 [03:43<00:00,  2.31it/s, epoch=1, grad_norm=0.0786, loss=1.9626, lr=0.000000]


0,1
Epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Grad Norm,█▇▄▃▂▃▂▁▄▁▂▂▂▃▂▂▂▂▂▄▂▃▂▂▆▂▂▂▂▃▂▁▁▁▁▁▁▁▁▁
Learning Rate,▁▃▄▅▇███████▇▇▇▇▇▆▆▆▅▅▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁
Train Loss,█▇▆▅▅▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val Loss,█▆▃▁▁▁

0,1
Epoch,1.0
Grad Norm,0.07861
Learning Rate,0.0
Train Loss,1.96258
Val Loss,1.94662


## Save

In [23]:
if not train_config.debug:
    os.makedirs(train_config.output_dir, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(train_config.output_dir, f"{train_config.run_name}.pt"))
    print(f"Model saved to {os.path.join(train_config.output_dir, f'{train_config.run_name}.pt')}")

Model saved to checkpoints/nanoGPT/nanoGPT-2025-05-29_04-51-39.pt


# Inference

In [29]:
print("=" * 50)
print("User prompt: ")
sample_prompt = "To be, or not to be, that is the question"
print(sample_prompt)
input_ids = char_tokenizer.encode(sample_prompt).to(device).unsqueeze(0)
print("-" * 50)

print("🤖 Model Response:")
model.generate(
    input_ids,
    tokenizer=char_tokenizer,
    max_new_tokens=generation_config.max_new_tokens,
    temperature=generation_config.temperature,
    top_k=generation_config.top_k
)
print()
print("=" * 50)

User prompt: 
To be, or not to be, that is the question
--------------------------------------------------
🤖 Model Response:

Best but garghered sway fard natim; it figh theall,
Comsuck; murrd it by you friestlen?

FROMEO:
I him nowst shall havee this sprit:--

ESers Vond:
And Rand his then tin mare be my
Murage tolly a gau
