# Story Generation using Transformers

In [None]:
# DIRECTORY STRUCTURE

!mkdir -p experiments

for i in range(1, 6):
   
    !mkdir -p experiments/experiment_{i}
    !mkdir -p experiments/experiment_{i}/model
    !mkdir -p experiments/experiment_{i}/train_logs

!ls -R experiments

In [None]:
#INSTALLING LIBRARIES

!pip install evaluate
!pip install rouge_score

In [None]:
#IMPORTING LIBRARIES

import os
import random
import numpy as np
import torch
from transformers import AutoTokenizer
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
import json
import math
import csv
import time
from torch.utils.data import DataLoader
from typing import Optional
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import  evaluate as evaluate_model
from datasets import load_from_disk
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

In [None]:
#SETTING DEVICE TO GPU

torch.backends.cuda.matmul.allow_tf32 = False
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

In [None]:
# MODEL ARCHITECTURE
N_HEAD = 4
N_LAYER = 4
N_EMBD = 256
VOCAB_SIZE = 50258
SEQ_LENGTH = 384

## DATA AND TRAINING PARAMETERS
BATCH_SIZE = 32
DATA_PCT = 0.7  # percent of training data

MAX_LR = 0.00005  # LEARNING RATE


## EPOCH LEVEL PARAMETERS
EPOCHS = 10
SAVE_EVERY = 1  # save model every x epochs
GENERATE_EVERY = 1 # generate text from model every x epochs

## STEP LEVEL PARAMETERS
COMPUTE_PER_EPOCH = 10  # approx. number of times to print training statistics per epoch

# MODEL LOADING
# If loading, set CHECKPOINT = True and specify LOAD_EPOCH
# If training from scratch, set CHECKPOINT = False and specify LOAD_EPOCH=None
CHECKPOINT = False
LOAD_EPOCH = None
START_EPOCH = LOAD_EPOCH if LOAD_EPOCH is not None else 0

PATH = "/kaggle/working/experiments"

MODEL_NAME = f"bt_{N_LAYER}_LAYERs_{int(DATA_PCT*100)}_DATA_PCT_{N_EMBD}_EMBD_DIM"


In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)   
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed) 
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


### Create shifted inputs (teacher forcing)

# In teacher forcing, regardless of the model's output at each timestep,
# it receives the true value as input for the next timestep. This is efficient because you don't need to run the
# model sequentially, the outputs at the different sequence locations can be computed in parallel.
set_seed()


def load_tokenizer(device):
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        pad_token_id = tokenizer.pad_token_id
        print(f"Added [PAD] token with ID: {pad_token_id}")
    else:
        pad_token_id = tokenizer.pad_token_id
        print(f"Using existing pad token {tokenizer.pad_token} with ID: {pad_token_id}")

    # Verify BOS and EOS tokens exist and get their IDs
    if tokenizer.bos_token is None:
        print("Warning: Tokenizer does not have a BOS token. Generation might be affected.")
    if tokenizer.eos_token is None:
        print("Warning: Tokenizer does not have an EOS token. Generation might be affected.")

    
    tokenizer.bos_token_id = tokenizer.bos_token_id if tokenizer.bos_token is not None else None
    tokenizer.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token is not None else None
    tokenizer.pad_token_id = pad_token_id # Ensure pad_token_id is set/updated

    return tokenizer


def shift_tokens_right(input_ids: torch.Tensor, decoder_start_token_id: int):
    
    if decoder_start_token_id is None:
        raise ValueError("decoder_start_token_id cannot be None.")
    
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()  # cut off last token
    shifted_input_ids[:, 0] = decoder_start_token_id  # add the start token
    return shifted_input_ids

tokenizer = load_tokenizer(device)

def collate_wrapper(batch):
    
    padded_batch = tokenizer.pad(
        {"input_ids": [torch.tensor(i["input_ids"], dtype=torch.long) for i in batch]},
        padding=True, # Pad to the longest sequence in the batch
        return_tensors="pt", # Return PyTorch tensors
    )

    # The padded sequences are our targets for teacher forcing
    # targets shape: (batch_size, padded_seq_len)
    targets = padded_batch["input_ids"]
    # attention_mask = padded_batch["attention_mask"] 

    # Find the index of the first padding token in each sequence
    # torch.sum(targets != tokenizer.pad_token_id, dim=-1) counts non-padding tokens
    # This count is the index *after* the last non-padding token
    first_pad_idx = torch.sum(targets != tokenizer.pad_token_id, dim=-1, keepdim=True) # (batch_size, 1)

    # Handle sequences that are not padded (sum == seq_len). Point to the last position.
    # This is where EOS will be injected if the sequence fills the context window.
    first_pad_idx[first_pad_idx == targets.shape[-1]] = targets.shape[-1] - 1

    
    if tokenizer.eos_token_id is not None:
         targets.scatter_(index=first_pad_idx, dim=-1, value=tokenizer.eos_token_id)
    else:
        print("Warning: EOS token ID is not set. Cannot inject EOS in collate_wrapper.")


    
    # The first input token is BOS.
    if tokenizer.bos_token_id is None:
         raise ValueError("BOS token ID is not set. Cannot create shifted inputs.")
    # 2. Create model inputs by shifting the targets right and prepending BOS
    model_inputs = shift_tokens_right(targets, tokenizer.bos_token_id)

    # Return the created inputs and targets tensors
    # Both shapes: (batch_size, padded_seq_len)
    return model_inputs, targets


In [None]:

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")


# For GPT-Neo, <|endoftext|> (ID 50256) exists and is used for EOS/BOS/PAD by default.
# "[PAD]", will get ID 50257.
if tokenizer.pad_token is None:
     print("Adding [PAD] special token...")
     tokenizer.add_special_tokens({"pad_token": "[PAD]"})
     print(f"Added [PAD] token with ID: {tokenizer.pad_token_id}")
else:
     print(f"Using existing pad token {tokenizer.pad_token} with ID: {tokenizer.pad_token_id}")


if tokenizer.eos_token_id is None and "<|endoftext|>" in tokenizer.vocab:
     print("Setting EOS token ID manually from vocab for <|endoftext|>.")
     tokenizer.eos_token_id = tokenizer.vocab["<|endoftext|>"]

# Set BOS token ID if needed (common to use EOS ID for GPT-like models)
# collate_wrapper needs tokenizer.bos_token_id
if tokenizer.bos_token_id is None and tokenizer.eos_token_id is not None:
     print("Setting BOS token ID to EOS token ID.")
     tokenizer.bos_token_id = tokenizer.eos_token_id


In [None]:

set_seed()

# MODEL ARCHITECTURE DEFINITION


class MLP(nn.Module):
    def __init__(self, n_embd, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Dropout(p=dropout),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(p=dropout),
        )

    def forward(self, x):
        return self.net(x)


