Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
Portions of this notebook consist of AI-generated content.

Permission is hereby granted, free of charge, to any person obtaining a copy

of this software and associated documentation files (the "Software"), to deal

in the Software without restriction, including without limitation the rights

to use, copy, modify, merge, publish, distribute, sublicense, and/or sell

copies of the Software, and to permit persons to whom the Software is

furnished to do so, subject to the following conditions:



The above copyright notice and this permission notice shall be included in all

copies or substantial portions of the Software.



THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR

IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,

FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE

AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER

LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,

OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE

SOFTWARE.

# Lab 10: Build a Tiny LLaMA from Scratch

Welcome! Our previous attention tutorial focused on core attention math. Here we put those pieces together into a real model that can tokenize, train on toy data, and generate text. In this hands‑on lab we’ll implement a minimal, fully working LLaMA‑style decoder‑only transformer from scratch without `transformers` dependency for the model itself.
We will educational clarity plus correctness (masks, RoPE, RMSNorm, SwiGLU, weight tying, etc.), with a tiny config so everything runs on AMD GPUs.




## Learning Objectives
By the end, you will be able to:
1. Implement the **LLaMA block** (RMSNorm → RoPE → MHA → SwiGLU MLP → residuals).
2. Use **PyTorch SDPA** (`scaled_dot_product_attention`) for efficient, numerically-stable attention with **causal masking** and **padding masks**.
3. Apply **Rotary Position Embeddings (RoPE)** correctly to Q/K.
4. Implement **RMSNorm** and **SwiGLU** as in modern LLaMA variants.
5. Do **weight tying** between token embedding and LM head.
6. Train a toy model end-to-end and **generate** text with greedy decoding.


---
## 1. Environment Setup 


In [2]:
# Core libraries for custom architectures
import math
import os
import warnings

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

warnings.filterwarnings("ignore")

# Configure device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Set random seeds for reproducibility
torch.manual_seed(42)
if device.type == "cuda":
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)

print("Environment ready for Tiny LLaMA implementation")

Using device: cuda
PyTorch version: 2.9.1+rocm7.10.0
GPU: Radeon 8060S Graphics
GPU Memory: 103.1 GB
Environment ready for Tiny LLaMA implementation


---
## 2. Tiny Byte‑Level Tokenizer (No External Files)
To keep this truly self-contained, we’ll use a **byte tokenizer**: each UTF‑8 byte (0–255) is a token. This is simple and works for any text.

**Trade‑off:** This is not BPE, but it avoids any external dependencies. It’s perfect for a small demo.


In [3]:
class ByteTokenizer:
    def __init__(self, vocab_size=256):
        assert vocab_size == 256, "This simple tokenizer maps bytes 0..255."
        self.vocab_size = 256

    def encode(self, text: str):
        return list(text.encode("utf-8"))

    def decode(self, ids):
        return bytes([int(i) % 256 for i in ids]).decode("utf-8", errors="replace")


tokenizer = ByteTokenizer()
vocab_size = tokenizer.vocab_size
print("Tokenizer ready. Vocab size:", vocab_size)

Tokenizer ready. Vocab size: 256


---
## 3. LLaMA Building Blocks (RMSNorm, RoPE, SwiGLU)

We define the following useful components for later useage:

- **RMSNorm** (no mean-centering) is simple and stable. 
- **RoPE** rotates Q/K to inject positions. 
- **SwiGLU** gates the MLP.


In [4]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x):
        norm = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
        return self.weight * x * norm


def rotate_half(x):  # x: [..., Dh], Dh must be even
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat([-x2, x1], dim=-1)


def rope_apply(q, k, cos, sin):
    """
    q, k: [B, H, T, Dh]
    cos, sin: [1, 1, T, Dh] (already reshaped for broadcasting)
    """
    q_out = (q * cos) + (rotate_half(q) * sin)
    k_out = (k * cos) + (rotate_half(k) * sin)
    return q_out, k_out


