# GPT

First notebook was written from 'scratch', this one leverages existing libraries to experiment with actual training and inference.

I also added some improvements over baseline model : 
- Moved attention computation to optimized `F.scaled_dot_product_attention`
- Moved `LayerNorm` to `RMSNorm`, which is the standard now
- Moved `GELU` to `SWIGLU`
- Moved positional encoding to `RoPE` on `Q` and `K`
- Disabled bias in every linear layers
- Grouped `Q`, `V`, `K` projections into 1


In [48]:
# Standard library
import csv
import inspect
import math
import multiprocessing
import os
import random
import time
from pprint import pprint
from datetime import datetime

# Environment config
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Third-party
import numpy as np
from datasets import concatenate_datasets, load_dataset, Dataset as ds
from rotary_embedding_torch import RotaryEmbedding
from tokenizers import Tokenizer

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import LinearLR, SequentialLR

# Custom
from utils import clean_columns

# Torch runtime config
torch.set_float32_matmul_precision("medium")
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()

In [49]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, causal: bool = True, dropout: float = 0.1):
        super().__init__()
        if embed_dim % num_heads != 0:
            raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads}).")
        
        self.causal = causal
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.dropout_p = dropout
        
        # Fused QKV projection: 3x the output size
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
        
        self.rotary_emb = RotaryEmbedding(dim=self.head_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
    
    def forward(self, x, k_v_cache=None):
        B, T, _ = x.shape
        using_cache = k_v_cache is not None and "K" in k_v_cache
    
        # 1. Single fused projection
        if using_cache:
            x_q = x[:, -1:, :]
            qkv = self.qkv_proj(x_q)  # (B, 1, 3 x embed_dim)
        else:
            qkv = self.qkv_proj(x)  # (B, T, 3 x embed_dim)
        
        # 2. Split into Q, K, V
        Q, K, V = qkv.chunk(3, dim=-1)  # Each is (B, T, embed_dim)
        
        def split_heads(t):
            return t.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 3. Split heads -> (B, H, T, D_head)
        Q = split_heads(Q)
        K = split_heads(K)
        V = split_heads(V)
    
        # 4. Apply RoPE 
        if using_cache:
            past_len = k_v_cache["K"].shape[-2]
            Q = self.rotary_emb.rotate_queries_or_keys(Q, offset=past_len)
            K = self.rotary_emb.rotate_queries_or_keys(K, offset=past_len)
            
            K = torch.cat([k_v_cache["K"], K], dim=-2)
            V = torch.cat([k_v_cache["V"], V], dim=-2)
        else:
            Q = self.rotary_emb.rotate_queries_or_keys(Q)
            K = self.rotary_emb.rotate_queries_or_keys(K)
    
        # 5. Update cache
        if k_v_cache is not None:
            k_v_cache["K"] = K
            k_v_cache["V"] = V
    
        # 6. Attention
        out = F.scaled_dot_product_attention(
            query=Q,
            key=K,
            value=V,
            attn_mask=None, 
            dropout_p=self.dropout_p if self.training else 0.0,
            is_causal=self.causal
        )
        
        # 7. Merge heads
        out = out.transpose(1, 2).contiguous().view(B, -1, self.embed_dim)

        # 8. Linear projection
        return self.out_proj(out), k_v_cache

In [50]:
class MLP(nn.Module):
    def __init__(self, embed_dim, hidden_dim=None, dropout_prob=0.1, use_swiglu=True):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = 4 * embed_dim
        
        self.use_swiglu = use_swiglu
        
        if use_swiglu:
            # Adjust hidden_dim for param count matching
            hidden_dim = int(2 * hidden_dim / 3)
            self.gate_proj = nn.Linear(embed_dim, hidden_dim, bias=False)
            self.up_proj = nn.Linear(embed_dim, hidden_dim, bias=False)
            self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=False)
        else:
            self.linear1 = nn.Linear(embed_dim, hidden_dim, bias=False)
            self.act = nn.GELU()
            self.linear2 = nn.Linear(hidden_dim, embed_dim, bias=False)
        
        self.dropout = nn.Dropout(dropout_prob)
    
    def forward(self, x):
        if self.use_swiglu:
            return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
        else:
            return self.dropout(self.linear2(self.act(self.linear1(x))))