class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd, n_head, seq_length, dropout=0.1):
        super().__init__()

        self.n_embd = n_embd
        self.n_head = n_head
        self.head_dim = (n_embd // n_head)  # Dimension of each head's key, query, and value
        
        assert (
            self.head_dim * n_head == self.n_embd
        ), "n_embd must be divisible by n_head"
        
        self.seq_length = seq_length
        self.drop = nn.Dropout(p=dropout)

        self.query = nn.Linear(n_embd, n_embd, bias=False)
        self.key = nn.Linear(n_embd, n_embd, bias=False)
        self.value = nn.Linear(n_embd, n_embd, bias=False)
        self.out = nn.Linear(n_embd, n_embd, bias=False)  # multi-head combining weight matrix

    def split_heads(self, x):
        B, S, D = x.size()
        # split dimension into n_head * head_dim, then transpose the sequence length w/ n_head
        # output: [B, n_head, S, head_dim]
        return x.view(B, S, self.n_head, self.head_dim).transpose(1, 2)

    def combine_heads(self, x):
        B, _, S, head_dim = x.size()  # _ is n_head which we will merge
        # output: [B, S, n_embd]
        return x.transpose(1, 2).contiguous().view(B, S, self.n_embd)

    def scaled_dot_product(self, q, k, v, dropout, mask=None):
        # q,k,v are [B, n_head, S, head_dim]
        # q @ k.T(-2, -1) sets up batch multiplication s.t. wei = [B, n_head, S, S]
        wei = q @ k.transpose(-2, -1)   #/ np.sqrt(self.head_dim
        wei = wei / torch.sqrt(torch.tensor(self.head_dim, dtype=q.dtype)) # Scale factor
        # mask = [B, 1, S, S], so it is simply broadcasted across each head 
        if mask is not None:
            wei = wei.masked_fill(mask , -1e4)

        wei = dropout(F.softmax(wei, dim=-1))
        out = wei @ v
        return out

    def forward(self, x, mask=None):
        # x: (B, S, n_embd)
        # Step 1 and 2: Project full query, key, value, then split via 
        # print(f"MHA Input x shape: {x.shape}") # Debug print
        q = self.split_heads(self.query(x))
        k = self.split_heads(self.key(x))
        v = self.split_heads(self.value(x))

        # Step 3: Compute scaled dot-product attention with causal mask
        attn = self.scaled_dot_product(q, k, v, self.drop, mask)

        # Step 4 and 5: Concatenate attention scores, return projected output matrix
        out = self.out(self.combine_heads(attn))  # (B, S, n_embd
        
        return out


class Block(nn.Module):
    def __init__(self, n_embd, n_head, seq_length, dropout=0.1):
        super().__init__()
        self.sa = MultiHeadAttention(n_embd, n_head, seq_length, dropout)
        self.mlp = MLP(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        
        self.drop = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        # residual connection (stream)
        x = x + self.drop(self.sa(self.ln1(x), mask))
        if torch.isnan(x).any(): print("NaN after SA!")
        x = x + self.drop(self.mlp(self.ln2(x)))
        if torch.isnan(x).any(): print("NaN after MLP!")
        return x


class PositionalEncoding(nn.Module):
    """
    Formula taken from the original Transformer paper:
    PE(pos, 2i (even)) = sin(pos/(10000^{2i/d_model}))
    PE(pos, 2i+1 (odd)) = cos(pos/(10000^{2i/d_model}))

    See reference for more details:
    https://kikaben.com/transformers-positional-encoding/
    """

    def __init__(self, d_model, max_len):
        super().__init__()

        position = torch.arange(max_len).unsqueeze(1)  # [max_len, 1]
        divisor = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))  # [d_model / 2, half for each of sin and cos]
        
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * divisor)
        pe[:, 1::2] = torch.cos(position * divisor)
        self.register_buffer("pe", pe.unsqueeze(0)) # shape: [1, max_len, d_model]
        # result: self.pe = [max_len, d_model], mapping each token index to a vector of length d_model as desired

    def forward(self, x):
        # index self.pe for the first seq_length mappings
        # output = (seq_length, d_model=n_embd)
        # return self.pe[: x.size(0)]
        pos_enc = self.pe[:, : x.size(1)].to(x.device)

        # Add positional encoding to the input embeddings
        # PE buffer shape [1, S, D] broadcasts correctly with x shape [B, S, D]
        return x + pos_enc


class BetterTransformer(nn.Module):
    def __init__(
        self,
        vocab_size,
        seq_length,
        n_embd,
        n_head,
        n_layer,
        pad_idx,
        eos_idx,
        device,
        dropout=0.1,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.n_head = n_head
        self.n_layer = n_layer
        self.dropout_rate = dropout 
        
        self.token_embedding = nn.Embedding(vocab_size, n_embd, padding_idx=pad_idx)
        self.position_embedding = PositionalEncoding(n_embd, seq_length)
        self.blocks = nn.Sequential(
            *[Block(n_embd, n_head, seq_length, dropout) for _ in range(n_layer)]
        )
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.drop = nn.Dropout(dropout)
        self.seq_length = seq_length
        self.pad_idx = pad_idx
        self.eos_idx = eos_idx
        self.device = device
        self.to(device)
        self.init_params()

    # weight initialization (Xavier uniform)
    def init_params(self, default_initialization=False):
        if not default_initialization:
            for name, p in self.named_parameters():
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)


    def get_causal_mask(self, seq_len):
        """
        Generates causal mask for decoding
        """
        causal_mask_sq = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() 
        return causal_mask_sq.to(self.device)

    
    def get_pad_mask(self, x, pad_idx):
        """
        Generates padding mask
        """
        return (x == pad_idx).unsqueeze(1).unsqueeze(-2).to(self.device)
        # (B x 1 x 1 x seq_len)

    def forward(self, x, targets=None):

        x = x.to(torch.long)
        B, S = x.shape
        
        pad_mask = self.get_pad_mask(x, self.pad_idx) # Shape: (B, 1, 1, S)

        # Get causal mask: True for future tokens (to be masked)
        causal_mask = self.get_causal_mask(S) # Shape: (S, S) -> (1, 1, S, S) after broadcasting/unsqueeze in get_causal_mask if implemented differently

        # Mask for Key (prevents attending to padded tokens): (B, 1, 1, S) True for padded
        key_padding_mask = (x == self.pad_idx).unsqueeze(1).unsqueeze(2) # (B, 1, 1, S)

        # Mask for Query (prevents attending from padded tokens): (B, 1, S, 1) True for padded
        query_padding_mask = (x == self.pad_idx).unsqueeze(1).unsqueeze(-1) # (B, 1, S, 1)

        # Causal mask: (1, 1, S, S) True for future
        causal_mask_sq = self.get_causal_mask(S).unsqueeze(0).unsqueeze(0) # (1, 1, S, S)

        # Combined mask: True if key is padded OR query is padded OR it's a future position
        # Shape will broadcast to (B, 1, S, S)
        attention_mask = key_padding_mask | query_padding_mask | causal_mask_sq
        # attention_mask shape: (B, 1, S, S) True where attention should be prevented (masked)

        tok_emb = self.token_embedding(x) # shape: (B, S, n_embd)

        x = self.drop(tok_emb + self.position_embedding(tok_emb)) # (B, S, n_embd)
        # Pass through Transformer Blocks
        # Each block receives the attention mask (True for masked positions)
        for block in self.blocks:
            
            x = block(x, attention_mask) # (B, S, n_embd)

        # final linear layer (Language Model Head)
        # logits shape: (B, S, vocab_size) - logits for predicting the next token at each position
        logits = self.lm_head(x)
        
        # Calculate loss with targets (Teacher Forcing)
        # The loss compares the prediction at position i with the target token at position i.
        # Since we want to predict the *next* token, the target at position i should be the token at position i+1.
        # This means targets are typically the inputs shifted left by one, with padding at the end.
        loss = None
        if targets is not None:
            # Reshape logits and targets for CrossEntropyLoss
            # logits: (B, S, vocab_size) -> (B*S, vocab_size)
            # targets: (B, S) -> (B*S)
            # We ignore loss for padding tokens in the targets
            loss = F.cross_entropy(
                logits.view(-1, self.vocab_size), # Reshape logits to (Total Tokens, vocab_size)
                targets.view(-1),             # Reshape targets to (Total Tokens)
                ignore_index=self.pad_idx,    # Ignore loss for padding tokens
                reduction='mean'              # Default: average loss over non-ignored tokens
            )

        return logits, loss
        

    def generate(
        self,
        input_ids,
        method="multinomial",
        max_new_tokens=1000,
        temp=1.0,
        p_nucleus=0.9,
        k=50,
    ):

        # input_ids begins as (batch_size, seq_length)

        self.eval()
        # Ensure input_ids is on the correct device
        input_ids = input_ids.to(self.device)

        # Track which sequences in the batch have finished generating
        batch_size, initial_seq_len = input_ids.shape
        finished = torch.zeros(batch_size, dtype=torch.bool, device=self.device)

        # Generation loop: runs up to max_new_tokens times
        for _ in range(max_new_tokens):
            # Check if all sequences are finished
            if finished.all():
                break # Exit the main generation loop

            # The model can only process up to self.seq_length tokens at a time
            input_for_model = (
                input_ids[:, -self.seq_length :]
            ) # Take the last self.seq_length tokens

            # 2. Get model predictions (logits) for the next token
            with torch.no_grad():
                # Forward pass using the truncated input
                # We only care about the logits for the last token in the input_for_model sequence
                # because we are predicting the *next* token auto-regressively
                logits, _ = self(input_for_model)
                # logits shape: (batch_size, sequence_length_in_model, vocab_size)
                # sequence_length_in_model is min(initial_seq_len + generated_tokens, self.seq_length)
                # We need predictions for the position *after* the last token in input_ids

            # logits for the next token (corresponding to the last token in input_for_model)
            # logits_next_token shape: (batch_size, vocab_size)
            logits_next_token = logits[:, -1, :]

            # 3. Apply temperature if using temperature sampling
            if method == "temperature" and temp is not None:
                 if temp <= 0:
                     raise ValueError("Temperature must be positive for sampling")
                 logits_next_token = logits_next_token / temp # Divide logits by temperature

            # 4. Convert logits to probabilities
            # probs shape: (batch_size, vocab_size)
            probs = F.softmax(logits_next_token, dim=-1)

            # 5. Sample the next token based on the chosen method
            next_token_ids = None # Tensor to hold the sampled next token ID for each sequence (batch_size, 1)

            if method == "greedy":
                # Greedy sampling: Choose the token with the highest probability
                next_token_ids = probs.argmax(dim=-1, keepdim=True) # keepdim=True to maintain shape (B, 1)

            elif method == "multinomial" or method == "temperature":
                 # Simple multinomial sampling
                 # Handles temperature because logits were already adjusted if method was 'temperature'
                 next_token_ids = torch.multinomial(probs, num_samples=1) # (batch_size, 1)

            elif method == "nucleus": # Top-p sampling
                # Ensure p_nucleus is valid
                assert p_nucleus is not None and (0 < p_nucleus <= 1), "Nucleus sampling requires 0 < p_nucleus <= 1"

                # Sort probabilities in descending order
                sorted_probs, sorted_indices = probs.sort(dim=-1, descending=True)
                # Calculate cumulative probabilities
                cumulative_probs = sorted_probs.cumsum(dim=-1)

                # Find indices to remove (where cumulative probability exceeds p_nucleus)
                # Shift the cumulative probabilities right by one to keep the first token above the threshold
                indices_to_remove = cumulative_probs > p_nucleus
                indices_to_remove[..., 1:] = indices_to_remove[..., :-1].clone()
                indices_to_remove[..., 0] = False # Always keep the most probable token

                # Set probabilities of tokens to remove to 0
                # Use scatter or index_fill to set original probability tensor to 0
                probs = probs.scatter(dim=-1, index=sorted_indices, src=torch.zeros_like(probs, device=self.device).masked_fill_(indices_to_remove, 0.0))

                # Sample from the filtered probabilities
                # Handle potential case where all probabilities were zeroed (fallback to greedy or a safe token)
                # Add a small epsilon to prevent sampling error if probs are all zero
                probs = probs + 1e-8 # epsilon
                next_token_ids = torch.multinomial(probs, num_samples=1)

            elif method == "top-k":
                # Ensure k is valid
                assert k is not None and k > 0, "Top-k sampling requires k > 0"
                k = min(k, self.vocab_size) # Don't select more tokens than are in the vocab

                # Get the top k probabilities and indices
                top_k_probs, top_k_indices = torch.topk(probs, k=k, dim=-1)

                # Create a mask to zero out probabilities below the k-th token's probability
                # Get the minimum probability among the top k for each sequence
                min_k_prob = top_k_probs[..., -1, None] # Shape (batch_size, 1)

                # Create a mask where probs < min_k_prob are True
                indices_to_remove = probs < min_k_prob

                 # Set probabilities of tokens to remove to 0
                probs[indices_to_remove] = 0.0 # This modifies the original probs tensor

                 # Sample from the filtered probabilities
                # Add a small epsilon to prevent sampling error if probs are all zero
                probs = probs + 1e-8 # epsilon
                next_token_ids = torch.multinomial(probs, num_samples=1) # (batch_size, 1)


            else:
                raise ValueError(f"Unknown sampling method: {method}")

            # --- Autoregressive Update ---
            # Only append the next token to sequences that are NOT finished
            # Create a mask for finished sequences (True for finished)
            finished_mask = finished.unsqueeze(-1) # Shape (batch_size, 1)

            # For finished sequences, the 'next token' is just their last token
            # For unfinished sequences, it's the sampled token
            # Use the last token ID of the sequence for finished ones
            last_token_in_seq = input_ids[:, -1].unsqueeze(-1) # Shape (batch_size, 1)
            next_token_to_add = torch.where(finished_mask, last_token_in_seq, next_token_ids)

            # Append the selected next token to the input_ids for the next iteration
            input_ids = torch.cat((input_ids, next_token_to_add), dim=-1)
            # input_ids shape is now (batch_size, initial_sequence_length + current_generated_tokens)

            # Update the 'finished' status: sequence is finished if it generated EOS token
            # Use the original sampled next_token_ids before the 'where' statement
            newly_finished = (next_token_ids == self.eos_idx).squeeze(-1) # Shape (batch_size,)
            finished = finished | newly_finished # Update finished status using boolean OR

        # Return the generated sequences 
        return input_ids

