



# Character-Level Transformer Training & Generation 🤖✍️

This notebook implements a character-level Transformer model, trains it on the "Tiny Shakespeare" dataset, and then uses it to generate new text.

### Key Features:
* **Model**: A simplified implementation of a modern Transformer architecture, including features like RMS Normalization and Rotary Position Embeddings (RoPE).
* **Training**: The model is trained from scratch to predict the next character in a sequence.
* **Generation**: After training, the model can generate new text starting from a custom prompt, using a KV cache for efficient, token-by-token generation.

## 1. Setup: Imports and Mock Functions

First, we import all the necessary libraries.

Since the original model code was designed for custom CUDA kernels for FP8 quantization, we'll create **mock functions**. These functions use standard PyTorch operations to ensure the code runs correctly in a standard environment like Colab, although without the performance benefits of the custom kernels.


In [1]:
import math
import os
from dataclasses import dataclass
from typing import Tuple, Optional, Literal
from tqdm import tqdm

import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.amp import autocast

# This is a placeholder for the custom CUDA kernels.
# Since we're running in a standard Python environment without compiling custom kernels,
# we will mock these functions to use standard PyTorch operations.
# NOTE: The performance will not match the original intent, but it will be functionally correct.
def mock_act_quant(x, block_size):
    return x, None # Pass-through for non-quantized training

def mock_weight_dequant(weight, scale):
    return weight # Pass-through

def mock_fp8_gemm(x, scale_x, weight, scale_w):
    # This simulates the GEMM operation using standard torch.matmul
    return F.linear(x, weight)

# Mock the kernel functions
act_quant = mock_act_quant
weight_dequant = mock_weight_dequant
fp8_gemm = mock_fp8_gemm

print("Imports and mock functions are ready.")

Imports and mock functions are ready.



## 2. Configuration

Here, we define the global hyperparameters for training and model execution. You can modify these values to experiment with different settings.

In [2]:
# Training Hyperparameters
BATCH_SIZE = 64        # How many independent sequences will we process in parallel?
BLOCK_SIZE = 256       # What is the maximum context length for predictions?
MAX_ITERS = 500        # Total training iterations
EVAL_INTERVAL = 100    # How often to evaluate and print loss
LEARNING_RATE = 3e-4   # Optimizer learning rate
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EVAL_ITERS = 50

# Globals for model definition (will be set in the Transformer __init__)
world_size = 1
rank = 0
# For training, we'll stick to bf16. fp8 training is more complex.
gemm_impl: Literal["bf16", "fp8"] = "bf16"
# 'absorb' is the more memory-efficient attention implementation
attn_impl: Literal["naive", "absorb"] = "absorb"
# block_size for quantization (not critical for bf16 training)
block_size = 128

print(f"Using device: {DEVICE}")

Using device: cuda




## 3. Model Architecture

This section contains the complete definition of our Transformer model. It's composed of several building blocks:

* **`ModelArgs`**: A dataclass to hold all model hyperparameters.
* **`RMSNorm`**: A modern normalization layer.
* **`ParallelEmbedding`, `Linear`, etc.**: Custom linear and embedding layers.
* **`MLA` (Multi-Query LoRA Attention)**: The attention mechanism.
* **`MLP`**: The feed-forward network (part of the Transformer block).
* **`Block`**: A single Transformer block combining attention and a feed-forward network.
* **`Transformer`**: The main class that stacks the blocks together.

A key optimization for generation is in the `Transformer.forward` method. During inference (`targets=None`), it calculates logits for **only the last token** (`h[:, [-1], :]`), which is much more efficient since we only need to predict the very next word.




In [3]:
@dataclass
class ModelArgs:
    """Data class for defining model arguments and hyperparameters."""
    max_batch_size: int = 8
    max_seq_len: int = 4096 * 4
    dtype: Literal["bf16", "fp8"] = "bf16"
    vocab_size: int = 102400
    dim: int = 2048
    inter_dim: int = 10944
    moe_inter_dim: int = 1408
    n_layers: int = 27
    n_dense_layers: int = 1
    n_heads: int = 16
    n_routed_experts: int = 64
    n_shared_experts: int = 2
    n_activated_experts: int = 6
    n_expert_groups: int = 1
    n_limited_groups: int = 1
    score_func: Literal["softmax", "sigmoid"] = "softmax"
    route_scale: float = 1.
    q_lora_rank: int = 0
    kv_lora_rank: int = 512
    qk_nope_head_dim: int = 128
    qk_rope_head_dim: int = 64
    v_head_dim: int = 128
    original_seq_len: int = 4096
    rope_theta: float = 10000.0
    rope_factor: float = 40
    beta_fast: int = 32
    beta_slow: int = 1
    mscale: float = 1.