In [51]:
class TransformerBlock(nn.Module):
    def __init__(self,
                 embed_dim,
                 num_heads,
                 mlp_ratio=4,
                 dropout_prob=0.1,
                 causal=True,
                 use_swiglu=True,
                ): 
        """
        Initialize a complete transformer block.
        
        APPROACH:
        1. Multi-head self-attention for sequence modeling
        2. 1st Normalization (pre-norm architecture)
        3. MLP with specified expansion ratio
        4. 2nd Normalization
    
        TRANSFORMER BLOCK ARCHITECTURE:
        x → Norm → MultiHeadAttention → + (residual) →
            Norm → MLP → + (residual) → output
    
        NB: We use pre-norm architecture (before attention/MLP)
        """
    
        super().__init__()
        self.norm1 = nn.RMSNorm(embed_dim)
        self.mha = MultiHeadAttention(embed_dim, num_heads, causal, dropout_prob)  # causal = masking out tokens
        self.norm2 = nn.RMSNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio * embed_dim, dropout_prob, use_swiglu)
    
    def forward(self, x, cache=None):
        x1 = self.norm1(x)
        x2, cache = self.mha(x1, cache)  # will be used when generating tokens during inference
        x2 = x2 + x  # residual path
    
        x3 = self.norm2(x2)
        x3 = self.mlp(x3) + x2  # residual path
        return x3, cache