In [None]:

set_seed() 

# DATA PREPARATION

data = load_dataset("roneneldan/TinyStories")

## Tokenizer
tokenizer = load_tokenizer(device)

print(f"Tokenizer loaded. Vocabulary size: {tokenizer.vocab_size}")
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    pad_token_id = tokenizer.pad_token_id
    print(f"Added [PAD] token with ID: {pad_token_id}")
else:
    # Use the existing pad token ID
    pad_token_id = tokenizer.pad_token_id
    print(f"Using existing pad token {tokenizer.pad_token} with ID: {pad_token_id}")
if tokenizer.bos_token_id is None:
     print("Warning: Tokenizer does not have a BOS token ID. Model might need one prepended manually or add special token.")

if tokenizer.eos_token_id is None:
    print("Warning: Tokenizer does not have an EOS token ID. Model will not learn to end sequences unless handled.")

tokenizer.pad_token_id = pad_token_id 

def preprocess(examples, tokenizer=tokenizer, max_length=SEQ_LENGTH):
    """
    Tokenizes text data, adds EOS token, truncates, and pads sequences.
    Designed to be used with datasets.map(batched=True).
    """
    # 1. Tokenize the text data
    # This returns {'input_ids': [...], 'attention_mask': [...]}
    tokenized_inputs = tokenizer(
        examples["text"],
        padding=False, 
        truncation=True, # Truncate sequences longer than max_length
        max_length=max_length, # Truncate to model's sequence length
    )

    # 2. Add EOS token and handle padding/truncation
    # We want sequences to end with EOS followed by padding.
    processed_inputs = []

    for input_ids in tokenized_inputs["input_ids"]:
        # If the sequence was truncated, it might not end with EOS.
        # If the sequence fits within max_length, add EOS.
        # We must account for adding EOS within max_length.
        # If original seq + EOS > max_length, truncate the original sequence by 1 to fit EOS.

        if len(input_ids) >= max_length:
             # Sequence is already >= max_length, truncate to max_length - 1 to make space for EOS
             processed_seq = input_ids[:max_length - 1]
        else:
             # Sequence is shorter than max_length, add EOS directly
             processed_seq = input_ids

        if tokenizer.eos_token_id is None:
             print("Error: EOS token ID is None when trying to add it.")
             processed_seq.append(tokenizer.pad_token_id) # Example fallback
        else:
            processed_seq.append(tokenizer.eos_token_id)

        # Now pad the sequence to max_length.
        # This padding is important if you want fixed-length sequences saved to disk.
        
        # Calculate padding needed
        padding_length = max_length - len(processed_seq)

        if padding_length < 0:
             print(f"Error: Padding length is negative ({padding_length}) after processing. Seq len: {len(processed_seq)}")
             padding_length = 0 
            
        # Extend with padding tokens
        processed_seq.extend([tokenizer.pad_token_id] * padding_length)

        # Ensure final length is max_length
        if len(processed_seq) != max_length:
            print(f"Error: Final processed sequence length {len(processed_seq)} != max_length {max_length}")
            # Truncate or pad to ensure max_length
            processed_seq = processed_seq[:max_length] # Final safety truncation


        processed_inputs.append(processed_seq)
        


    # Return as a dictionary matching the datasets library expected format
    # 'input_ids' will contain the sequences with EOS and padding
    return {"input_ids": processed_inputs} 


print(f"Tokenizing dataset and processing sequences (max_length={SEQ_LENGTH})...")


data_tokenized = data.map(
    preprocess,
    batched=True,
    num_proc=os.cpu_count(), 
    remove_columns=["text"] # to save memory/disk
)
print("Tokenization and processing complete.")