class ParallelEmbedding(nn.Module):
    def __init__(self, vocab_size: int, dim: int):
        super().__init__()
        self.vocab_size = vocab_size
        self.dim = dim
        assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
        self.part_vocab_size = (vocab_size // world_size)
        self.vocab_start_idx = rank * self.part_vocab_size
        self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
        self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if world_size > 1:
            mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
            x = x - self.vocab_start_idx
            x[mask] = 0
        y = F.embedding(x, self.weight)
        if world_size > 1:
            y[mask] = 0.
            dist.all_reduce(y)
        return y

def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
    if weight.element_size() > 1 or gemm_impl == "bf16":
        w = weight_dequant(weight, getattr(weight, 'scale', None)) if hasattr(weight, 'scale') else weight
        if w.dtype != x.dtype:
            w = w.to(x.dtype)
        if bias is not None and bias.dtype != x.dtype:
            bias = bias.to(x.dtype)
        return F.linear(x, w, bias)
    else:
        x, scale = act_quant(x, block_size)
        y = fp8_gemm(x, scale, weight, weight.scale)
        if bias is not None:
            y += bias
        return y


class Linear(nn.Module):
    dtype = torch.bfloat16
    def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
        if self.weight.element_size() == 1:
            scale_out_features = (out_features + block_size - 1) // block_size
            scale_in_features = (in_features + block_size - 1) // block_size
            self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
        else:
            self.register_parameter("scale", None)
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features, dtype=dtype or Linear.dtype))
        else:
            self.register_parameter("bias", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return linear(x, self.weight, self.bias)

class ColumnParallelLinear(Linear):
    def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None):
        assert out_features % world_size == 0
        self.part_out_features = out_features // world_size
        super().__init__(in_features, self.part_out_features, bias, dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return linear(x, self.weight, self.bias)

class RowParallelLinear(Linear):
    def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None):
        assert in_features % world_size == 0
        self.part_in_features = in_features // world_size
        super().__init__(self.part_in_features, out_features, bias, dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = linear(x, self.weight) # Bias is applied after all_reduce
        if world_size > 1:
            dist.all_reduce(y)
        if self.bias is not None:
            y += self.bias
        return y

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        input_dtype = x.dtype
        output = self._norm(x.float())
        output = output * self.weight
        return output.to(input_dtype)

def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
    dim = args.qk_rope_head_dim
    seqlen = args.max_seq_len
    base = args.rope_theta
    # ... (rest of precompute_freqs_cis is kept the same)
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=DEVICE) / dim))
    t = torch.arange(seqlen, device=DEVICE)
    freqs = torch.outer(t, freqs).float()
    return torch.polar(torch.ones_like(freqs), freqs)

def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
    # Handle empty tensors gracefully
    if x.numel() == 0:
        return x
    dtype = x.dtype
    x = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2) # (1, seq, 1, dim/2)
    y = torch.view_as_real(x * freqs_cis).flatten(3)
    return y.to(dtype)