In [52]:
class GPT(nn.Module):
    """
    Complete GPT (Generative Pre-trained Transformer) model.

    This combines embeddings, positional encoding, multiple transformer blocks,
    and a language modeling head for text generation.
    """

    def __init__(self,
                 vocab_size,
                 embed_dim,
                 num_layers,
                 num_heads,
                 mlp_ratio=4,
                 dropout_prob=0.1,
                 is_causal=True,
                 use_swiglu=True,
                ):
        """
        Initialize complete GPT model.
        """
        super().__init__()

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.mlp_ratio = mlp_ratio

        self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
        self.dropout = nn.Dropout(dropout_prob)
        self.blocks = nn.ModuleList([TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout_prob, is_causal, use_swiglu) for _ in range(num_layers)])
        self.norm = nn.RMSNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        # weight tying
        self.lm_head.weight = self.embedding.weight

        # below shamefully stolen from nano-gpt
        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            # dont forget swiglu case, oops
            if pn.endswith(("c_proj.weight", "out_proj.weight", "down_proj.weight")):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.num_layers))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
       
    def forward(self, tokens):
        embeddings = self.embedding(tokens)
        x = self.dropout(embeddings)
        for b in self.blocks:
            x, _ = b(x)  # iteratively refines features from initial embeddings
        features = self.norm(x)  # normalized to stabilize training
        return self.lm_head(features)

    @property
    def device(self):
        return next(self.parameters()).device

    @torch.no_grad()
    def generate(self,
                 prompt_tokens,
                 max_new_tokens=50,
                 temperature=1.0,
                 use_cache=True,
                 use_top_k=False,
                ):
        self.eval()

        tokens_out = prompt_tokens.clone()
        current_tokens = prompt_tokens.clone()
        tokens_out = tokens_out.to(self.device)
        current_tokens = current_tokens.to(self.device)
        cache = [{} if use_cache else None for _ in range(len(self.blocks))]
        
        for _ in range(max_new_tokens):

            x = self.embedding(current_tokens)
            for i, b in enumerate(self.blocks):
                x, c_i = b(x, cache[i])
                cache[i] = c_i
            
            features = self.norm(x)
            logits = self.lm_head(features)
                    
            last_logits = logits[:, -1, :]
    
            if temperature > 0:
                scaled_logits = last_logits / temperature
                # Only sample from top k tokens to avoid garbage prediction derailing whole prediction
                # We don't simply take max prob token to allow "creativity"
                if use_top_k:
                    # heuristic that is ok for toy project
                    # most of probability mass in on a small amount of tokens
                    k = min(max(5, int(0.01 * self.vocab_size)), 100)
                    values, indices = torch.topk(scaled_logits, k)
                    scaled_logits = torch.full_like(scaled_logits, float('-inf'))
                    scaled_logits.scatter_(1, indices, values)
                probs = torch.softmax(scaled_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                # Greedy decoding if temp is 0 (prevents division by zero)
                next_token = torch.argmax(last_logits, dim=-1, keepdim=True)

            # Stop generating if model thinks the "document" is finished
            if next_token.item() == eot_id:
                break
            
            tokens_out = torch.cat([tokens_out, next_token], dim=1)

            # If caching, we only need to feed the newest token next time, otherwise full sequence
            current_tokens = next_token if use_cache else tokens_out
       
        return tokens_out

## Full Training

Note : we re-use the tokenizer defined at the beginning

As I do not have unlimited compute, our datasets need to be as clean as possible and high signal. C4 / FineWeb will have massive overlap with OpenWebText (from same common crawl), wikipedia is probably going to be oversampled, which should be fine, bookcorpus is mostly self-published fictions, should have less overlap.

I thought adding maths and code would be a good idea but it introduces big issues with the tokenizer and learning new semantics / syntax etc. might be too ambitious for a small(er) scale model. Arxiv is definitely a no, open-web-math might be fine as it's more informal.

In [53]:
#### CONFIG #####

# Basically GPT-2 Small
block_size = 1024  # Can also do 512 for faster convergence then 1024 to finish training
batch_size = 16
embed_dim = 768
num_layers = 12
num_heads = 12
dropout_prob = 0.1  # usually set as 0 but my model being relatively smaller, should help
mlp_ratio = 4  # standard 4x expansion


# Training
MAX_STEPS = 600000       # Total number of micro-batches to process
GRAD_ACCUM_STEPS = 40    # Accumulate gradients over 40 batches
LOG_INTERVAL = 500       # Log every 500 micro-batches
num_workers = 4          # For data loading
prefetch = 4
dtype = torch.bfloat16
device = "cuda"
model_path = f"gpt_model_maths_{block_size}_final.pt"  # where do we store trained model
log_file = "training_log_pretraining.csv"  # where do we store training logs
print("torch.cuda.is_bf16_supported()", torch.cuda.is_bf16_supported())

torch.cuda.is_bf16_supported() True


In [54]:
# from huggingface_hub import login
# login("your token")  # faster dl

# Setup Tokenizer
tokenizer = Tokenizer.from_pretrained("GPT2")
eot_id = tokenizer.token_to_id("<|endoftext|>")
assert eot_id is not None


# Clean and Concatenate
print("Loading & cleaning up datasets...")
cleaned_datasets = {
    # main dataset, high quality web crawl
    "fineweb-edu": load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train"),
    # might have some overlap with above but should be fine
    "wiki-train": load_dataset("wikitext", "wikitext-2-raw-v1", split="train"),
    "wiki-validation": load_dataset("wikitext", "wikitext-2-raw-v1", split="validation"),
    # for flavor / diversity
    "bookcorpus": load_dataset("bookcorpus", split="train[:10%]"),
    # pretty good Q/A maths dataset (tiny though, might not do much)
    "maths-qa": load_dataset("microsoft/orca-math-word-problems-200k", split="train")  
}
cleaned_datasets = {n: clean_columns(d) for n, d in cleaned_datasets.items()}

print(f"\n✓ Loaded {len(cleaned_datasets)} datasets")
for n, ds in cleaned_datasets.items():
    print(f"  Dataset {n}: {len(ds):,} examples")

print("Concatenating...")
train_ds = concatenate_datasets(cleaned_datasets.values())

# Shuffle to make sure the model doesn't train for hours on wikipedia then suddenly code etc.
train_ds = train_ds.shuffle(seed=42)  # usually bad idea but our data is small enough
print(f"Success! Final Train Size: {len(train_ds)} rows")


def process_batch(examples):
    """
    1. Tokenizes text.
    2. Appends EOT token to EVERY document.
    3. Flattens into a 1D stream.
    4. Chunks into block_size + 1 (to allow for shifting).
    """
    all_token_ids = []
    
    # Tokenize and add EOT (Document Boundary)
    for text in examples["text"]:
        # Skip empty strings
        if not text.strip():
            continue
        
        # Encode
        ids = tokenizer.encode(text).ids
        
        # Append EOT (Crucial for GPT context separation)
        ids.append(eot_id)
        all_token_ids.extend(ids)
    
    # We need chunks of length block_size + 1 (bcs of shifting -> block_size later)
    chunk_len = block_size + 1
    
    # Truncate remainder
    total_len = (len(all_token_ids) // chunk_len) * chunk_len
    
    # Reshape into list of lists
    chunks = [
        all_token_ids[i : i + chunk_len] 
        for i in range(0, total_len, chunk_len)
    ]
    
    # Return dict for HF Dataset
    return {"chunk_ids": chunks}


# Apply the processing
print("Tokenizing and chunking (this may take a moment)...")
train_tokenized = train_ds.map(
    process_batch, 
    batched=True, 
    batch_size=1000, 
    num_proc=multiprocessing.cpu_count(),
    remove_columns=train_ds.column_names,  # Remove 'text' to free up RAM.
    desc="Processing Train"
)

Loading & cleaning up datasets...


Resolving data files:   0%|          | 0/2410 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/98 [00:00<?, ?it/s]


✓ Loaded 5 datasets
  Dataset fineweb-edu: 9,672,101 examples
  Dataset wiki-train: 36,718 examples
  Dataset wiki-validation: 3,760 examples
  Dataset bookcorpus: 7,400,423 examples
  Dataset maths-qa: 200,035 examples
Concatenating...
Success! Final Train Size: 17313037 rows
Tokenizing and chunking (this may take a moment)...


In [55]:
class GPTDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset):
        self.ds = hf_dataset

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        # Retrieve the chunk of size BLOCK_SIZE + 1
        chunk = self.ds[idx]["chunk_ids"].long()
        # Shift target, the model needs to learn token_t -> token_t+1
        return chunk[:-1], chunk[1:]


# Alternatively can set format to numpy and use from_numpy() in the Dataset (zero-copy, much faster than tensor())
train_tokenized.set_format(type="torch", columns=["chunk_ids"])
train_dataset = GPTDataset(train_tokenized)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    prefetch_factor=prefetch,
    pin_memory=True
)