output_dir = "kaggle/working/tokenized_data_384.hf"
print(f"Saving tokenized dataset to {output_dir}...")
data_tokenized.save_to_disk(output_dir)
print("Tokenization complete and saved!")


In [None]:
print(f"Type of data_tokenized before saving: {type(data_tokenized)}")
print(f"Structure of data_tokenized before saving: {data_tokenized}")
print(f"Number of examples in train split: {len(data_tokenized['train'])}")
print(f"Number of examples in validation split: {len(data_tokenized['validation'])}")

In [None]:
## Setup
set_seed()

def collate_wrapper_processed_data(batch, tokenizer):
    """
    Collate function for the DataLoader, assuming data has been preprocessed
    (tokenized, truncated, padded to max_length=SEQ_LENGTH, EOS added before padding).
    Stacks the processed sequences and creates teacher forcing inputs by shifting.

    Args:
        batch (list): A list of samples from the processed dataset.
                      Each sample is a dictionary with 'input_ids' (as list/tensor of shape (SEQ_LENGTH,)).
        tokenizer: The tokenizer used, needed for BOS token ID.

    Returns:
        tuple: A tuple containing (model_inputs, targets).
               model_inputs: Shifted sequences with BOS prepended. (batch_size, SEQ_LENGTH)
               targets: Original processed sequences (with EOS/padding). (batch_size, SEQ_LENGTH)
    """
    
    try:
        targets = torch.stack([i["input_ids"].clone().detach().long() for i in batch])
    except Exception as e:
        print(f"Error stacking batch: {e}")
        print(f"Sample batch item input_ids length: {[len(i['input_ids']) for i in batch[: min(5, len(batch))]]}")
        raise e 

    # targets shape: (batch_size, SEQ_LENGTH)

    # Create model inputs by shifting the targets right and prepending BOS
    if tokenizer.bos_token_id is None:
         # Using pad_token_id as a fallback start token if BOS is missing 
         print("Warning: BOS token ID is None. Using PAD token ID as decoder_start_token_id.")
         decoder_start_token_id = tokenizer.pad_token_id
         if decoder_start_token_id is None:
              raise ValueError("Neither BOS nor PAD token ID is set. Cannot create shifted inputs.")
    else:
        decoder_start_token_id = tokenizer.bos_token_id


    model_inputs = shift_tokens_right(targets, decoder_start_token_id)
    # model_inputs shape: (batch_size, SEQ_LENGTH)

    return model_inputs, targets

tokenized_data_path = "kaggle/working/tokenized_data_384.hf"

try:
    data = load_from_disk(tokenized_data_path)

except Exception as e:
    print(f"Error loading dataset: {e}")
    print(f"Attempted to load from {PATH}/tokenized_data_384/")
    print("Please ensure the path is correct and contains .parquet files for train/validation splits.")
    raise e 


print("Setting dataset format to torch...")

data.set_format("torch")

print("Shuffling and selecting data subsets...")
# Shuffle the entire dataset splits (important for training)
data = data.shuffle(seed=set_seed()) # set_seed function for reproducibility

# Select a percentage of the training data if DATA_PCT < 1.0
# Select the full validation data
train_split_name = 'train' 
val_split_name = 'validation' 

if DATA_PCT < 1.0:
    # Calculate the number of samples to select
    num_train_samples = int(DATA_PCT * len(data[train_split_name]))
    print(f"Selecting {num_train_samples} samples ({DATA_PCT*100}%) for training...")
    train_data = data[train_split_name].select(range(num_train_samples))
else:
    print(f"Using full training data ({len(data[train_split_name])} samples)...")
    train_data = data[train_split_name]

val_data = data[val_split_name].select(range(500)) # Use the full validation split by default

print(f"Final Train size: {len(train_data)}, Val size: {len(val_data)}")

def get_dataloaders(train_data, val_data, tokenizer, batch_size=BATCH_SIZE, num_workers=4):
    """
    Returns DataLoaders with increased workers for faster data loading.
    """
    train_dataloader = DataLoader(
        train_data,
        shuffle=True,
        batch_size=batch_size,
        collate_fn=lambda batch: collate_wrapper_processed_data(batch, tokenizer),
        num_workers=num_workers,  
        pin_memory=True,  
    )

    val_dataloader = DataLoader(
        val_data,
        shuffle=False,
        batch_size=batch_size,
        collate_fn=lambda batch: collate_wrapper_processed_data(batch, tokenizer),
        num_workers=4,  
        pin_memory=True,  
    )

    return train_dataloader, val_dataloader



def prep_train(tokenizer, vocab_size=VOCAB_SIZE,
               seq_length=SEQ_LENGTH,
               n_embd=N_EMBD, n_head=N_HEAD, n_layer=N_LAYER, max_lr=MAX_LR,
               device=device):
    """
    Returns newly initialized model and optimizer.
    Explicitly calculates vocab size based on known token IDs from the tokenizer.
    """
    print("\n--- Initializing Model ---")

    # Get special token IDs from the tokenizer
    pad_idx = tokenizer.pad_token_id
    eos_idx = tokenizer.eos_token_id # 50256 or None
    bos_idx = tokenizer.bos_token_id # 50256 or None

    if pad_idx is None:
         raise ValueError("PAD token ID is required but not set on the tokenizer.")

    
    max_token_id_in_tokenizer = tokenizer.vocab_size - 1 

    max_known_id = max_token_id_in_tokenizer
    if pad_idx is not None:
        max_known_id = max(max_known_id, pad_idx) # Max is 50257
    if eos_idx is not None:
        max_known_id = max(max_known_id, eos_idx) # Max is 50256
    if bos_idx is not None:
        max_known_id = max(max_known_id, bos_idx) # Max is 50256

    # The model's vocabulary size must be 1 greater than the absolute maximum token ID used by the tokenizer.
    model_vocab_size = max_known_id + 1 # This will be 50257 + 1 = 50258
    print(f"Calculated model_vocab_size based on max identified token ID ({max_known_id}): {model_vocab_size}")


    # Generation will only stop based on max_new_tokens unless you use a different stop token.
    if eos_idx is None:
         print("Warning: Tokenizer does not have a defined EOS token ID. Model generation stopping will rely on max_new_tokens.")

    model = BetterTransformer(
        vocab_size=model_vocab_size,
        seq_length=seq_length,
        n_embd=n_embd,
        n_head=n_head,
        n_layer=n_layer,
        pad_idx=pad_idx, 
        eos_idx=eos_idx, 
        device=device,
        dropout=0.1,
    )
    model.to(device)

    num_params = sum(p.numel() for p in model.parameters())
    
    optimizer = optim.AdamW(model.parameters(), lr=max_lr)

    return model, optimizer


# Load the model from checkpoint
def load_checkpoint(model, optimizer, scheduler=None, path=PATH, model_name=MODEL_NAME, load_epoch=LOAD_EPOCH):
    """
    Loads model, optimizer, and optionally scheduler states from a checkpoint file.
    Also loads loss history.

    Args:
        model (nn.Module): The model object to load state into.
        optimizer (optim.Optimizer): The optimizer object to load state into.
        scheduler (Optional[object]): The scheduler object to load state into (e.g., StepLR).
                                     Pass None if no scheduler is used/saved.
        path (str): Base path to the checkpoint directory.
        model_name (str): Base name of the model file.
        load_epoch (int): The epoch number of the checkpoint to load.

    Returns:
        tuple: A tuple containing (model, optimizer, scheduler, train_losses, val_losses, loaded_epoch).
               loaded_epoch is the epoch number saved in the checkpoint file.
    """
    if load_epoch is None:
        raise ValueError("LOAD_EPOCH must be specified when calling load_checkpoint.")

    checkpoint_file = f"{path}/model/{model_name}_epoch_{load_epoch}.pt"
    print(f"Loading checkpoint from {checkpoint_file}...")

    try:
        checkpoint = torch.load(checkpoint_file, map_location=device)

        model.load_state_dict(checkpoint["model_state_dict"])
        print("Model state loaded.")

        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        print("Optimizer state loaded.")

        if scheduler is not None and "scheduler_state_dict" in checkpoint:
            try:
                 scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
                 print("Scheduler state loaded.")
            except Exception as e:
                 print(f"Warning: Could not load scheduler state dict: {e}")
        elif scheduler is not None and "scheduler_state_dict" not in checkpoint:
             print("Warning: Scheduler object provided, but no scheduler state found in checkpoint.")
        

        loaded_epoch = checkpoint.get('epoch', load_epoch) 
        print(f"Checkpoint epoch: {loaded_epoch}")


        # Load loss history
        train_losses_file = f"{path}/train_logs/{model_name}_train_losses.json"
        val_losses_file = f"{path}/train_logs/{model_name}_val_losses.json"

        train_losses = []
        val_losses = []

        if os.path.exists(train_losses_file):
            with open(train_losses_file, "r") as f3:
                train_losses = json.load(f3)
            print("Train losses loaded.")

        if os.path.exists(val_losses_file):
            with open(val_losses_file, "r") as f4:
                val_losses = json.load(f4)
            print("Val losses loaded.")

        print("Checkpoint loaded successfully.")
        return model, optimizer, scheduler, train_losses, val_losses, loaded_epoch

    except FileNotFoundError:
        print(f"Error: Checkpoint file not found at {checkpoint_file}")
        raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_file}")
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        raise e 