class RotaryEmbedding(nn.Module):
    def __init__(self, dim: int, base: float = 10000.0, max_seq_len: int = 2048):
        super().__init__()
        assert dim % 2 == 0, "RoPE dim must be even."
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))  # [Dh/2]
        t = torch.arange(max_seq_len, dtype=torch.float32)  # [T]
        freqs = torch.einsum("i,j->ij", t, inv_freq)  # [T, Dh/2]
        emb = torch.cat([freqs, freqs], dim=-1)  # [T, Dh]
        self.register_buffer("cos_cached", emb.cos(), persistent=False)  # [T, Dh]
        self.register_buffer("sin_cached", emb.sin(), persistent=False)  # [T, Dh]

    def get_cos_sin(self, seq_len: int):
        return self.cos_cached[:seq_len], self.sin_cached[:seq_len]


class SwiGLU(nn.Module):
    def __init__(self, dim_in: int, dim_hidden: int):
        super().__init__()
        self.w1 = nn.Linear(dim_in, dim_hidden, bias=False)
        self.w2 = nn.Linear(dim_in, dim_hidden, bias=False)
        self.w3 = nn.Linear(dim_hidden, dim_in, bias=False)

    def forward(self, x):
        return self.w3(torch.nn.functional.silu(self.w1(x)) * self.w2(x))


print("Components defined.")

Components defined.


---
## 4. Multi-Head Attention (SDPA, Causal + Padding Masks)

We defin our multihead attention block using pre-build attentionPyTorch SDPA for speed/stability. `is_causal=True` handles causal masking. Optional padding mask uses `True` = masked key.


In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim: int, n_heads: int, rope: RotaryEmbedding, dropout: float = 0.0):
        super().__init__()
        assert dim % n_heads == 0
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.qkv = nn.Linear(dim, 3 * dim, bias=False)
        self.out = nn.Linear(dim, dim, bias=False)
        self.rope = rope
        self.dropout_p = dropout

    def forward(self, x, key_padding_mask=None):
        B, T, D = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.split(D, dim=-1)

        def to_heads(t):
            return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        qh, kh, vh = map(to_heads, (q, k, v))

        cos, sin = self.rope.get_cos_sin(T)
        cos = cos.to(qh.dtype).to(qh.device)[None, None, :, :]
        sin = sin.to(qh.dtype).to(qh.device)[None, None, :, :]
        qh, kh = rope_apply(qh, kh, cos, sin)

        attn_mask = None
        if key_padding_mask is not None:
            attn_mask = key_padding_mask[:, None, None, :].expand(B, 1, T, T)

        out = torch.nn.functional.scaled_dot_product_attention(
            qh,
            kh,
            vh,
            attn_mask=attn_mask,
            dropout_p=self.dropout_p if self.training else 0.0,
            is_causal=True,
        )
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        return self.out(out)


print("attention head is defined.")

attention head is defined.


---
## 5. LLaMA Block and Model

We now integrate the components together with the following structure:

LLaMA block = 
RMSNorm → MHA → residual → RMSNorm → SwiGLU MLP → residual. 

We also tie embedding ↔ LM head weights.


In [6]:
class LLaMABlock(nn.Module):
    def __init__(self, dim: int, n_heads: int, mlp_ratio: float, rope: RotaryEmbedding, dropout: float = 0.0):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = MultiHeadAttention(dim, n_heads, rope, dropout=dropout)
        self.norm2 = RMSNorm(dim)
        hidden = int(mlp_ratio * dim)
        self.mlp = SwiGLU(dim, hidden)

    def forward(self, x, key_padding_mask=None):
        x = x + self.attn(self.norm1(x), key_padding_mask=key_padding_mask)
        x = x + self.mlp(self.norm2(x))
        return x