# --- Verification ---
print("-" * 20)
print(f"Final Train Batches: {len(train_loader)}")
x, y = next(iter(train_loader))
print(f"Input x shape: {x.shape}")  # Should be [Batch, Block_Size]
print(f"Target y shape: {y.shape}") # Should be [Batch, Block_Size]

print("\nSanity Check (Shifting):")
print(f"x[0, -5:]: {x[0, -5:].tolist()}") # End of input
print(f"y[0, -5:]: {y[0, -5:].tolist()}") # End of target
del x
del y

--------------------
Final Train Batches: 617609
Input x shape: torch.Size([16, 1024])
Target y shape: torch.Size([16, 1024])

Sanity Check (Shifting):
x[0, -5:]: [198, 13295, 13019, 25, 314]
y[0, -5:]: [13295, 13019, 25, 314, 423]


In [56]:
# --- Model ---
model_config = {
    "vocab_size": tokenizer.get_vocab_size(),
    "embed_dim": embed_dim,
    "num_layers": num_layers,
    "num_heads": num_heads,
    "mlp_ratio": mlp_ratio,
    "dropout_prob": dropout_prob,
    "is_causal": True,
    "use_swiglu": True,
}

print("Initializing model with config : ")
pprint(model_config)

model = GPT(**model_config).to(device)