def evaluate(model, val_dataloader, device, val_losses, epoch, tokenizer, model_pad_idx, metrics_csv_path="metrics_log.csv"):
    """
    Evaluates model on validation dataset.
    Calculates validation loss and generation metrics (BLEU, ROUGE, METEOR, Perplexity).
    """
    print(f"Evaluating Epoch {epoch+1}...")
    start_time = time.perf_counter() 

    model.eval()

    val_losses_batch = [] # Collect loss per batch

    # Calculate validation loss
    with torch.no_grad():
        for batch_idx, batch in enumerate(val_dataloader): # Use batch_idx
            
            dec_input = batch[0].to(device) # These are the shifted inputs (inputs to the model)
            targets = batch[1].to(device)   # These are the true targets (labels for loss)

            # Forward pass to get loss
            logits, loss = model(dec_input, targets)

            # Collect loss item (already averaged per non-padding token by model forward)
            val_losses_batch.append(loss.item())

    # Calculate average validation loss for the epoch
    avg_val_loss = sum(val_losses_batch) / len(val_dataloader)
    val_losses.append(avg_val_loss) 

    # You need to pass model.pad_idx to compute_metrics for the perplexity fix
    generation_metrics = compute_metrics(model, tokenizer, val_dataloader, device, model_pad_idx=model.pad_idx)
    bleu_results, rouge_results, meteor_results, perplexity_score = generation_metrics
    meteor_display_score = meteor_results.get('meteor', 0.0)


    # --- Print/Log Evaluation Results ---
    eval_time = time.perf_counter() - start_time
    
    print(f"Epoch: {epoch+1}/{EPOCHS} | Full Val Loss: {avg_val_loss:.5f} | Perplexity: {perplexity_score:.2f} |" f" BLEU: {bleu_results['bleu']:.4f} | ROUGE1: {rouge_results['rouge1']:.4f} | ROUGE2: {rouge_results['rouge2']:.4f} | ROUGE_L: {rouge_results['rougeL']:.4f} | METEOR: {meteor_display_score:.4f} |" f" Evaluation Time: {eval_time:.3f}s")
    meteor_score = meteor_results.get('meteor', 0.0)
    rouge1 = rouge_results.get('rouge1', 0.0)
    rouge2 = rouge_results.get('rouge2', 0.0)
    rougeL = rouge_results.get('rougeL', 0.0)
    rougeLsum = rouge_results.get('rougeLsum', 0.0)
    bleu_score = bleu_results.get('bleu', 0.0)

    # --- Save to CSV ---
    fieldnames = [
        "epoch", "val_loss", "perplexity", "bleu",
        "rouge1", "rouge2", "rougeL", "rougeLsum", "meteor", "eval_time_sec"
    ]
    row = {
        "epoch": epoch + 1,
        "val_loss": avg_val_loss,
        "perplexity": perplexity_score,
        "bleu": bleu_score,
        "rouge1": rouge1,
        "rouge2": rouge2,
        "rougeL": rougeL,
        "rougeLsum": rougeLsum,
        "meteor": meteor_score,
        "eval_time_sec": eval_time,
    }
    fieldnames = list(row.keys())

    # Check if the file exists and has content
    file_exists = os.path.exists(metrics_csv_path) and os.path.getsize(metrics_csv_path) > 0


    
    with open(metrics_csv_path, mode="a", newline="") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        if not file_exists:
            writer.writeheader()
        writer.writerow(row)

    

    
    return {
        "val_loss": avg_val_loss,
        "perplexity": perplexity_score,
        "bleu": bleu_results,
        "rouge": rouge_results,
        "meteor": meteor_results 
    }

    

def generate_train(
    model, tokenizer, generation_file_path, cond_prompts, epoch, device, num_uncond_samples
):
    """
    Generates model output to unconditional and conditional prompts, writes output to file.
    Uses corrected batch generation and simplified string formatting.

    Args:
        model (nn.Module): The model to generate from.
        tokenizer: The tokenizer.
        generation_file_path (str): Path to the output file.
        cond_prompts (list[str]): List of conditional prompt strings.
        epoch (int): Current epoch number (for logging).
        device (torch.device): Device for tensors.
    """
    print(f"\nGenerating text samples for Epoch {epoch+1}...")
    start_time = time.perf_counter() 

    model.eval() 

    generation_text = f"{MODEL_NAME} Output @Epoch {epoch+1}\n"
    generation_text += f"Generation Time @Epoch {epoch+1}: {0}s (Placeholder)\n"

    # --- Unconditional Generation ---
    generation_text += "\nUNCONDITIONAL GENERATION:\n\n"

    num_uncond_samples = num_uncond_samples # Generate a few unconditional samples
    uncond_start_tokens = torch.full(
        (num_uncond_samples, 1), tokenizer.bos_token_id, dtype=torch.long, device=device
    )
    if tokenizer.bos_token_id is None:
         print("Warning: BOS token ID is None. Cannot perform standard unconditional generation.")
         generation_text += "Skipped unconditional generation: BOS token ID is None.\n"
    else:
        # Generate unconditional samples using different methods
        with torch.no_grad():
            uncond_samples_150_topk = model.generate(uncond_start_tokens, method="top-k", k=5, max_new_tokens=150)
            uncond_samples_150_greedy = model.generate(uncond_start_tokens, method="greedy", max_new_tokens=150)
            uncond_samples_150_nucleus = model.generate(uncond_start_tokens, method="nucleus", p_nucleus=0.9, max_new_tokens=150) # Use p_nucleus=0.9 as a standard value
            uncond_samples_150_multinomial = model.generate(uncond_start_tokens, method="multinomial", max_new_tokens=150, temp=1.0) # Add temp for multinomial/temperature

            uncond_samples_250_topk = model.generate(uncond_start_tokens, method="top-k", k=5, max_new_tokens=250)
            uncond_samples_250_nucleus = model.generate(uncond_start_tokens, method="nucleus", p_nucleus=0.9, max_new_tokens=250)


        # Decode samples (batch_decode handles batches)
        decoded_uncond_150_topk = tokenizer.batch_decode(uncond_samples_150_topk, skip_special_tokens=True)
        decoded_uncond_150_greedy = tokenizer.batch_decode(uncond_samples_150_greedy, skip_special_tokens=True)
        decoded_uncond_150_nucleus = tokenizer.batch_decode(uncond_samples_150_nucleus, skip_special_tokens=True)
        decoded_uncond_150_multinomial = tokenizer.batch_decode(uncond_samples_150_multinomial, skip_special_tokens=True)
        decoded_uncond_250_topk = tokenizer.batch_decode(uncond_samples_250_topk, skip_special_tokens=True)
        decoded_uncond_250_nucleus = tokenizer.batch_decode(uncond_samples_250_nucleus, skip_special_tokens=True)


        # Format and append unconditional results
        generation_text += "Top-k (5) (150 max_tokens):\n" + "\n---\n".join(decoded_uncond_150_topk) + "\n\n"
        generation_text += "Greedy (150 max_tokens):\n" + "\n---\n".join(decoded_uncond_150_greedy) + "\n\n"
        generation_text += f"Nucleus (0.9) (150 max_tokens):\n" + "\n---\n".join(decoded_uncond_150_nucleus) + "\n\n"
        generation_text += f"Multinomial (temp=1) (150 max_tokens):\n" + "\n---\n".join(decoded_uncond_150_multinomial) + "\n\n"
        generation_text += "Top-k (5) (250 max_tokens):\n" + "\n---\n".join(decoded_uncond_250_topk) + "\n\n"
        generation_text += f"Nucleus (0.9) (250 max_tokens):\n" + "\n---\n".join(decoded_uncond_250_nucleus) + "\n\n"


    # --- Conditional Generation ---
    generation_text += "\n#####################################################\n"
    generation_text += f"CONDITIONAL GENERATION (Top-k (5), 250 max_tokens):\n"
    generation_text += "-----------------------------------------------------\n"

    cond_inputs = tokenizer(
        cond_prompts,
        padding="max_length", # Pad prompts to max_length for batching
        truncation=True,
        max_length=SEQ_LENGTH, # Truncate/pad to model's sequence length
        return_tensors="pt", 
    ).input_ids.to(device) 

    if tokenizer.eos_token_id is None:
         print("Warning: EOS token ID is None. Generated conditional text might not stop correctly.")

    cond_generated_tokens = model.generate(
        cond_inputs,
        method="top-k", 
        k=5, 
        max_new_tokens=250, 
        p_nucleus=0.9,
        temp=1.0,
    )
    
    decoded_cond_samples = tokenizer.batch_decode(cond_generated_tokens, skip_special_tokens=False) 

    formatted_cond_results = []
    for i, full_sequence_tokens in enumerate(cond_generated_tokens):
        original_prompt_tokens = cond_inputs[i] 
        
        original_prompt_text = cond_prompts[i]
        
        print(f"\n--- Debugging Conditional Prompt {i+1} ---")
        print(f"Original Prompt Text: '{original_prompt_text}'")
        original_prompt_token_ids_exact = tokenizer.encode(original_prompt_text) 
        print(f"Original Prompt Text Encoded IDs (for split_point): {original_prompt_token_ids_exact}")
        split_point = len(original_prompt_token_ids_exact)
        print(f"Calculated Split Point: {split_point}")
        print(f"Full Generated Sequence Length: {len(full_sequence_tokens)}")
        
        # Check the slice of tokens that *should* be the generated part
        generated_slice_tokens = full_sequence_tokens[split_point:]
        print(f"Generated Slice Length: {len(generated_slice_tokens)}")
        print(f"Generated Slice Token IDs (first 20): {generated_slice_tokens.tolist()[:20]}")
        decoded_slice_raw = tokenizer.decode(generated_slice_tokens.tolist(), skip_special_tokens=False)
        print(f"Decoded Generated Slice (raw): '{decoded_slice_raw}'")
        decoded_prompt_str = original_prompt_text
        decoded_generated_str = tokenizer.decode(full_sequence_tokens[split_point:], skip_special_tokens=True)
        formatted_result = f"{decoded_prompt_str} || {decoded_generated_str}"
        formatted_cond_results.append(formatted_result)
    generation_text += "\n\n".join(formatted_cond_results)


    # --- Final Output ---
    
    end_time = time.perf_counter()
    generation_duration = end_time - start_time
    generation_text = generation_text.replace(f"Generation Time @Epoch {epoch+1}: {0}s (Placeholder)",f"Generation Time @Epoch {epoch+1}: {generation_duration:.3f}s")


    # Write to file
    with open(generation_file_path, "a") as file:
        file.write(generation_text + "\n\n") 

    print(generation_text)

    model.train() # Set model back to train mode after generation