class LLaMAFromScratch(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        dim: int = 256,
        n_layers: int = 4,
        n_heads: int = 8,
        mlp_ratio: float = 4.0,
        max_seq_len: int = 256,
        dropout: float = 0.0,
    ):
        super().__init__()
        assert dim % n_heads == 0
        self.vocab_size = vocab_size
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.embed = nn.Embedding(vocab_size, dim)
        self.rope = RotaryEmbedding(dim // n_heads, max_seq_len=max_seq_len)
        self.blocks = nn.ModuleList(
            [LLaMABlock(dim, n_heads, mlp_ratio, self.rope, dropout=dropout) for _ in range(n_layers)]
        )
        self.norm_f = RMSNorm(dim)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)
        self.lm_head.weight = self.embed.weight  # weight tying

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            fan_in = m.weight.shape[1]
            std = 0.02 / math.sqrt(max(1, fan_in))
            with torch.no_grad():
                nn.init.trunc_normal_(m.weight, mean=0.0, std=std, a=-2 * std, b=2 * std)
                if getattr(m, "bias", None) is not None:
                    nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            with torch.no_grad():
                nn.init.normal_(m.weight, mean=0.0, std=0.02)

    def forward(self, input_ids, key_padding_mask=None):
        B, T = input_ids.shape
        if self.max_seq_len < T:
            raise ValueError(f"Sequence length {T} exceeds max_seq_len={self.max_seq_len}.")
        x = self.embed(input_ids)
        for blk in self.blocks:
            x = blk(x, key_padding_mask=key_padding_mask)
        x = self.norm_f(x)
        logits = self.lm_head(x)
        return logits

    @torch.no_grad()
    def generate(self, input_ids, max_new_tokens=50):
        self.eval()
        out = input_ids.clone().to(next(self.parameters()).device)
        for _ in range(max_new_tokens):
            tokens = out[:, -self.max_seq_len :]
            logits = self(tokens)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            out = torch.cat([out, next_token], dim=1)
        return out


print("Attention computational block is defined.")

Attention computational block is defined.


---
## 6.Tiny Sanity Tests
Build a small model and check shapes/parameter counts.

In [None]:
cfg = {"vocab_size": 256, "dim": 192, "n_layers": 3, "n_heads": 6, "mlp_ratio": 4.0, "max_seq_len": 128, "dropout": 0.0}
model = LLaMAFromScratch(**cfg).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(model)
print(f"Total parameters: {n_params:,}")

LLaMAFromScratch(
  (embed): Embedding(256, 192)
  (rope): RotaryEmbedding()
  (blocks): ModuleList(
    (0-2): 3 x LLaMABlock(
      (norm1): RMSNorm()
      (attn): MultiHeadAttention(
        (qkv): Linear(in_features=192, out_features=576, bias=False)
        (out): Linear(in_features=192, out_features=192, bias=False)
        (rope): RotaryEmbedding()
      )
      (norm2): RMSNorm()
      (mlp): SwiGLU(
        (w1): Linear(in_features=192, out_features=768, bias=False)
        (w2): Linear(in_features=192, out_features=768, bias=False)
        (w3): Linear(in_features=768, out_features=192, bias=False)
      )
    )
  )
  (norm_f): RMSNorm()
  (lm_head): Linear(in_features=192, out_features=256, bias=False)
)
Total parameters: 1,819,968


---
## 7. Toy Dataset & Training Loop
We train on a tiny corpus to verify the loss goes down (next‑token prediction).

In [None]:
class ByteTokenizer:
    def __init__(self, vocab_size=256):
        assert vocab_size == 256
        self.vocab_size = 256

    def encode(self, text: str):
        return list(text.encode("utf-8"))

    def decode(self, ids):
        return bytes([int(i) % 256 for i in ids]).decode("utf-8", errors="replace")


tokenizer = ByteTokenizer()

text_corpus = [
    "LLaMA from scratch.\n",
    "We build a tiny decoder-only Transformer.\n",
    "This is educational and runs on CPU or GPU.\n",
    "Attention with RoPE and RMSNorm works well.\n",
    "SwiGLU MLP and weight tying included.\n",
]
data = tokenizer.encode("".join(text_corpus))
print("Toy corpus bytes:", len(data))


def make_batches(data, seq_len=64, batch_size=8):
    X, Y = [], []
    for i in range(0, len(data) - seq_len - 1, seq_len):
        x = data[i : i + seq_len]
        y = data[i + 1 : i + seq_len + 1]
        X.append(x)
        Y.append(y)
        if len(X) == batch_size:
            xb = torch.tensor(X, dtype=torch.long, device=device)
            yb = torch.tensor(Y, dtype=torch.long, device=device)
            yield xb, yb
            X, Y = [], []
    if X:
        xb = torch.tensor(X, dtype=torch.long, device=device)
        yb = torch.tensor(Y, dtype=torch.long, device=device)
        yield xb, yb


