This is the corresponding code for the [01: Attention is All You Need](https://yyzhang2025.github.io/100-AI-Papers/posts/01-attention.html).


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import random
import math

import einops

In [None]:
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")


# DEVICE = get_device()
DEVICE = torch.device("cpu")  # Force CPU for debugging
print("Using device:", DEVICE)

In [None]:
import os

os.environ["HF_DATASETS_CACHE"] = "./tmp/hf_cache"

In [None]:
from dataclasses import dataclass


@dataclass
class ModelConfig:
    src_vocab_size: int = 16000
    tgt_vocab_size: int = 16000
    max_seq: int = 128

    d_model: int = 512
    d_ff: int = 2048
    num_heads: int = 8
    num_layers: int = 6
    dropout: float = 0.1

    eps: float = 1e-6  # for Layer Normalization

## Transformer Model Implementation


In [None]:
class WordEmbedding(nn.Module):
    def __init__(self, config: ModelConfig, is_tgt: bool = False):
        super().__init__()

        if is_tgt:
            self.embedding = nn.Embedding(config.tgt_vocab_size, config.d_model)
        else:
            self.embedding = nn.Embedding(config.src_vocab_size, config.d_model)

    def forward(self, x):
        """
        x: (batch_size, seq_len)
        """
        return self.embedding(x)


class PositionalEmbedding(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        pos_index = torch.arange(config.max_seq).unsqueeze(1)  # (max_seq, 1)

        div_term = torch.exp(
            torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model)
        )

        pe = torch.zeros(config.max_seq, config.d_model)  # (max_seq, d_model)
        pe[:, 0::2] = torch.sin(pos_index * div_term)
        pe[:, 1::2] = torch.cos(pos_index * div_term)

        pe = pe.unsqueeze(0)  # (1, max_seq, d_model)

        pe.requires_grad = False
        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        x: (batch_size, seq_len, d_model)
        """
        seq_len = x.size(1)
        return self.pe[:, :seq_len, :]  # (1, seq_len, d_model)


class Embedding(nn.Module):
    def __init__(self, config: ModelConfig, is_tgt: bool = False):
        super().__init__()
        self.word_embedding = WordEmbedding(config, is_tgt)
        self.positional_embedding = PositionalEmbedding(config)

    def forward(self, x):
        """
        x: (batch_size, seq_len)
        """
        word_emb = self.word_embedding(x)
        pos_emb = self.positional_embedding(word_emb)
        return word_emb + pos_emb  # (batch_size, seq_len, d_model)

### Layer Normalization

$$
\text{LayerNorm}(x) = \frac{x - \mu}{\sigma + \epsilon} \cdot \gamma + \beta
$$


In [None]:
class LayerNormalization(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        self.eps = config.eps

        self.gamma = nn.Parameter(torch.ones(config.d_model))  # (d_model,)
        self.beta = nn.Parameter(torch.zeros(config.d_model))  # (d_model,)

    def _compute_mean_std(self, x):
        """
        Compute mean and standard deviation for the input tensor x
        On the last dimension (features)
        x: (batch_size, seq_len, d_model)
        Output:
            mean: (batch_size, seq_len, 1)
            std: (batch_size, seq_len, 1)
        """
        mean = x.mean(dim=-1, keepdim=True)
        # std = x.std(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        std = torch.sqrt(var + self.eps)  # manual epsilon for extra safety
        return mean, std

    def forward(self, x):
        mean, std = self._compute_mean_std(x)

        # print(x.shape, x)
        # normalize x: (batch_size, seq_len, d_model)
        if torch.isnan(std).any():
            print("❌ NaN in LayerNorm std")
            print("std min:", std.min().item(), "max:", std.max().item())
        normalized_x = (x - mean) / std  # Avoid division by zero

        return normalized_x * self.gamma + self.beta  # (batch_size, seq_len, d_model)

### Feedforward Neural Network


In [None]:
class FFN(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        self.ln1 = nn.Linear(config.d_model, config.d_ff, bias=True)
        self.ln2 = nn.Linear(config.d_ff, config.d_model, bias=True)

    def forward(self, x):
        x = F.relu(self.ln1(x))  # Apply ReLU activation
        x = self.ln2(x)  # Linear transformation
        return x  # (batch_size, seq_len, d_model)

In [None]:
def scaled_dot_product_attention(q, k, v, mask=None):
    """
    Scaled Dot-Product Attention
    q: (batch_size, num_heads, seq_len_q, d_k)
    k: (batch_size, num_heads, seq_len_k, d_k)
    v: (batch_size, num_heads, seq_len_v, d_v)
    mask: (batch_size, 1, seq_len_q, seq_len_k) or None
    """
    d_k = k.shape[-1]

    scores = einops.einsum(
        q,
        k,
        "batch heads seq_len_q d_k, batch heads seq_len_k d_k -> batch heads seq_len_q seq_len_k",
    )

    scores = scores / math.sqrt(d_k)  # Scale the scores
    if mask is not None:
        scores = scores.masked_fill(mask, float("-inf"))  # Apply mask if provided

    scores = F.softmax(scores, dim=-1)  # Apply softmax to get attention weights

    output = einops.einsum(
        scores,
        v,
        "batch heads seq_len_q seq_len_k, batch heads seq_len_k d_v -> batch heads seq_len_q d_v",
    )

    return output

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        assert (
            config.d_model % config.num_heads == 0
        ), "d_model must be divisible by num_heads"
        self.d_k = config.d_model // config.num_heads  # Dimension of each head
        self.num_heads = config.num_heads

        self.qkv_proj = nn.Linear(
            config.d_model, config.d_model * 3, bias=True
        )  # (d_model, d_model * 3)

        self.out_proj = nn.Linear(config.d_model, config.d_model, bias=True)

    def forward(self, x, mask=None):
        """
        x: (batch_size, seq_len, d_model)
        mask: (batch_size, 1, seq_len_q, seq_len_k) or None
        """
        batch_size, seq_len, _ = x.size()

        q, k, v = map(
            lambda t: einops.rearrange(
                t,
                "batch seq_len (heads d_k) -> batch heads seq_len d_k",
                heads=self.num_heads,
            ),
            self.qkv_proj(x).chunk(3, dim=-1),
        )  # (batch, num_heads, seq_len, d_k)

        # Compute attention
        attn_output = scaled_dot_product_attention(q, k, v, mask)

        # Rearrange back to (batch_size, seq_len, d_model)
        attn_output = einops.rearrange(
            attn_output,
            "batch heads seq_len d_v -> batch seq_len (heads d_v)",
            heads=self.num_heads,
        )

        output = self.out_proj(attn_output)  # (batch_size, seq_len, d_model)
        return output  # (batch_size, seq_len, d_model)

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        assert (
            config.d_model % config.num_heads == 0
        ), "d_model must be divisible by num_heads"
        self.d_k = config.d_model // config.num_heads  # Dimension of each head

        self.num_heads = config.num_heads

        self.q_proj = nn.Linear(
            config.d_model, config.d_model, bias=True
        )  # (d_model, d_model)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=True)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=True)
        self.out_proj = nn.Linear(config.d_model, config.d_model, bias=True)

    def forward(self, query, key, value, mask=None):
        """
        query: (batch_size, seq_len_q, d_model)
        key: (batch_size, seq_len_k, d_model)
        value: (batch_size, seq_len_v, d_model)
        mask: (batch_size, 1, seq_len_q, seq_len_k) or None
        """
        batch_size = query.size(0)

        q = einops.rearrange(
            self.q_proj(query),
            "batch seq_len_q (heads d_k) -> batch heads seq_len_q d_k",
            heads=self.num_heads,
        )

        k = einops.rearrange(
            self.k_proj(key),
            "batch seq_len_k (heads d_k) -> batch heads seq_len_k d_k",
            heads=self.num_heads,
        )

        v = einops.rearrange(
            self.v_proj(value),
            "batch seq_len_v (heads d_v) -> batch heads seq_len_v d_v",
            heads=self.num_heads,
        )

        # Compute attention
        attn_output = scaled_dot_product_attention(q, k, v, mask)

        # Rearrange back to (batch_size, seq_len_q, d_model)
        attn_output = einops.rearrange(
            attn_output,
            "batch heads seq_len_q d_v -> batch seq_len_q (heads d_v)",
            heads=self.num_heads,
        )

        return self.out_proj(attn_output)  # (batch_size, seq_len_q, d_model)

### Encoder Block


In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        self.self_attn = MultiHeadAttention(config)
        self.ffn = FFN(config)
        self.ln1 = LayerNormalization(config)
        self.ln2 = LayerNormalization(config)

    def forward(self, x, mask=None):
        out = self.self_attn(x, mask)  # (batch_size, seq_len, d_model)
        out = self.ln1(out + x)  # Add & Norm
        out = self.ffn(out)  # (batch_size, seq_len, d_model
        out = self.ln2(out + x)  # Add & Norm
        return out  # (batch_size, seq_len, d_model

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        self.self_attn = MultiHeadAttention(config)
        self.cross_attn = CrossAttention(config)
        self.ffn = FFN(config)
        self.ln1 = LayerNormalization(config)
        self.ln2 = LayerNormalization(config)
        self.ln3 = LayerNormalization(config)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        out = self.self_attn(x, tgt_mask)  # Self-attention
        out = self.ln1(out + x)  # Add & Norm

        out = self.cross_attn(out, enc_output, enc_output, src_mask)  # Cross-attention
        out = self.ln2(out + x)  # Add & Norm

        out = self.ffn(out)  # Feedforward
        out = self.ln3(out + x)  # Add & Norm

        return out  # (batch_size, seq_len, d_model)

In [None]:
class Encoder(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        self.embedding = Embedding(config)
        self.layers = nn.ModuleList(
            [EncoderBlock(config) for _ in range(config.num_layers)]
        )
        self.ln = LayerNormalization(config)

    def forward(self, x, mask=None):
        """
        x: (batch_size, seq_len)
        """
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, mask)

        x = self.ln(x)  # Final Layer Normalization
        return x  # (batch_size, seq_len, d_model)


class Decoder(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        self.embedding = Embedding(config, is_tgt=True)
        self.layers = nn.ModuleList(
            [DecoderBlock(config) for _ in range(config.num_layers)]
        )
        self.ln = LayerNormalization(config)

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        """
        x: (batch_size, seq_len)
        enc_output: (batch_size, seq_len, d_model)
        """
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x, enc_output, src_mask=src_mask, tgt_mask=tgt_mask)
        x = self.ln(x)  # Final Layer Normalization
        return x  # (batch_size, seq_len, d_model)

In [None]:
class Transformer(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()

        self.encoder = Encoder(config)
        self.decoder = Decoder(config)
        self.output_layer = nn.Linear(config.d_model, config.tgt_vocab_size)

        self._tie_weight()

    def _tie_weight(self):
        """
        Tie the weights of the output layer with the embedding layer.
        This is a common practice in Transformer models to reduce the number of parameters.
        """
        self.output_layer.weight = (
            self.decoder.embedding.word_embedding.embedding.weight
        )

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        src: (batch_size, src_seq_len)
        tgt: (batch_size, tgt_seq_len)
        mask: (batch_size, 1, tgt_seq_len, src_seq_len) or None
        """
        enc_output = self.encoder(src, src_mask)  # (batch_size, src_seq_len, d_model)
        dec_output = self.decoder(
            tgt, enc_output, src_mask, tgt_mask
        )  # (batch_size, tgt_seq_len, d_model)
        output = self.output_layer(
            dec_output
        )  # (batch_size, tgt_seq_len, tgt_vocab_size)
        return output  # Final output logits

In [None]:
def create_causal_mask(seq_len_q, seq_len_k):
    """
    Create a causal mask for the attention mechanism.
    seq_len_q: Length of the query sequence
    seq_len_k: Length of the key sequence
    """
    mask = torch.triu(torch.ones(seq_len_q, seq_len_k), diagonal=1).bool()
    return mask.unsqueeze(0)  # (1, 1, seq_len_q, seq_len_k)


def create_padding_mask(x, padding_idx=0):
    """
    Create a padding mask for the attention mechanism.
    seq_len: Length of the sequence
    padding_idx: Index used for padding (default is 0)
    """
    mask = (x == padding_idx).unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len)
    return mask  # (1, 1, seq_len, seq_len)