In [None]:
def train(
    model: nn.Module,
    tokenizer: AutoTokenizer, 
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    device: torch.device,
    optimizer: optim.Optimizer,
    scheduler: Optional[object], 
    train_loss_list: Optional[list] = None, 
    val_loss_list: Optional[list] = None, 
    START_EPOCH: int = 0,
    EPOCHS: int = 10,
    SAVE_EVERY: int = 1,
    GENERATE_EVERY: int = 1,
    PATH: str = ".", 
    MODEL_NAME: str = "default_model",
    COMPUTE_PER_EPOCH: int = 10,
    num_uncond_samples: int = 4,
    cond_prompts: Optional[list[str]] = None, 
):
    """
    Main training loop for the BetterTransformer model.

    Args:
        model: The model to train.
        tokenizer: The tokenizer.
        train_dataloader: DataLoader for training data.
        val_dataloader: DataLoader for validation data.
        device: Device (cuda/cpu).
        optimizer: Optimizer.
        scheduler (Optional): Learning rate scheduler. Pass None for fixed LR.
        train_loss_list (Optional): Existing list of training losses for resuming.
        val_loss_list (Optional): Existing list of validation losses for resuming.
        START_EPOCH (int): Epoch to start from (for resuming).
        EPOCHS (int): Total number of epochs to train for.
        SAVE_EVERY (int): Save checkpoint and run generation every this many epochs.
        GENERATE_EVERY (int): Run generation every this many epochs (if different from SAVE_EVERY).
                              Set to SAVE_EVERY if always generated on save.
        PATH (str): Base path for saving checkpoints and logs.
        MODEL_NAME (str): Base name for saved files.
        COMPUTE_PER_EPOCH (int): Approx number of times to log training stats per epoch.
        num_uncond_samples (int): Number of samples for unconditional generation logging.
        cond_prompts (Optional[list[str]]): List of strings for conditional generation logging.
    """
    # Set CUDA optimizations 
    torch.backends.cuda.enable_flash_sdp(True)
    torch.backends.cuda.enable_math_sdp(False) 
    torch.backends.cuda.matmul.allow_tf32 = True 
    use_amp = (device.type == "cuda")
    print(f"Training on {device}  →  AMP enabled? {use_amp}")

    # Ensure output directories exist
    os.makedirs(f"{PATH}/model", exist_ok=True)
    os.makedirs(f"{PATH}/train_logs", exist_ok=True)


    scaler = GradScaler() 
    accum_steps = 4 # Number of batches to accumulate gradients over
    # Calculate how often to log training stats based on desired frequency per epoch
    # Ensure compute_every is at least 1
    compute_every = max(1, len(train_dataloader) // COMPUTE_PER_EPOCH)
    # Calculate how often to save 
    save_every_epochs = SAVE_EVERY
    generate_every_epochs = GENERATE_EVERY

    empty_cache_every_steps = 1000 # How often to clear CUDA cache (in training steps)


    # ----- LOGGING AND TRACKING ----- 
    # Use the loss lists passed in (allows resuming)
    train_losses = train_loss_list if train_loss_list is not None else []
    val_losses = val_loss_list if val_loss_list is not None else []

    # File path for generation logs
    generation_file_path = f"{PATH}/train_logs/OUTPUT_{MODEL_NAME}.txt" # Changed path to train_logs

    # ----- TRAINING LOOP -------
    print(f"Starting training from Epoch {START_EPOCH+1} to {EPOCHS}...")
    for epoch in range(START_EPOCH, EPOCHS):
        print(f"\n--- Epoch {epoch+1}/{EPOCHS} ---")
        print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.3g}")

        # --- Evaluate Model ---
        metrics_csv_path = f"{PATH}/train_logs/Metrics_{MODEL_NAME}.csv"
        epoch_metrics = evaluate(model, val_dataloader, device, val_losses, epoch, tokenizer, model.pad_idx, metrics_csv_path)

        # --- Set Model to Training Mode ---
        model.train()

        # --- Per-Epoch Training Initialization ---
        optimizer.zero_grad() # Zero gradients at the start of each epoch
        running_loss = 0.0 # Accumulate loss over batches in the current epoch
        batch_times = [] # List to store batch processing times

        # Setup progress bar for the training dataloader
        progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader),
                            desc=f"Epoch {epoch+1} (Train)", dynamic_ncols=True)

        # --- Training Batches Loop ---
        for step, batch in progress_bar:
            start_time = time.perf_counter() # Start timing for the batch

            # Move batch data to the correct device
            inputs, targets = batch[0].to(device), batch[1].to(device)

            # --- Forward Pass and Loss Calculation (with Mixed Precision) ---
            if use_amp:
                with autocast(): # Enable autocast for mixed precision
                    _, loss = model(inputs, targets)
    
                    # Scale loss for gradient accumulation
                    # This averages the loss over 'accum_steps' batches
                    loss = loss / accum_steps
    
                # --- Backward Pass ---
                # Scales gradients for mixed precision before backward 
                if torch.isnan(loss) or torch.isinf(loss):
                    print("NaN or Inf detected in loss!")
                    print(f"Loss: {loss}")
                    print("Checking logits and targets...")
                
                    with torch.no_grad():
                        logits, _ = model(inputs)
                        print("Max logit:", logits.max().item())
                        print("Min logit:", logits.min().item())
                        print("Any NaNs in logits?", torch.isnan(logits).any().item())
                        print("Any Infs in logits?", torch.isinf(logits).any().item())
                        print("Target max:", targets.max().item(), "Target min:", targets.min().item())
                    raise RuntimeError("Loss is NaN or Inf — stopping.")

                scaler.scale(loss).backward()
            else:
                _, loss = model(inputs, targets)
                loss = loss / accum_steps
                if torch.isnan(loss) or torch.isinf(loss):
                    raise ValueError(f"Loss is bad! {loss}")
                loss.backward()
            # --- Gradient Accumulation and Optimizer Step ---
            # Perform optimizer step only after accumulating gradients for accum_steps batches
            if (step + 1) % accum_steps == 0 or (step + 1) == len(train_dataloader):
                # Unscale gradients before clipping and optimizer step
                if use_amp:
                    scaler.unscale_(optimizer)
                    # Clip gradients to prevent exploding gradients
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Max gradient norm
    
                    # Perform optimizer step using the scaled gradients
                    scaler.step(optimizer)
    
                    # Update the scaler for the next iteration (adjusts scale factor)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                # Zero gradients after the optimizer step
                optimizer.zero_grad(set_to_none=True) # set_to_none=True is often slightly faster


            # --- Clear CUDA Cache (Optional) ---
            # Helps free up memory frequently, can reduce fragmentation
            if (step + 1) % empty_cache_every_steps == 0:
                 if torch.cuda.is_available():
                     torch.cuda.empty_cache()

            # --- Logging (inside batch loop) ---
            batch_duration = time.perf_counter() - start_time # Time for the current batch
            batch_times.append(batch_duration)

            unscaled_batch_loss = loss.item() * accum_steps
            running_loss += unscaled_batch_loss # running_loss is sum of unscaled batch losses


            # Display progress bar postfix
            progress_bar.set_postfix({
                # Display average unscaled loss per batch over steps processed so far
                'loss': f"{running_loss / (step + 1):.4f}",
                # 'grad': f"{avg_grad_norm_display:.1f}", # Display gradient norm if tracking
                'speed': f"{BATCH_SIZE / np.mean(batch_times[-min(len(batch_times), 10):]):.1f} samples/s", # Avg speed over last 10 batches
                'mem': f"{torch.cuda.memory_allocated() / 1e9:.2f}GB" if torch.cuda.is_available() else "N/A" # Display GPU memory
            })


            if (step + 1) % compute_every == 0:
                 # Running average of unscaled batch loss
                train_losses.append(running_loss / (step + 1))

        # --- Save Checkpoint ---
        # Save model, optimizer, and scheduler state_dicts
        # Include loss history in the checkpoint for cleaner resumption
        if (epoch + 1) % save_every_epochs == 0:
            checkpoint_path = f"{PATH}/model/{MODEL_NAME}_epoch_{epoch+1}.pt"
            print(f"\nSaving checkpoint to {checkpoint_path}...")
            try:
                torch.save({
                    'epoch': epoch + 1, # Save current epoch number
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict() if scheduler is not None else None,
                    'train_losses': train_losses, # Save training loss history (step-wise)
                    'val_losses': val_losses,   # Save validation loss history (epoch-wise)
                    'scaler_state_dict': scaler.state_dict() if scaler is not None else None, # Save scaler state
                }, checkpoint_path)
                print("Checkpoint saved.")
            except Exception as e:
                print(f"Error saving checkpoint {checkpoint_path}: {e}")


        # --- Generate Text Samples ---
        if (epoch + 1) % generate_every_epochs == 0:
             generate_train(model, tokenizer, generation_file_path, cond_prompts, epoch, device, num_uncond_samples)
             
        # --- Print End of Epoch Summary ---
        final_avg_train_loss = running_loss / len(train_dataloader) 
        print(f"Epoch {epoch+1} complete. Final Avg Train Loss: {final_avg_train_loss:.4f}")

    # --- End of Epoch Loop ---

    print("Training complete.")
    
    return train_losses, val_losses


