In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
import random
import os
import re
import matplotlib.pyplot as plt


class CustomTokenizer:
    def __init__(self, specials=["<PAD>", "<UNK>", "<BOS>", "<EOS>"]):
        self.specials = specials
        self.token2idx = {}
        self.idx2token = []
        for sp in specials:
            self.add_token(sp)
        self.pad_token = "<PAD>"
        self.unk_token = "<UNK>"
        self.bos_token = "<BOS>"
        self.eos_token = "<EOS>"

    def add_token(self, token):
        if token not in self.token2idx:
            self.token2idx[token] = len(self.idx2token)
            self.idx2token.append(token)

    def build_vocab(self, text_lines):
        """
        Build the vocab by lowercasing and regex-based tokenization.
        """
        for line in text_lines:
            line = line.lower()  # Convert to lowercase
            # Extract "word" tokens OR any single non-whitespace char (e.g. punctuation)
            tokens = re.findall(r"\w+|[^\w\s]", line)
            for t in tokens:
                self.add_token(t)

    def encode(self, text_line, add_bos_eos=False):
        """
        Convert a text line into token IDs using the same regex approach.
        """
        text_line = text_line.lower()
        tokens = re.findall(r"\w+|[^\w\s]", text_line)
        ids = []
        if add_bos_eos:
            ids.append(self.token2idx[self.bos_token])
        for t in tokens:
            if t in self.token2idx:
                ids.append(self.token2idx[t])
            else:
                ids.append(self.token2idx[self.unk_token])
        if add_bos_eos:
            ids.append(self.token2idx[self.eos_token])
        return ids

    def decode(self, token_ids):
        """Convert a list of token IDs back into a string."""
        tokens = []
        for i in token_ids:
            tokens.append(self.idx2token[i])
        return " ".join(tokens)

    def __len__(self):
        return len(self.idx2token)

class LabelSmoothingLoss(nn.Module):
    def __init__(self, smoothing, vocab_size, ignore_index=None):
        super().__init__()
        self.smoothing = smoothing
        self.vocab_size = vocab_size
        self.ignore_index = ignore_index

    def forward(self, log_probs, target):
        """
        log_probs: (batch_size*seq_len, vocab_size) [already passed through log_softmax]
        target:    (batch_size*seq_len)
        """
        # Confidence for the true class:
        confidence = 1.0 - self.smoothing
        # Amount to distribute over all other classes:
        smoothing_value = self.smoothing / (self.vocab_size - 1)

        # Create full one-hot distribution with smoothing
        with torch.no_grad():
            true_dist = torch.full_like(log_probs, smoothing_value)
            true_dist.scatter_(1, target.unsqueeze(1), confidence)

            # If ignore_index is set, zero out those entries
            if self.ignore_index is not None:
                ignore_mask = (target == self.ignore_index).unsqueeze(1)
                true_dist.masked_fill_(ignore_mask, 0.0)

        loss = -torch.sum(true_dist * log_probs, dim=1).mean()
        return loss