In [None]:
transformer = Transformer(ModelConfig())

src = torch.randint(
    0, ModelConfig().src_vocab_size, (32, 10)
)  # (batch_size, src_seq_len
tgt = torch.randint(
    0, ModelConfig().tgt_vocab_size, (32, 15)
)  # (batch_size, tgt_seq_len

causal_mask = create_causal_mask(tgt.size(1), tgt.size(1))
output = transformer(src, tgt, tgt_mask=causal_mask)

## Dataset Preparation


In [None]:
# from datasets import load_dataset
# import os


# dataset = load_dataset(
#     "iwslt2017",
#     "iwslt2017-en-zh",
#     download_mode="force_redownload",
#     trust_remote_code=True,
# )

In [None]:
dataset["train"]

In [None]:
from datasets import load_dataset
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders, processors
from tokenizers.normalizers import Sequence, NFKC
from tokenizers.pre_tokenizers import Whitespace

### Train BPE Tokenizer


In [None]:
# en_texts = []
# zh_texts = []

# for example in dataset["train"]:
#     en_texts.append(example["translation"]["en"])
#     zh_texts.append(example["translation"]["zh"])

# with open("en_corpus.txt", "w", encoding="utf-8") as f:
#     for line in en_texts:
#         f.write(line.strip() + "\n")