In [None]:
# --- Main Training Orchestration Function ---

def train_model(
    n_head: int,
    n_layer: int,
    n_embd: int,
    vocab_size: int,
    seq_length: int,
    batch_size: int,
    data_pct: float,
    path: str,
    checkpoint: bool,
    load_epoch: Optional[int] = None,
    epochs: int = 10, 
    max_lr: float = 0.0001,
    save_every: int = 1,
    generate_every: int = 1, 
    num_workers: int = 4, 
    compute_per_epoch: int = 10, 
    num_uncond_samples: int = 4, 
    cond_prompts: Optional[list[str]] = None, 
):
    """
    Orchestrates the training process for the BetterTransformer model.
    Handles initialization, checkpoint loading, and calls the main training loop.
    """
    set_seed() 
    
    current_model_name = f"bt_{n_layer}_LAYERs_{int(data_pct*100)}_DATA_PCT_{n_embd}_EMBD_DIM"

    print(f"Training Configuration:")
    print(f"  Model: {current_model_name}")
    print(f"  Epochs: {epochs}, Start Epoch: {load_epoch if checkpoint else 0}")
    print(f"  Batch Size: {batch_size}, Data Pct: {data_pct*100}%")
    print(f"  Max LR: {max_lr}")
    print(f"  Save Every: {save_every} epochs, Generate Every: {generate_every} epochs")
    print(f"  Checkpointing Enabled: {checkpoint}, Load Epoch: {load_epoch}")
    print(f"  Output Path: {path}")
    print(f"Using device: {device}")

    # Ensure output directories exist
    os.makedirs(f"{path}/model", exist_ok=True)
    os.makedirs(f"{path}/train_logs", exist_ok=True)


    # ====== DATA LOADING AND DATALOADER SETUP ======
    if 'data' not in globals():
         print("Loading data within train_model...")
         raise RuntimeError("Global 'data' variable not found. Load data before calling train_model.")


    train_dataloader, val_dataloader = get_dataloaders(
        train_data=data["train"].select(range(int(data_pct * len(data["train"])))), 
        val_data=data["validation"].select(range(500)), 
        tokenizer=tokenizer, 
        batch_size=batch_size, 
        num_workers=num_workers, 
    )

    # ====== MODEL, OPTIMIZER, AND SCHEDULER SETUP ======
    # Initialize model and optimizer
    # Pass architecture parameters and max_lr to prep_train
    torch.autograd.set_detect_anomaly(True)
    model, optimizer = prep_train(
        tokenizer=tokenizer,
        vocab_size=vocab_size,
        seq_length=seq_length,
        n_embd=n_embd,
        n_head=n_head,
        n_layer=n_layer,
        max_lr=max_lr, 
        device=device,
    )

    # Initialize scheduler even if not used for fixed LR (pass None)
    # This ensures the 'scheduler' variable exists to be passed around
    # If using a scheduler (e.g., StepLR), initialize it here:
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5) # Example

    # scheduler = initial_scheduler # Pass the initialized scheduler (likely None here for fixed LR)
    scheduler = None

    # ====== CHECKPOINT LOADING OR STARTING FROM SCRATCH ======
    train_losses = [] 
    val_losses = []
    start_epoch = 0 

    if checkpoint:
        try:
            model, optimizer, scheduler, train_losses, val_losses, loaded_epoch = load_checkpoint(
                model=model,
                optimizer=optimizer,
                scheduler=scheduler, 
                path=path,
                model_name=current_model_name, 
                load_epoch=load_epoch, 
            )
            start_epoch = loaded_epoch 
            print(f"Resuming training from Epoch {start_epoch + 1}")

        except (FileNotFoundError, RuntimeError) as e:
            print(f"Could not load checkpoint: {e}")
            print("Starting training from scratch.")
            
            start_epoch = 0
            
            print("Re-initializing model, optimizer, scheduler for starting from scratch...")
            model, optimizer, scheduler = prep_train(
                 tokenizer=tokenizer, # Need tokenizer here
                 vocab_size=vocab_size, seq_length=seq_length, n_embd=n_embd,
                 n_head=n_head, n_layer=n_layer, max_lr=max_lr, device=device
            )
            train_losses = []
            val_losses = []


    # ====== DEFINE CONDITIONAL PROMPTS (Move from train function) ======
    
    cond_prompts = [
        "Once there was a strong girl named Alyssa. She loved to lift weights. She",
        "One day, Casey was driving his car. He wanted to race with the police. He",
        "Lily wanted to get either a cat or a dog. Her mother didn't let her get a dog so instead she",
        "Once upon a time, there was a cat who got lost in the forest. One day,",
        "Terry saw a big red dog in the alley. He",
        "Poppy was extremely tired. Her mom told her to wash the dishes, but she just wanted to",
        "One day, Daniel went to the beach. He brought a",
        "Once there was a tiny cat named Bob. He wanted to eat the cookies on the counter, but",
    ]

    # ====== CALL MAIN TRAINING LOOP ======
    print("Starting main training loop...")
    train(
        model=model,
        tokenizer=tokenizer, 
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        device=device,
        optimizer=optimizer,
        scheduler=scheduler, 
        train_loss_list=train_losses, 
        val_loss_list=val_losses, 
        START_EPOCH=start_epoch, 
        EPOCHS=epochs, 
        SAVE_EVERY=save_every, 
        GENERATE_EVERY=generate_every, 
        PATH=path, 
        MODEL_NAME=current_model_name, 
        COMPUTE_PER_EPOCH=compute_per_epoch, 
        num_uncond_samples=num_uncond_samples,
        cond_prompts=cond_prompts, 
    )

    print("train_model function finished.")