optim = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=0.01)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
loss_fn = nn.CrossEntropyLoss()


def train_one_epoch(epoch, seq_len=64, batch_size=8):
    model.train()
    losses = []
    for xb, yb in make_batches(data, seq_len=seq_len, batch_size=batch_size):
        optim.zero_grad(set_to_none=True)
        with torch.autocast(
            device_type=("cuda" if device.type == "cuda" else "cpu"),
            dtype=(torch.bfloat16 if device.type == "cuda" else torch.float32),
        ):
            logits = model(xb)
            loss = loss_fn(logits.view(-1, logits.size(-1)), yb.view(-1))
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()
        losses.append(loss.item())
    print(f"Epoch {epoch} | loss={np.mean(losses):.4f}")


for epoch in range(1, 300):
    train_one_epoch(epoch)

Toy corpus bytes: 188
Epoch 1 | loss=5.5427
Epoch 2 | loss=5.0205
Epoch 3 | loss=4.4570
Epoch 4 | loss=3.9181
Epoch 5 | loss=3.5630
Epoch 6 | loss=3.3616
Epoch 7 | loss=3.2656
Epoch 8 | loss=3.2190
Epoch 9 | loss=3.1997
Epoch 10 | loss=3.1917
Epoch 11 | loss=3.1818
Epoch 12 | loss=3.1709
Epoch 13 | loss=3.1669
Epoch 14 | loss=3.1697
Epoch 15 | loss=3.1618
Epoch 16 | loss=3.1471
Epoch 17 | loss=3.1425
Epoch 18 | loss=3.1394
Epoch 19 | loss=3.1327
Epoch 20 | loss=3.1274
Epoch 21 | loss=3.1226
Epoch 22 | loss=3.1186
Epoch 23 | loss=3.1074
Epoch 24 | loss=3.1034
Epoch 25 | loss=3.0971
Epoch 26 | loss=3.0794
Epoch 27 | loss=3.0611
Epoch 28 | loss=3.0173
Epoch 29 | loss=2.9718
Epoch 30 | loss=2.9669
Epoch 31 | loss=3.0640
Epoch 32 | loss=2.9091
Epoch 33 | loss=2.9654
Epoch 34 | loss=2.8309
Epoch 35 | loss=2.9067
Epoch 36 | loss=2.8540
Epoch 37 | loss=2.8073
Epoch 38 | loss=2.7759
Epoch 39 | loss=2.7588
Epoch 40 | loss=2.7415
Epoch 41 | loss=2.7279
Epoch 42 | loss=2.6770
Epoch 43 | loss=2.634

---
## 8. Greedy Text Generation
Prompt and generate a few tokens.

In [9]:
prompt = "LLaMA"
inp = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device)
out_ids = model.generate(inp, max_new_tokens=80)[0].tolist()
print(tokenizer.decode(out_ids))

LLaMA from scratch.
We build a tiny decoder-only Transformer.
Thi o andmuddn  rrr P.e


## Lab Summary

### Technical Concepts Learned
- **RMSNorm Implementation**: Root Mean Square normalization without mean-centering for stable transformer training
- **Rotary Position Embeddings (RoPE)**: Applying rotary embeddings to Q/K tensors with correct broadcasting for relative position encoding
- **SwiGLU MLP**: Implementing gated linear units with SiLU activation for improved model expressivity
- **PyTorch SDPA**: Using `scaled_dot_product_attention` for efficient causal masking and numerically stable attention
- **Weight Tying**: Sharing parameters between token embedding and LM head to reduce model size

### Experiment Further
- Add KV cache for O(T·d) per-step inference instead of recomputing full attention
- Replace byte-level tokenizer with BPE (tiktoken/SentencePiece) for better sample efficiency
- Implement learning rate warmup and cosine decay schedule for training stability
- Add dropout to attention and MLP layers and tune RMSNorm epsilon for regularization
- Enable mixed precision (FP16/BF16) training with activation checkpointing for larger models