# with open("zh_corpus.txt", "w", encoding="utf-8") as f:
#     for line in zh_texts:
#         f.write(line.strip() + "\n")


def load_or_train_bpe_tokenizer(corpus_file, vocab_size, save_path):
    if os.path.exists(save_path):
        print(f"Loading tokenizer from {save_path}")
        tokenizer = Tokenizer.from_file(save_path)
        return tokenizer
    else:
        tokenizer = Tokenizer(models.BPE())
        tokenizer.normalizer = NFKC()
        tokenizer.pre_tokenizer = Whitespace()

        trainer = trainers.BpeTrainer(
            vocab_size=vocab_size, special_tokens=["<pad>", "<unk>", "<s>", "</s>"]
        )

        tokenizer.train([corpus_file], trainer)
        tokenizer.save(save_path)
        print(f"Saved tokenizer to {save_path}")

        return tokenizer


en_tokenizer = load_or_train_bpe_tokenizer(
    "en_corpus.txt", vocab_size=16000, save_path="en_bpe.json"
)
zh_tokenizer = load_or_train_bpe_tokenizer(
    "zh_corpus.txt", vocab_size=16000, save_path="zh_bpe.json"
)

In [None]:
import torch
from torch.utils.data import Dataset


class TranslationDataset(Dataset):
    def __init__(
        self,
        raw_dataset,
        tokenizer_src,
        tokenizer_tgt,
        src_lang="en",
        tgt_lang="ch",
        seq_len=128,
    ):
        self.ds = raw_dataset
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.seq_len = seq_len

        self.pad_id_src = (
            tokenizer_src.token_to_id("<pad>")
            if tokenizer_src.token_to_id("<pad>") is not None
            else 0
        )
        self.pad_id_tgt = (
            tokenizer_tgt.token_to_id("<pad>")
            if tokenizer_tgt.token_to_id("<pad>") is not None
            else 0
        )
        self.sos_id = tokenizer_tgt.token_to_id("<s>") or 1
        self.eos_id = tokenizer_tgt.token_to_id("</s>") or 2

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        # item = self.ds[idx]["translation"]
        item = self.ds[idx]

        src_text = item[self.src_lang]
        tgt_text = item[self.tgt_lang]

        src_ids = self.tokenizer_src.encode(src_text).ids
        if len(src_ids) > self.seq_len:
            src_ids = src_ids[: self.seq_len]
        src_ids = src_ids + [self.pad_id_src] * (self.seq_len - len(src_ids))

        tgt_ids = self.tokenizer_tgt.encode(tgt_text).ids
        if len(tgt_ids) > self.seq_len - 1:
            tgt_ids = tgt_ids[: self.seq_len - 1]
        tgt_ids = (
            tgt_ids
            + [self.eos_id]
            + [self.pad_id_tgt] * (self.seq_len - len(tgt_ids) - 1)
        )

        decoder_ids = self.tokenizer_tgt.encode(tgt_text).ids
        if len(decoder_ids) > self.seq_len - 1:
            decoder_ids = decoder_ids[: self.seq_len - 1]
        decoder_ids = (
            [self.sos_id]
            + decoder_ids
            + [self.pad_id_tgt] * (self.seq_len - len(decoder_ids) - 1)
        )

        src_ids = torch.tensor(src_ids, dtype=torch.int64)
        tgt_ids = torch.tensor(tgt_ids, dtype=torch.int64)
        decoder_ids = torch.tensor(decoder_ids, dtype=torch.int64)

        assert decoder_ids.size(0) == self.seq_len
        assert tgt_ids.size(0) == self.seq_len
        assert src_ids.size(0) == self.seq_len

        return {
            "input_ids": src_ids,
            "labels": tgt_ids,
            "decoder_input_ids": decoder_ids,
            "encoder_mask": create_padding_mask(src_ids, self.pad_id_src),
            "decoder_mask": create_causal_mask(decoder_ids.size(0), decoder_ids.size(0))
            & create_padding_mask(decoder_ids, self.pad_id_tgt),
        }