class MLA(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.dim = args.dim
        self.n_heads = args.n_heads
        self.n_local_heads = args.n_heads // world_size
        self.q_lora_rank = args.q_lora_rank
        self.kv_lora_rank = args.kv_lora_rank
        self.qk_nope_head_dim = args.qk_nope_head_dim
        self.qk_rope_head_dim = args.qk_rope_head_dim
        self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
        self.v_head_dim = args.v_head_dim

        if self.q_lora_rank == 0:
            self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim, bias=False)
        else:
            self.wq_a = Linear(self.dim, self.q_lora_rank, bias=False)
            self.q_norm = RMSNorm(self.q_lora_rank)
            self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim, bias=False)

        self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim, bias=False)
        self.kv_norm = RMSNorm(self.kv_lora_rank)
        self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False)
        self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim, bias=False)
        self.softmax_scale = self.qk_head_dim ** -0.5

        if attn_impl == "naive":
            self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim, dtype=torch.bfloat16), persistent=False)
            self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim, dtype=torch.bfloat16), persistent=False)
        else:
            self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank, dtype=torch.bfloat16), persistent=False)
            self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim, dtype=torch.bfloat16), persistent=False)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
        bsz, seqlen, _ = x.shape
        end_pos = start_pos + seqlen

        if self.q_lora_rank == 0:
            q = self.wq(x)
        else:
            q = self.wq_b(self.q_norm(self.wq_a(x)))
        q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
        q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

        q_pe = apply_rotary_emb(q_pe, freqs_cis)
        kv = self.wkv_a(x)
        kv, k_pe_new = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        k_pe_new = apply_rotary_emb(k_pe_new.unsqueeze(2), freqs_cis).squeeze(2)

        is_causal = start_pos == 0

        if attn_impl == "naive":
            q = torch.cat([q_nope, q_pe], dim=-1)
            kv_b = self.wkv_b(self.kv_norm(kv))
            kv_b = kv_b.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
            k_nope, v = torch.split(kv_b, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            k_pe = k_pe_new.unsqueeze(2).expand(-1, -1, self.n_local_heads, -1)
            k = torch.cat([k_nope, k_pe], dim=-1)
            if not is_causal:
                self.k_cache[range(bsz), start_pos:end_pos] = k
                self.v_cache[range(bsz), start_pos:end_pos] = v
                k = self.k_cache[range(bsz), :end_pos]
                v = self.v_cache[range(bsz), :end_pos]
            scores = torch.einsum("bshd,bthd->bsht", q, k) * self.softmax_scale
        else: # absorb implementation
            wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
            wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
            w = wkv_b[:, :self.qk_nope_head_dim].to(q_nope.dtype)
            q_nope_c = torch.einsum("bshd,hdc->bshc", q_nope, w)
            kv_normed = self.kv_norm(kv)
            if not is_causal:
                self.kv_cache[range(bsz), start_pos:end_pos] = kv_normed
                self.pe_cache[range(bsz), start_pos:end_pos] = k_pe_new
                kv_cache_data = self.kv_cache[range(bsz), :end_pos]
                pe_cache_data = self.pe_cache[range(bsz), :end_pos]
            else:
                kv_cache_data = kv_normed
                pe_cache_data = k_pe_new
            scores = (torch.einsum("bshc,btc->bsht", q_nope_c, kv_cache_data) +
                      torch.einsum("bshr,btr->bsht", q_pe, pe_cache_data)) * self.softmax_scale

        if mask is not None:
            scores += mask.unsqueeze(0).unsqueeze(2)
        scores = F.softmax(scores.float(), dim=-1).type_as(x)

        if attn_impl == "naive":
            output = torch.einsum("bsht,bthd->bshd", scores, v)
        else:
            output = torch.einsum("bsht,btc->bshc", scores, kv_cache_data)
            w = wkv_b[:, -self.v_head_dim:].to(output.dtype)
            output = torch.einsum("bshc,hdc->bshd", output, w)

        return self.wo(output.flatten(2))

class MLP(nn.Module):
    def __init__(self, dim: int, inter_dim: int):
        super().__init__()
        self.w1 = ColumnParallelLinear(dim, inter_dim, bias=False)
        self.w2 = RowParallelLinear(inter_dim, dim, bias=False)
        self.w3 = ColumnParallelLinear(dim, inter_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

class Block(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.attn = MLA(args)
        self.ffn = MLP(args.dim, args.inter_dim)
        self.attn_norm = RMSNorm(args.dim)
        self.ffn_norm = RMSNorm(args.dim)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
        h = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
        out = h + self.ffn(self.ffn_norm(h))
        return out

class Transformer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        global world_size, rank
        world_size = 1
        rank = 0
        self.max_seq_len = args.max_seq_len
        self.embed = ParallelEmbedding(args.vocab_size, args.dim)
        self.layers = torch.nn.ModuleList([Block(i, args) for i in range(args.n_layers)])
        self.norm = RMSNorm(args.dim)
        self.head = ColumnParallelLinear(args.dim, args.vocab_size, bias=False)
        self.register_buffer("freqs_cis", precompute_freqs_cis(args).to(torch.bfloat16), persistent=False)
        self.args = args
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (Linear, ColumnParallelLinear, RowParallelLinear, ParallelEmbedding)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if hasattr(module, 'bias') and module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, tokens: torch.Tensor, start_pos: int = 0, targets: Optional[torch.Tensor] = None):
        seqlen = tokens.size(1)
        h = self.embed(tokens)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
        mask = None
        if seqlen > 1:
            mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
            mask = torch.triu(mask, diagonal=1)

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        h = self.norm(h)

        if targets is not None:
            # Training: compute logits for all tokens and calculate loss
            logits = self.head(h)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        else:
            # Inference: efficiently compute logits for the last token only
            logits = self.head(h[:, [-1], :])
            loss = None

        return logits, loss

print("Model architecture defined.")

Model architecture defined.



## 4. Data Preparation

We'll download the "Tiny Shakespeare" dataset, which is a small text file containing several of Shakespeare's works.

We then perform character-level tokenization:
1.  Find all unique characters in the text to create our vocabulary.
2.  Create mappings from characters to integers (`stoi`) and integers back to characters (`itos`).
3.  Define a helper function `get_batch` to randomly sample chunks of text for training and validation.


In [4]:
# Download the Tiny Shakespeare dataset if it doesn't exist
if not os.path.exists('input.txt'):
    print("Downloading Tiny Shakespeare dataset...")
    !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Character-level tokenization
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Split data into training and validation sets
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data_source = train_data if split == 'train' else val_data
    ix = torch.randint(len(data_source) - BLOCK_SIZE, (BATCH_SIZE,))
    x = torch.stack([data_source[i:i+BLOCK_SIZE] for i in ix])
    y = torch.stack([data_source[i+1:i+BLOCK_SIZE+1] for i in ix])
    x, y = x.to(DEVICE), y.to(DEVICE)
    return x, y

print(f"Dataset loaded. Vocabulary size: {vocab_size}")

Downloading Tiny Shakespeare dataset...
--2025-08-17 10:56:39--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-08-17 10:56:39 (22.4 MB/s) - ‘input.txt’ saved [1115394/1115394]

Dataset loaded. Vocabulary size: 65



## 5. Model Configuration for Training

The default `ModelArgs` are for a large model. For this demo, we'll create a much smaller `TrainingModelArgs` configuration that can be trained quickly on a free Colab GPU.



In [5]:
@dataclass
class TrainingModelArgs(ModelArgs):
    # Drastically reduce parameters to fit on a free Colab GPU
    max_batch_size: int = BATCH_SIZE
    max_seq_len: int = BLOCK_SIZE
    vocab_size: int = vocab_size
    dim: int = 384
    inter_dim: int = 768  # 2x dim
    moe_inter_dim: int = 256
    n_layers: int = 6
    n_heads: int = 6
    # Keep other params simple for this training run
    n_routed_experts: int = 8
    n_dense_layers: int = 6 # Use only MLP layers by setting this high
    qk_nope_head_dim: int = 32
    qk_rope_head_dim: int = 32
    v_head_dim: int = 32
    kv_lora_rank: int = 128
    original_seq_len: int = BLOCK_SIZE

print("Training-specific model arguments are set.")

Training-specific model arguments are set.




## 6. Training Setup

This section defines the helper functions for our training loop:
* **`estimate_loss`**: Evaluates the model's performance on the training and validation sets without updating weights.
* **`save_checkpoint` / `load_checkpoint`**: Functions to save and resume training progress. This is useful in case the Colab instance disconnects.
* **`train`**: The main training function that orchestrates the entire process.


-----

### **Cell 13: Code**

In [6]:
@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(EVAL_ITERS)
        for k in range(EVAL_ITERS):
            X, Y = get_batch(split)
            _, loss = model(X, targets=Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

CHECKPOINT_PATH = "shakespeare_checkpoint.pth"

def save_checkpoint(model, optimizer, iter_num, args):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'iter_num': iter_num,
        'args': args,
    }, CHECKPOINT_PATH)

def load_checkpoint():
    if not torch.cuda.is_available():
        map_location = torch.device("cpu")
    else:
        map_location = DEVICE

    if os.path.exists(CHECKPOINT_PATH):
        print("Resuming from checkpoint...")
        checkpoint = torch.load(CHECKPOINT_PATH, map_location=map_location)
        return checkpoint
    return None

def train():
    torch.manual_seed(1337)

    args = TrainingModelArgs()
    model = Transformer(args)
    model.to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    # Try to load checkpoint
    start_iter = 0
    checkpoint = load_checkpoint()
    if checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_iter = checkpoint['iter_num'] + 1
        print(f"Checkpoint loaded. Resuming from iteration {start_iter}.")

    print(f"Model on {DEVICE}. Total params: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

    pbar = tqdm(range(start_iter, MAX_ITERS), desc="Training")
    for iter_num in pbar:
        if iter_num % EVAL_INTERVAL == 0 or iter_num == MAX_ITERS - 1:
            losses = estimate_loss(model)
            pbar.set_postfix({
                "train_loss": f"{losses['train']:.4f}",
                "val_loss": f"{losses['val']:.4f}"
            })
            save_checkpoint(model, optimizer, iter_num, args)

        xb, yb = get_batch('train')
        with autocast(device_type=DEVICE, dtype=torch.bfloat16, enabled=(DEVICE == 'cuda')):
            logits, loss = model(xb, targets=yb)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

    print("\nTraining complete.")
    final_losses = estimate_loss(model)
    print(f"Final train loss: {final_losses['train']:.4f}, val loss: {final_losses['val']:.4f}")

    print("Saving final model...")
    torch.save(model.state_dict(), 'shakespeare_model.pth')
    return args

print("Training functions are ready.")

Training functions are ready.




### **Cell 14: Markdown**


## 7. Text Generation Function

The `generate` function takes the trained model and a starting prompt to produce new text. It uses several techniques for high-quality sampling:
* **Temperature Scaling**: Controls the randomness of predictions.
* **Top-k Sampling**: Limits the sampling pool to the `k` most likely next tokens.
* **Top-p (Nucleus) Sampling**: Selects from the smallest set of tokens whose cumulative probability exceeds `p`.


-----

### **Cell 15: Code**

In [7]:
@torch.no_grad()
def generate(model_args, prompt: str = "\n", max_new_tokens: int = 500, temperature: float = 0.8, top_k: int = 200, top_p: Optional[float] = None):
    """Generates text from a trained Transformer model using KV-cache."""

    print("\nLoading model for generation...")
    torch.set_default_dtype(torch.bfloat16)

    # Cap max_new_tokens to avoid running past the model's configured sequence length
    if max_new_tokens + len(prompt) > model_args.max_seq_len:
        new_len = model_args.max_seq_len - len(prompt)
        print(f"Warning: Prompt length + max_new_tokens exceeds model's max_seq_len ({model_args.max_seq_len}).")
        print(f"Capping generation at {new_len} tokens.")
        max_new_tokens = new_len

    model = Transformer(model_args)
    model.load_state_dict(torch.load('shakespeare_model.pth', map_location=DEVICE))
    model.to(DEVICE)
    model.eval()

    print("Generating text...")
    start_ids = encode(prompt)
    idx = torch.tensor(start_ids, dtype=torch.long, device=DEVICE)[None, ...]

    # Print the starting prompt first
    print(prompt, end='', flush=True)

    for _ in range(max_new_tokens):
        # crop idx to the last block_size tokens
        idx_cond = idx[:, -BLOCK_SIZE:]
        # get the predictions
        logits, _ = model(idx_cond)
        # focus only on the last time step
        logits = logits[:, -1, :] / temperature
        # apply top-k filtering
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply top-p (nucleus) filtering
        if top_p is not None and 0.0 < top_p < 1.0:
            probs_sort, indices_sort = torch.sort(logits, descending=True)
            probs_sum = torch.cumsum(F.softmax(probs_sort, dim=-1), dim=-1)
            mask = probs_sum > top_p
            mask[:, 1:] = mask[:, :-1].clone()
            mask[:, 0] = 0
            indices_to_remove = mask.scatter(1, indices_sort, mask)
            logits[indices_to_remove] = -float('Inf')

        # sample from the distribution
        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        # append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)

        print(decode(idx_next[0].tolist()), end='', flush=True)

    print("\n--- Generation Complete ---")

print("Generation function is ready.")

Generation function is ready.



## 8. Run Training and Generation! 🚀

This is the final step. We'll call the `train()` function to start training the model. Once training is complete, the `generate()` function is called with your custom prompt.



In [8]:
if __name__ == "__main__":
    trained_model_args = train()
    generate(trained_model_args, prompt="hello thee", max_new_tokens=500)

  self.register_buffer("freqs_cis", precompute_freqs_cis(args).to(torch.bfloat16), persistent=False)


Model on cuda. Total params: 7.35M


Training: 100%|██████████| 500/500 [05:51<00:00,  1.42it/s, train_loss=1.3823, val_loss=1.5896]



Training complete.
Final train loss: 1.3792, val loss: 1.5903
Saving final model...

Loading model for generation...
Capping generation at 246 tokens.
Generating text...
hello thee wind affection.

STANLEY:
He made sometimes encounter my counterfellow
First Squestion? why lady, end you my father,
And I make the mooft the both all dill,
Ses it ender commfort: there by the commend.

First Soldier:
The earthere my could speak
--- Generation Complete ---