Initializing model with config : 
{'dropout_prob': 0.1,
 'embed_dim': 768,
 'is_causal': True,
 'mlp_ratio': 4,
 'num_heads': 12,
 'num_layers': 12,
 'use_swiglu': True,
 'vocab_size': 50257}


In [57]:
model = torch.compile(model)   # can take a while
model.train()
print("Model compiled !")

Model compiled !


In [11]:
# # TODO : implement resume from checkpoint (i.e optimizer state, step count etc.)
# ckpt_path = "ckpts/..."

# state_dict = torch.load(ckpt_path, map_location=device)
# model.load_state_dict(state_dict, strict=False)

<All keys matched successfully>

Using standard practice :
- linear warmup + cosine schedule
- big weight decay (empirically proven to be beneficial)
- gradient accumulation to emulate large batch size

Nb : hearing contradicting statements about label_smoothing so disabled for now, we also only decay linear weights unlike nanogpt that also decay embeddings (NanoGPT does not follow modern best practices here apparently)

In [58]:
# --- Optimizer ---
# Separate parameters into decay and no-decay groups
decay_params = []
no_decay_params = []

for name, param in model.named_parameters():
    if not param.requires_grad:
        continue
    
    # Don't apply weight decay to:
    # - Norm parameters (scale/weight etc.)
    # - Embedding table (which is tied to lm_head)
    # - Any bias terms if present
    if any(keyword in name.lower() for keyword in ['norm', 'bias', 'embed', 'embedding', 'lm_head']):
        no_decay_params.append(param)
    else:
        decay_params.append(param)

print(f"\nDecay params: {len(decay_params)}, No decay params: {len(no_decay_params)}")

total_optim_steps = MAX_STEPS // GRAD_ACCUM_STEPS
print(f"Total Micro-batches: {MAX_STEPS}")
print(f"Gradient Accumulation: {GRAD_ACCUM_STEPS}")
print(f"Total Optimizer Updates: {total_optim_steps}")

# Perform ADAM update with a single kernel
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
print("Using fused AdamW : ", use_fused)

optimizer = torch.optim.AdamW([
    {'params': decay_params, 'weight_decay': 0.1},
    {'params': no_decay_params, 'weight_decay': 0.0}
], 
    lr=3e-4, 
    betas=(0.9, 0.95),
    **extra_args
)

warmup_steps = int(total_optim_steps * 0.05)  # 5% of total optim steps
warmup_scheduler = LinearLR(
    optimizer, 
    start_factor=0.1,  # Start at 3e-5
    total_iters=warmup_steps
)
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=total_optim_steps - warmup_steps, 
    eta_min=1e-5
)
scheduler = SequentialLR(
    optimizer, 
    schedulers=[warmup_scheduler, cosine_scheduler], 
    milestones=[warmup_steps]
)
loss_fn = torch.nn.CrossEntropyLoss()


Decay params: 60, No decay params: 26
Total Micro-batches: 600000
Gradient Accumulation: 40
Total Optimizer Updates: 15000
Using fused AdamW :  True


In [None]:
# --- CSV Logger ---
file_exists = os.path.isfile(log_file)
with open(log_file, "a", newline="") as f:
    writer = csv.writer(f)
    if not file_exists:
        writer.writerow(["micro_step", "optim_step", "loss", "lr", "tokens_seen", "tokens_per_sec", "timestamp"])