In [None]:
dataset["train"][:1000]["translation"]

In [None]:
# zh_tokenizer = Tokenizer.from_file("zh_bpe.json")
# en_tokenizer = Tokenizer.from_file("en_bpe.json")


translation_dataset = TranslationDataset(
    # raw_dataset=dataset["train"],
    raw_dataset=dataset["train"][:1000]["translation"],
    tokenizer_src=zh_tokenizer,
    tokenizer_tgt=en_tokenizer,
    src_lang="zh",
    tgt_lang="en",
    seq_len=128,
)

In [None]:
sample = translation_dataset[0]

input_ids = sample["input_ids"]
decoder_input_ids = sample["decoder_input_ids"]
labels = sample["labels"]

In [None]:
zh_tokenizer.decode(input_ids.tolist())

In [None]:
en_tokenizer.decode(decoder_input_ids.tolist())

In [None]:
en_tokenizer.decode(labels.tolist())

In [None]:
train_dataloader = torch.utils.data.DataLoader(
    translation_dataset,
    batch_size=32,
    shuffle=True,
)

## Training Procedure


### Loss Function with Label Smoothing


In [None]:
class LabelSmoothing(nn.Module):
    def __init__(self, config: ModelConfig, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
        self.vocab_size = config.tgt_vocab_size
        self.eps = smoothing / (self.vocab_size - 1)

    def forward(self, logits, target):
        """
        logits: (batch_size, seq_len, vocab_size)
        target: (batch_size, seq_len)
        """
        log_probs = F.log_softmax(logits, dim=-1)  # (batch_size, seq_len, vocab_size)
        nll_loss = -log_probs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1)

        smooth_loss = -log_probs.mean(dim=-1)  # Average over vocabulary
        loss = (1.0 - self.smoothing) * nll_loss + self.eps * smooth_loss

        return loss.mean()  # Average over batch and sequence length