def read_lines_from_file(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        lines = f.readlines()
    # Remove empty lines and strip
    lines = [l.strip() for l in lines if l.strip() != ""]
    return lines

def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
    """
    Cosine decay to 0, with a linear warm-up at the start.
    """
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        # progress after warmup
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        # Cosine decay from 1 -> 0
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def prepare_dataset(filepath, tokenizer, seq_len=32):
    lines = read_lines_from_file(filepath)
    # Convert each line into token IDs
    token_sequences = []
    for line in lines:
        token_ids = tokenizer.encode(line)
        if len(token_ids) > 0:
            token_sequences.append(token_ids)

    # Now we create training examples of length seq_len (for x)
    # and the next token as target (shifted by 1).
    all_input_ids = []
    all_target_ids = []

    # We’ll just concatenate everything for language modeling.
    # Alternatively, you can handle each line individually.
    concat_ids = []
    for seq in token_sequences:
        concat_ids.extend(seq + [tokenizer.token2idx[tokenizer.eos_token]])  # end with <EOS>

    # Slide over the concatenated IDs by seq_len
    for i in range(0, len(concat_ids) - seq_len):
        x = concat_ids[i : i + seq_len]
        y = concat_ids[i + 1 : i + seq_len + 1]
        if len(x) == seq_len and len(y) == seq_len:
            all_input_ids.append(x)
            all_target_ids.append(y)

    return torch.LongTensor(all_input_ids), torch.LongTensor(all_target_ids)

def generate_causal_mask(seq_len):
    """
    Returns a (seq_len, seq_len) mask where
    positions (i, j) are 0 if j <= i, and -inf if j > i.
    This ensures each token can only attend to itself
    and previous tokens, not future ones.
    """
    # shape: (seq_len, seq_len)
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # shape (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        seq_len = x.size(1)
        # Add positional encoding up to seq_len
        x = x + self.pe[:, :seq_len, :]
        return x

def scaled_dot_product_attention(q, k, v, mask=None):
    """
    q, k, v: (batch_size, n_heads, seq_len, d_head)
    mask: (1, 1, seq_len, seq_len) or (batch_size, 1, seq_len, seq_len)
    """
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores + mask  # add -inf where mask is True
    attn = torch.softmax(scores, dim=-1)
    output = torch.matmul(attn, v)  # (batch_size, n_heads, seq_len, d_head)
    return output, attn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        x shape: (batch_size, seq_len, d_model)
        mask shape: (seq_len, seq_len) or broadcastable
        """
        bsz, seq_len, _ = x.size()
        q = self.w_q(x)  # (bsz, seq_len, d_model)
        k = self.w_k(x)
        v = self.w_v(x)

        # Reshape into (bsz, num_heads, seq_len, d_head)
        q = q.reshape(bsz, seq_len, self.num_heads, self.d_head).transpose(1,2)
        k = k.reshape(bsz, seq_len, self.num_heads, self.d_head).transpose(1,2)
        v = v.reshape(bsz, seq_len, self.num_heads, self.d_head).transpose(1,2)

        # Expand mask if necessary
        if mask is not None:
            # (seq_len, seq_len) => (bsz, num_heads, seq_len, seq_len)
            mask = mask.unsqueeze(0).unsqueeze(1)  # shape: (1, 1, seq_len, seq_len)

        # Apply scaled dot-product attention
        attn_output, attn = scaled_dot_product_attention(q, k, v, mask=mask)
        # attn_output: (bsz, n_heads, seq_len, d_head)

        # Combine heads
        attn_output = attn_output.transpose(1,2).contiguous().reshape(bsz, seq_len, self.d_model)
        # Final linear
        output = self.w_o(attn_output)  # (bsz, seq_len, d_model)
        output = self.dropout(output)
        return output

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ln1 = nn.LayerNorm(d_model)
        self.ff = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        # Self-attention
        attn_out = self.attn(x, mask=mask)
        x = x + attn_out  # Residual
        x = self.ln1(x)

        # Feed-forward
        ff_out = self.ff(x)
        x = x + ff_out  # Residual
        x = self.ln2(x)

        return x

class TransformerLanguageModel(nn.Module):
    def __init__(self, vocab_size, d_model=256, num_heads=4, d_ff=1024, num_layers=4, dropout=0.1, pad_idx=0):
        super().__init__()
        self.pad_idx = pad_idx
        self.d_model = d_model

        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])
        self.ln = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

        # Weight initialization (optional but recommended)
        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, x):
        """
        x shape: (batch_size, seq_len)
        Returns: logits of shape (batch_size, seq_len, vocab_size)
        """
        bsz, seq_len = x.size()
        # Create causal mask
        mask = generate_causal_mask(seq_len).to(x.device)  # (seq_len, seq_len)

        # Embeddings
        x = self.token_emb(x) * math.sqrt(self.d_model)  # (bsz, seq_len, d_model)
        x = self.pos_enc(x)
        x = self.dropout(x)

        # Pass through each Transformer block
        for layer in self.layers:
            x = layer(x, mask=mask)

        # Final layer norm, then project to vocab
        x = self.ln(x)
        logits = self.fc_out(x)  # (bsz, seq_len, vocab_size)
        return logits

def calculate_loss_and_perplexity(logits, targets, pad_idx=0):
    """
    logits: (batch_size, seq_len, vocab_size)
    targets: (batch_size, seq_len)
    """
    # Flatten logits to (batch_size*seq_len, vocab_size)
    vocab_size = logits.size(-1)
    logits = logits.view(-1, vocab_size)
    targets = targets.view(-1)

    # We ignore pad tokens in the loss
    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_idx)
    loss = loss_fn(logits, targets)

    # Perplexity
    # Only compute perplexity on non-PAD
    mask = (targets != pad_idx)
    n_tokens = mask.sum().item()
    # cross_entropy = total_loss / n_tokens
    # But we can just re-implement the cross-entropy on masked positions
    with torch.no_grad():
        log_probs = nn.functional.log_softmax(logits, dim=-1)
        log_probs = log_probs[mask]
        targets_masked = targets[mask]
        negative_log_likelihood = -log_probs[range(len(targets_masked)), targets_masked]
        avg_nll = negative_log_likelihood.mean().item()
        ppl = math.exp(avg_nll)

    return loss, ppl

def train_model(train_x, train_y, dev_x, dev_y, model, tokenizer,
                epochs=10, batch_size=32, lr=1e-4, pad_idx=0,
                smoothing=0.1, warmup_ratio=0.1, weight_decay=1e-4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # 1) Adam with weight decay
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    # 2) Set up scheduler
    total_steps = epochs * (train_x.size(0) // batch_size)
    warmup_steps = int(warmup_ratio * total_steps)
    scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)

    # 3) Label smoothing criterion
    vocab_size = len(tokenizer)
    ls_criterion = LabelSmoothingLoss(smoothing, vocab_size, ignore_index=pad_idx)

    train_losses, dev_losses = [], []
    train_ppls, dev_ppls = [], []

    def get_batches(x_data, y_data, batch_size):
        for i in range(0, x_data.size(0), batch_size):
            yield x_data[i:i+batch_size], y_data[i:i+batch_size]

    global_step = 0

    for epoch in range(1, epochs+1):
        model.train()
        epoch_loss = 0.0
        epoch_ppl = 0.0
        count = 0

        # Shuffle training data each epoch
        idx = torch.randperm(train_x.size(0))
        train_x = train_x[idx]
        train_y = train_y[idx]

        for bx, by in get_batches(train_x, train_y, batch_size):
            bx, by = bx.to(device), by.to(device)
            optimizer.zero_grad()

            # Forward pass
            logits = model(bx)  # (B, seq, vocab)
            # We'll flatten to compute label-smoothed cross-entropy
            log_probs = F.log_softmax(logits.view(-1, vocab_size), dim=-1)
            loss = ls_criterion(log_probs, by.view(-1))

            # Backprop
            loss.backward()
            # Gradient clipping
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()  # step the LR scheduler per iteration
            global_step += 1

            epoch_loss += loss.item()
            count += 1

            # (Optional) compute perplexity from this batch for logging
            with torch.no_grad():
                # gather the correct tokens' log-probs (ignoring pad)
                mask = (by.view(-1) != pad_idx)
                masked_log_probs = log_probs[mask, :]
                masked_targets = by.view(-1)[mask]
                # negative log-likelihood
                nll = -masked_log_probs[range(masked_log_probs.size(0)), masked_targets].mean().item()
                epoch_ppl += math.exp(nll)

        train_losses.append(epoch_loss / count)
        train_ppls.append(epoch_ppl / count)

        # Validation
        model.eval()
        dev_loss = 0.0
        dev_ppl_sum = 0.0
        dev_count = 0
        with torch.no_grad():
            for bx, by in get_batches(dev_x, dev_y, batch_size):
                bx, by = bx.to(device), by.to(device)
                logits = model(bx)
                log_probs = F.log_softmax(logits.view(-1, vocab_size), dim=-1)
                loss = ls_criterion(log_probs, by.view(-1))
                dev_loss += loss.item()

                # perplexity
                mask = (by.view(-1) != pad_idx)
                masked_log_probs = log_probs[mask, :]
                masked_targets = by.view(-1)[mask]
                nll = -masked_log_probs[range(masked_log_probs.size(0)), masked_targets].mean().item()
                dev_ppl_sum += math.exp(nll)

                dev_count += 1

        dev_loss /= dev_count
        dev_ppl = dev_ppl_sum / dev_count
        dev_losses.append(dev_loss)
        dev_ppls.append(dev_ppl)

        print(f"[Epoch {epoch}] train_loss={train_losses[-1]:.4f}, train_ppl={train_ppls[-1]:.2f}, "
              f"dev_loss={dev_losses[-1]:.4f}, dev_ppl={dev_ppls[-1]:.2f}")

        # Simple generation test
        sample_context = "First Citizen :"
        generated_text = generate_text(model, sample_context, tokenizer, max_new_tokens=20)
        print("Sample generation:", generated_text)

    return train_losses, dev_losses, train_ppls, dev_ppls


def generate_text(model, context_text, tokenizer, max_new_tokens=20, temperature=1.0):
    model.eval()
    device = next(model.parameters()).device

    # Encode the context
    context_ids = tokenizer.encode(context_text)
    context = torch.LongTensor(context_ids).unsqueeze(0).to(device)  # shape (1, len_context)

    # Generate
    for _ in range(max_new_tokens):
        with torch.no_grad():
            logits = model(context)  # (1, seq_len, vocab_size)
            logits = logits[:, -1, :]  # take the last position
            # Optionally apply temperature
            logits = logits / temperature
            probs = nn.functional.softmax(logits, dim=-1)
            next_token_id = torch.multinomial(probs, num_samples=1)
        # Append next token
        context = torch.cat([context, next_token_id], dim=1)

    # Convert IDs to text
    generated_ids = context[0].cpu().numpy().tolist()
    return tokenizer.decode(generated_ids)

def main():
    # === 1) Read data ===
    train_file = "/kaggle/input/shakespear/shakespear_train.txt"
    dev_file = "/kaggle/input/shakespear/shakespear_dev.txt"
    test_file = "test.txt"  # provided at demo

    train_lines = read_lines_from_file(train_file)
    dev_lines = read_lines_from_file(dev_file)

    # === 2) Build tokenizer ===
    tokenizer = CustomTokenizer()
    tokenizer.build_vocab(train_lines + dev_lines)
    vocab_size = len(tokenizer)
    print(f"Vocab size = {vocab_size}")

    # === 3) Prepare the datasets ===
    train_x, train_y = prepare_dataset(train_file, tokenizer, seq_len=32)
    dev_x, dev_y = prepare_dataset(dev_file, tokenizer, seq_len=32)

    print("train_x shape:", train_x.shape, "train_y shape:", train_y.shape)
    print("dev_x shape:", dev_x.shape, "dev_y shape:", dev_y.shape)

    # === 4) Create the model ===
    d_model = 128
    num_heads = 8
    d_ff = 512
    num_layers = 2
    dropout = 0.2
    pad_idx = tokenizer.token2idx[tokenizer.pad_token]
    model = TransformerLanguageModel(
        vocab_size=vocab_size,
        d_model=d_model,
        num_heads=num_heads,
        d_ff=d_ff,
        num_layers=num_layers,
        dropout=dropout,
        pad_idx=pad_idx
    )

    # === 5) Train the model ===
    epochs = 10
    batch_size = 32
    lr = 1e-3  # you can start a bit higher since we have warm-up
    train_losses, dev_losses, train_ppls, dev_ppls = train_model(
        train_x, train_y, dev_x, dev_y,
        model,
        tokenizer,
        epochs=epochs,
        batch_size=batch_size,
        lr=lr,
        pad_idx=pad_idx,
        smoothing=0.1,         # label smoothing
        warmup_ratio=0.1,      # 10% of total steps are warm-up
        weight_decay=1e-4      # L2 regularization
    )

    # === 6) Plot training & validation losses and perplexities ===
    plt.figure()
    plt.plot(train_losses, label="Train Loss")
    plt.plot(dev_losses, label="Dev Loss")
    plt.title("Loss vs. Epoch")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

    plt.figure()
    plt.plot(train_ppls, label="Train Perplexity")
    plt.plot(dev_ppls, label="Dev Perplexity")
    plt.title("Perplexity vs. Epoch")
    plt.xlabel("Epoch")
    plt.ylabel("Perplexity")
    plt.legend()
    plt.show()

    # === 7) Save the model ===
    model_save_path = "transformer_shakespeare.pt"
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")

    # === 8) Evaluate on the test set (demo) ===
    # We can load the model from disk if we want fresh:
    # model2 = TransformerLanguageModel(vocab_size, d_model, num_heads, d_ff, num_layers, dropout, pad_idx)
    # model2.load_state_dict(torch.load(model_save_path))
    # model2.eval()

    # Let's do test perplexity (assuming test data also prepared):
    if os.path.exists(test_file):
        test_x, test_y = prepare_dataset(test_file, tokenizer, seq_len=32)
        if test_x.size(0) > 0:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model.to(device)
            model.eval()
            test_loss, test_ppl = 0.0, 0.0
            test_count = 0
            for i in range(0, test_x.size(0), batch_size):
                bx = test_x[i:i+batch_size].to(device)
                by = test_y[i:i+batch_size].to(device)
                with torch.no_grad():
                    logits = model(bx)
                    loss, ppl = calculate_loss_and_perplexity(logits, by, pad_idx=pad_idx)
                test_loss += loss.item()
                test_ppl += ppl
                test_count += 1
            test_loss /= test_count
            test_ppl /= test_count
            print(f"Test Loss: {test_loss:.4f}, Test PPL: {test_ppl:.2f}")

        # Also do generation from each line in test.txt
        test_lines = read_lines_from_file(test_file)
        for line in test_lines:
            prompt = line.strip()
            gen_text = generate_text(model, prompt, tokenizer, max_new_tokens=20)
            print(f"Prompt: {prompt}")
            print(f"Generated: {gen_text}")
            print("----")


if __name__ == "__main__":
    main()

In [None]:
def calculate_loss_and_perplexity(logits, targets, criterion, pad_idx=0):
    """
    logits:  (batch_size, seq_len, vocab_size)
    targets: (batch_size, seq_len)
    criterion: a loss function (e.g., LabelSmoothingLoss)

    Returns: (loss, ppl) for this batch
    """
    batch_size, seq_len, vocab_size = logits.size()
    log_probs = F.log_softmax(logits.view(-1, vocab_size), dim=-1)
    loss = criterion(log_probs, targets.view(-1))

    # ignoring padding tokens
    mask = (targets.view(-1) != pad_idx)
    # Gather the correct token log-probs
    masked_log_probs = log_probs[mask, :]
    masked_targets = targets.view(-1)[mask]

    nll = -masked_log_probs[range(masked_log_probs.size(0)), masked_targets].mean().item()
    ppl = math.exp(nll)
    return loss, ppl

def inference(model_path, test_file, tokenizer, pad_idx=0, max_new_tokens=20, batch_size=32):
    """
    1) Loads the pretrained model from 'model_path'
    2) Reads and preprocesses 'test_file'
    3) Generates 'max_new_tokens' tokens for each line in test_file
    4) Calculates perplexity on the test set
    """
    # 1) Loading the pretrained model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # re-instantiating the same model architecture used during training:
    d_model = 256
    num_heads = 16
    d_ff = 1024
    num_layers = 4
    dropout = 0.2

    model = TransformerLanguageModel(
        vocab_size=len(tokenizer),
        d_model=d_model,
        num_heads=num_heads,
        d_ff=d_ff,
        num_layers=num_layers,
        dropout=dropout,
        pad_idx=pad_idx
    )
    # Loading the weights
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    # 2) Read and preprocess the test file
    if not os.path.exists(test_file):
        print(f"Test file '{test_file}' not found.")
        return

    test_lines = read_lines_from_file(test_file)
    # Prepare dataset
    test_x, test_y = prepare_dataset(test_file, tokenizer, seq_len=32)
    if test_x.size(0) == 0:
        print("No data found in test file (after tokenization).")
        return

    # 3) Generate tokens for each line
    print("\n--- Generation on each test line ---")
    for line in test_lines:
        prompt = line.strip()
        gen_text = generate_text(model, prompt, tokenizer, max_new_tokens=max_new_tokens)
        print(f"Prompt: {prompt}")
        print(f"Generated: {gen_text}")
        print("----")

    # 4) Calculate perplexity on the entire test set ---
    criterion = LabelSmoothingLoss(smoothing=0.1, vocab_size=len(tokenizer), ignore_index=pad_idx)

    test_loss, test_ppl = 0.0, 0.0
    test_count = 0

    def get_test_batches(x_data, y_data, bsz):
        for i in range(0, x_data.size(0), bsz):
            yield x_data[i:i+bsz], y_data[i:i+bsz]

    with torch.no_grad():
        for bx, by in get_test_batches(test_x, test_y, batch_size):
            bx, by = bx.to(device), by.to(device)
            logits = model(bx)  # (batch, seq, vocab)
            loss, ppl = calculate_loss_and_perplexity(logits, by, criterion, pad_idx=pad_idx)
            test_loss += loss.item()
            test_ppl += ppl
            test_count += 1

    test_loss /= test_count
    test_ppl  /= test_count
    print(f"\n--- Test Metrics ---")
    print(f"Test Loss: {test_loss:.4f}, Test Perplexity: {test_ppl:.2f}\n")

def main():
    train_file = "/kaggle/input/shakespear/shakespear_train.txt"
    dev_file   = "/kaggle/input/shakespear/shakespear_dev.txt"

    train_lines = read_lines_from_file(train_file)
    dev_lines   = read_lines_from_file(dev_file)

    tokenizer = CustomTokenizer()
    tokenizer.build_vocab(train_lines + dev_lines)

    model_save_path = "/kaggle/working/transformer_shakespeare.pt"
    test_file = "/kaggle/input/shakespear/shakespear_dev.txt"
    inference(model_save_path, test_file, tokenizer, pad_idx=tokenizer.token2idx["<PAD>"])

if __name__ == "__main__":
    main()