In [None]:
# Load metric calculators from the evaluate library
try:
    bleu_metric = evaluate_model.load("bleu")
    rouge_metric = evaluate_model.load("rouge")
    meteor_metric = evaluate_model.load("meteor")
    print("Generation metrics loaded.")
except Exception as e:
    print(f"Error loading generation metrics: {e}")
    print("BLEU, ROUGE, METEOR metrics will not be computed.")
    bleu_metric = None
    rouge_metric = None
    meteor_metric = None

In [None]:
def compute_metrics(model, tokenizer, val_dataloader, device, model_pad_idx):
    """
    Computes validation loss and generation metrics (Perplexity, BLEU, ROUGE, METEOR)
    on the validation dataset for multiple generation strategies.

    Args:
        model (nn.Module): The model to evaluate.
        tokenizer: The tokenizer.
        val_dataloader (DataLoader): DataLoader for the validation data.
        device (torch.device): Device (cuda/cpu).
        model_pad_idx (int): The padding token ID used by the model and data.

    Returns:
        tuple: (bleu_results, rouge_results, meteor_results, perplexity_score).
               Returns default/inf values if metrics computation fails or metrics are not loaded.
               Each result is a dictionary or a float depending on the metric.
    """
    print("\n--- Computing Validation Metrics ---")
    start_time = time.perf_counter() 

    model.eval() 

    total_nll = 0.0  
    total_tokens = 0  

    all_generated_texts = []  
    all_reference_texts = [] 

    # Iterate through the validation data again specifically for metric calculation
    print("Processing validation batches for metrics...")
    with torch.no_grad():  
        for batch_idx, batch in enumerate(tqdm(val_dataloader, desc="Metric Batches")):
            try:
                # inputs are the shifted sequences (what the model receives)
                # targets are the true sequences (what the model should predict)
                inputs, targets = batch[0].to(device), batch[1].to(device)

                # --- 1. Perplexity Calculation ---
                logits, loss = model(inputs, targets)  # Loss is the average NLL per non-padding token in this batch
                if loss is not None:
                    non_padding_tokens_in_batch = (targets != model_pad_idx).sum().item()
                    if non_padding_tokens_in_batch > 0:
                        total_nll += loss.item() * non_padding_tokens_in_batch
                        total_tokens += non_padding_tokens_in_batch

                # --- 2. Generation Metrics for different sampling strategies ---
                # Greedy generation
                generated_tokens_greedy = model.generate(inputs, method="greedy", max_new_tokens=targets.shape[1] - inputs.shape[1] + 50)
                decoded_generated_greedy = tokenizer.batch_decode(generated_tokens_greedy, skip_special_tokens=True)

                # Top-k generation
                generated_tokens_topk = model.generate(inputs, method="top-k", k=5, max_new_tokens=targets.shape[1] - inputs.shape[1] + 50)
                decoded_generated_topk = tokenizer.batch_decode(generated_tokens_topk, skip_special_tokens=True)

                # Nucleus generation
                generated_tokens_nucleus = model.generate(inputs, method="nucleus", p_nucleus=0.9, max_new_tokens=targets.shape[1] - inputs.shape[1] + 50)
                decoded_generated_nucleus = tokenizer.batch_decode(generated_tokens_nucleus, skip_special_tokens=True)

                # Collect generated texts for all methods
                all_generated_texts.extend(decoded_generated_greedy)
                all_generated_texts.extend(decoded_generated_topk)
                all_generated_texts.extend(decoded_generated_nucleus)

                # Collect reference texts for metrics calculation (BLEU, ROUGE, METEOR)
                decoded_references = tokenizer.batch_decode(targets, skip_special_tokens=True)
                for ref in decoded_references:
                    all_reference_texts.extend([[ref]] * 3)  # One ref per generation method


            except Exception as e:
                print(f"Error processing batch {batch_idx} for metrics: {e}")
                continue

    # --- 3. Final Metric Calculation ---
    print("Calculating final metrics...")

    # Calculate Perplexity
    perplexity_score = math.exp(total_nll / total_tokens) if total_tokens > 0 else float('inf')

    # Calculate Generation Metrics using collected texts for all strategies
    bleu_results = {'bleu': 0.0} 
    rouge_results = {'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0, 'rougeLsum': 0.0}  
    meteor_results = {'meteor': 0.0} 

    if all_generated_texts and all_reference_texts:  # Ensure there's data to compute metrics on
        try:
            if bleu_metric is not None:
                bleu_results = bleu_metric.compute(predictions=all_generated_texts, references=all_reference_texts)
            if rouge_metric is not None:
                rouge_results = rouge_metric.compute(predictions=all_generated_texts, references=all_reference_texts)
            if meteor_metric is not None:
                meteor_results = meteor_metric.compute(predictions=all_generated_texts, references=[ref[0] for ref in all_reference_texts])  # Flatten references for METEOR

        except Exception as e:
            print(f"Error during metric computation: {e}")
            print("Generation metrics might be incomplete or incorrect.")

    eval_time = time.perf_counter() - start_time
    print(f"Metric computation finished in {eval_time:.3f}s.")

    # Return the calculated metrics
    return bleu_results, rouge_results, meteor_results, perplexity_score


In [None]:

set_seed() 

PATH = "/kaggle/working/experiments/experiment_1" # Define the specific output path for this run
print(f"Training output path: {PATH}") 

# --- Actual call to start the training run ---

print("Starting training run...")
train_model(
    n_head=N_HEAD,
    n_layer=N_LAYER,
    n_embd=N_EMBD,
    vocab_size=VOCAB_SIZE,
    seq_length=SEQ_LENGTH,
    batch_size=BATCH_SIZE,
    data_pct=DATA_PCT,
    path=PATH,
    checkpoint=CHECKPOINT, 
    load_epoch=LOAD_EPOCH, 
    epochs=EPOCHS,
    max_lr=MAX_LR,
    save_every=SAVE_EVERY,
    generate_every=GENERATE_EVERY,
    num_workers=4,
    compute_per_epoch=10,
    num_uncond_samples=4,
    cond_prompts=[ 
        "Once there was a strong girl named Alyssa. She loved to lift weights. She",
        "One day, Casey was driving his car. He wanted to race with the police. He",
        "Lily wanted to get either a cat or a dog. Her mother didn't let her get a dog so instead she",
        "Once upon a time, there was a cat who got lost in the forest. One day,",
        "Terry saw a big red dog in the alley. He",
        "Poppy was extremely tired. Her mom told her to wash the dishes, but she just wanted to",
        "One day, Daniel went to the beach. He brought a",
        "Once there was a tiny cat named Bob. He wanted to eat the cookies on the counter, but",
    ],
)
print("train_model call finished.")