# --- Training Loop ---
micro_step = 0      # Counts every batch seen
optim_step = 0      # Counts every weight update
tokens_seen = 0
running_loss = 0.0
start_time = time.time()
start_training = time.time()

# Initialize gradients once before starting
optimizer.zero_grad(set_to_none=True)

# Train until we've seen enough tokens
while micro_step < MAX_STEPS:
    for x, y in train_loader:
        
        x, y = x.to(device), y.to(device)
        B, T = x.shape
        tokens_seen += B * T

        # 1. Forward
        with torch.autocast(device_type="cuda", dtype=dtype):
            logits = model(x)
            loss = loss_fn(logits.view(-1, logits.size(-1)), y.view(-1))

        # 2. Scale Loss for Backward (but keep original for logging!)
        current_loss_val = loss.item() 
        scaled_loss = loss / GRAD_ACCUM_STEPS
        
        # 3. Backward
        scaled_loss.backward()

        # 4. Step (only every GRAD_ACCUM_STEPS micro-steps)
        if (micro_step + 1) % GRAD_ACCUM_STEPS == 0:
            # avoids exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()
            optim_step += 1

        # 5. Bookkeeping
        running_loss += current_loss_val
        micro_step += 1

        # 6. Logging
        if micro_step % LOG_INTERVAL == 0:
            elapsed = time.time() - start_time
            avg_loss = running_loss / LOG_INTERVAL
            tokens_per_sec = (B * T * LOG_INTERVAL) / elapsed
            current_lr = optimizer.param_groups[0]['lr']
            timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

            print(
                f"step {micro_step:06d} | "
                f"opt_step {optim_step:04d} | "
                f"loss {avg_loss:.3f} | "
                f"lr {current_lr:.2e} | "
                f"{tokens_per_sec:,.0f} tok/s"
            )

            try:
                with open(log_file, "a", newline="") as f:
                    writer = csv.writer(f)
                    writer.writerow([micro_step, optim_step, f"{avg_loss:.4f}", f"{current_lr:.2e}", tokens_seen, int(tokens_per_sec), timestamp])
            except Exception as e:
                print(f"CSV Error: {e}")

            running_loss = 0.0
            start_time = time.time()

        if micro_step % 50000 == 0:
            mid_model_path = model_path.replace(".pt", f"_{micro_step}.pt")
            print(f"Saving intermediate model in {mid_model_path}")
            torch.save(model.state_dict(), mid_model_path)
        
        if micro_step >= MAX_STEPS:
            elapsed = int(time.time() - start_training)
            h = elapsed // 3600
            m = (elapsed % 3600) // 60
            s = elapsed % 60
            print(f"\nProcessed {tokens_seen:,} tokens in {h:02d}:{m:02d}:{s:02d}")
            print(f"Saving final model in {model_path}")
            torch.save(model.state_dict(), model_path)
            break

step 000500 | opt_step 0012 | loss 9.782 | lr 3.43e-05 | 78,354 tok/s


In [26]:
prompt = "Test prompt I need to sample somewhere "
x = torch.tensor(tokenizer.encode(prompt).ids)

In [None]:
model.to("cuda")
out = model.generate(
    x.unsqueeze(0).to("cuda"),
    max_new_tokens=250,
    temperature=0.9,
    use_cache=True,
    use_top_k=True,
)

print("\nOutput : ", tokenizer.decode(out[0].tolist()))

## TO DO : 

- [x] ROPE for K, V, Q
- [ ] ROPE for K, V cache, should be correct but need double check
- [x] Top k sampling / Temperature
- [x] K / V cache
- [x] Add stop token / EOS handling
- [x] Training on a real problem to see how far we can push current model
- [x] Clean up / Revisit markdown / maths
- [ ] Explore hyper connections and manifold constrained HC
- [x] Check newer architectures / design choices (https://github.com/lucidrains git is a gold mine)
- [ ] MUON optimizer ?