In [None]:
class CrossEntropyLossWithLabelSmoothing(nn.Module):
    def __init__(self, config: ModelConfig, smoothing=0.1):
        super().__init__()
        self.label_smoothing = LabelSmoothing(config, smoothing)

    def forward(self, logits, target):
        """
        logits: (batch_size, seq_len, vocab_size)
        target: (batch_size, seq_len)
        """
        return self.label_smoothing(logits, target)  # Compute loss with label smoothing

### Adam Optimizer


$$\theta_{t+1} = \theta_t - \eta \cdot \frac{m_{t,i}}{\sqrt{v_{t,i}}+\epsilon}$$

where:

\begin{align*}
\hat{m}*{t,i} &= \frac{m*{t,i}}{1-\beta*1^t} \\
\hat{v}_{t,i} &= \frac{v\_{t,i}}{1-\beta_2^t}
\end{align_}


In [None]:
class Adam:
    def __init__(self, model_params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
        self.model_params = list(model_params)
        self.lr = lr
        self.beta_1, self.beta_2 = betas
        self.eps = eps
        self.avg_grads = [torch.zeros_like(p) for p in self.model_params]
        self.avg_squares = [torch.zeros_like(p) for p in self.model_params]
        self.n_steps = 0

    def zero_grad(self):
        for p in self.model_params:
            p.grad = None

    @torch.no_grad()
    def step(self):
        self.n_steps += 1  # increment ONCE per step

        for p, m, v in zip(self.model_params, self.avg_grads, self.avg_squares):
            if p.grad is None:
                continue

            # Update moving averages
            m.mul_(self.beta_1).add_(p.grad, alpha=1 - self.beta_1)
            v.mul_(self.beta_2).addcmul_(p.grad, p.grad, value=1 - self.beta_2)

            # Bias correction
            m_hat = m / (1 - self.beta_1**self.n_steps)
            v_hat = v / (1 - self.beta_2**self.n_steps)

            # Parameter update
            p.addcdiv_(m_hat, v_hat.sqrt() + self.eps, value=-self.lr)

## Training Loop


In [None]:
from tqdm.autonotebook import tqdm

In [None]:
def train_step(
    model,
    optimizer,
    criterion,
    data,
):
    model.train()

    optimizer.zero_grad()

    input_ids = data["input_ids"].to(DEVICE)
    decoder_input_ids = data["decoder_input_ids"].to(DEVICE)
    enc_mask = data["encoder_mask"].to(DEVICE)
    dec_mask = data["decoder_mask"].to(DEVICE)

    labels = data["labels"].to(DEVICE)

    # Forward pass
    logits = model(input_ids, decoder_input_ids, src_mask=enc_mask, tgt_mask=dec_mask)
    print(torch.isnan(logits).any(), torch.isinf(logits).any())
    # print("logits NaN:", torch.isnan(logits).any())
    # print("logits max:", torch.nanmax(logits).item())
    # print("logits min:", torch.nanmin(logits).item())

    # Compute loss # Debugging line
    loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))

    # Backward pass
    # loss.backward()
    with torch.autograd.set_detect_anomaly(True):
        loss.backward()

    # Update parameters
    optimizer.step()

    return loss.item()  # Return the loss value for logging


def train(model, dataset, optimizer, criterion, num_epochs=10, batch_size=32):
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0.0
        for batch in tqdm(dataset, desc=f"Epoch {epoch + 1}/{num_epochs}"):
            loss = train_step(model, optimizer, criterion, batch)
            total_loss += loss

        avg_loss = total_loss / (len(dataset) // batch_size)
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}")

In [None]:
# Initialize model, optimizer, and loss function
model = Transformer(ModelConfig()).to(DEVICE)
optimizer = Adam(model.parameters(), lr=1e-5)
# criterion = CrossEntropyLossWithLabelSmoothing(ModelConfig(), smoothing=0.1)
criterion = nn.CrossEntropyLoss(ignore_index=0).to(
    DEVICE
)  # Use standard CrossEntropyLoss with padding index

# Train the model
train(model, train_dataloader, optimizer, criterion, num_epochs=10, batch_size=32)
# Initialize model, optimizer, and loss function