# 6.4610 Research Project

## Overview
In this file, we implement a transformer model trained on OpenWebText

## Imports

In [None]:
import torch
import torch.nn as nn
from typing import Tuple, Union, Optional, List, Dict
import math
import numpy as np
from dataclasses import dataclass
from transformers import AutoTokenizer
import datasets
from tqdm import tqdm
import os
import json
from itertools import islice

## Preliminaries

In [None]:
train_dataset_size = 200000
test_dataset_size = 10000
last_trained_epoch = 0

@dataclass
class TransformerConfig:
    """Configuration class for transformer model"""
    vocab_size: int = 50257
    hidden_size: int = 768
    num_attention_heads: int = 12
    num_hidden_layers: int = 12
    intermediate_size: int = 3072
    max_position_embeddings: int = 512
    use_causal_mask: bool = True
    number_diffusion_kernels = 4

@dataclass
class TrainingConfig:
    # Model hyperparameters
    vocab_size: int = 50257
    hidden_size: int = 768
    num_attention_heads: int = 12
    num_hidden_layers: int = 12
    intermediate_size: int = 3072
    max_position_embeddings: int = 512
    use_causal_mask: bool = True

    # Training hyperparameters
    batch_size: int = 4
    learning_rate: float = 5e-4
    weight_decay: float = 0.01
    num_epochs: int = 3
    steps_per_epoch: int = 200000
    warmup_steps: int = 1000
    max_grad_norm: float = 1.0
    save_steps: int = 10000
    eval_steps: int = 5000
    train_dataset_size: int = 200000
    test_dataset_size: int = 10000

    # Paths
    output_dir: str = "/transformer"
    log_dir: str = "/logs_transformer"

In [None]:
class FeedForward(nn.Module):
    def __init__(self, hidden_size: int, intermediate_size: int):
        """
        Position-wise feed-forward network

        Args:
            hidden_size: Model dimension
            intermediate_size: Hidden dimension of FFN
            activation_fn: Activation function ('relu', 'gelu', etc.)
        """
        super().__init__()

        self.activation = nn.GELU()
        self.linear1 = nn.Linear(hidden_size, intermediate_size)
        self.linear2 = nn.Linear(intermediate_size, hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        return x

def count_parameters(model):
    """Count trainable parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class TrainingMetrics:
    """Track training metrics"""
    def __init__(self):
        self.losses = []
        self.learning_rates = []
        self.step = 0

    def update(self, loss: float, lr: float):
        self.losses.append(loss)
        self.learning_rates.append(lr)
        self.step += 1

    def get_avg_loss(self, last_n: int = 100):
        if len(self.losses) == 0:
            return 0.0
        return np.mean(self.losses[-last_n:])


# Custom dataset class for on-the-fly tokenization
class Dataset:
    def __init__(self, dataset, tokenizer, max_length=512, max_samples=None):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.max_samples = max_samples

    def __iter__(self):
        count = 0

        for each in self.dataset:
            text = each['text']

            # Tokenize on the fly
            encoded = self.tokenizer(
                text,
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'  # Return tensors for direct use
            )

            # Create labels (same as input_ids for causal language modeling)
            labels = encoded['input_ids'].clone()

            yield {
                'input_ids': encoded['input_ids'].squeeze(0),  # Remove batch dimension
                'labels': labels.squeeze(0)  # Remove batch dimension
            }

            count += 1
            if self.max_samples is not None and count >= self.max_samples:
                break


def evaluate_model(model, tokenizer, test_prompts: List[str], temperature: float = 0.7):
    """Evaluate model with test prompts"""
    model.eval()

    # Get device from model parameters
    device = next(model.parameters()).device

    print("Generating samples from trained model:")
    print("=" * 60)

    for i, prompt in enumerate(test_prompts):
        print(f"\nPrompt {i+1}: '{prompt}'")
        print("-" * 40)

        # Tokenize prompt
        input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

        # Generate with different temperatures
        with torch.no_grad():
            generated_ids = model.generate(
                input_ids,
                max_new_tokens=150,
                temperature=temperature
            )

            generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            print(f"Temperature {temperature}: {generated_text}")
            print()

def calculate_perplexity(model, tokenizer, test_prompts_iter: iter, batch_size: int = 16, max_length: int = 512):
    """
    Calculate perplexity of model with test prompts in a batched way.
    Applies attention mask so that only valid tokens contribute to the loss.
    """
    model.eval()
    device = next(model.parameters()).device

    print("Calculating perplexity of model with test prompts (batched):")
    print("=" * 60)

    test_prompts = [sample['text'] for sample in test_prompts_iter]

    # Tokenize all prompts at once (batched)
    encodings = tokenizer(
        test_prompts,
        max_length=max_length,
        truncation=True,
        padding='longest',
        return_tensors='pt'
    )
    input_ids = encodings['input_ids'].to(device)
    attention_mask = encodings.get('attention_mask', None)
    if attention_mask is not None:
        attention_mask = attention_mask.to(device)

    num_samples = input_ids.size(0)
    total_loss = 0.0
    total_tokens = 0

    with torch.no_grad():
        for start in tqdm(range(0, num_samples, batch_size)):
            end = min(start + batch_size, num_samples)
            batch_input_ids = input_ids[start:end]

            # Forward pass
            logits = model(batch_input_ids)
            if isinstance(logits, tuple):
                logits = logits[0]

            # Shift logits and labels for causal language modeling
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = batch_input_ids[..., 1:].contiguous()
            if attention_mask is not None:
                batch_attention_mask = attention_mask[start:end]
                shift_mask = batch_attention_mask[..., 1:].contiguous()
            else:
                shift_mask = None

            # Flatten for loss computation
            loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )  # (batch * seq_len-1,)

            if shift_mask is not None:
                loss = loss * shift_mask.view(-1).float()
                num_valid = shift_mask.sum().item()
            else:
                num_valid = shift_labels.numel()

            total_loss += loss.sum().item()
            total_tokens += num_valid

    avg_loss = total_loss / max(1, total_tokens)
    avg_perplexity = float(torch.exp(torch.tensor(avg_loss)))
    print(f"Average Perplexity: {avg_perplexity:.4f}")
    return avg_perplexity


def save_model(model, tokenizer, optimizer, scheduler, save_path: str):
    """Save model and tokenizer"""
    os.makedirs(save_path, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(save_path, "model.pt"))
    torch.save(optimizer.state_dict(), os.path.join(save_path, "optimizer.pt"))
    torch.save(scheduler.state_dict(), os.path.join(save_path, "scheduler.pt"))
    if hasattr(model, 'config'):
        torch.save(model.config.__dict__, os.path.join(save_path, "config.json"))
    tokenizer.save_pretrained(save_path)
    print(f"Model saved to {save_path}")

def load_model(model, load_path: str):
    """Load model weights"""
    model.load_state_dict(torch.load(os.path.join(load_path, "model.pt")))
    print(f"Model loaded from {load_path}")

## RMSNorm

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        """
        RMS Normalization

        Args:
            hidden_size: The size of the hidden dimension
            eps: Small constant for numerical stability
        """
        super().__init__()
        self.parameter = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Apply RMS normalization

        Args:
            hidden_states: Input tensor of shape (batch_size, seq_len, hidden_size)

        Returns:
            Normalized tensor of shape (batch_size, seq_len, hidden_size)
        """
        rms = torch.sqrt(torch.mean(torch.square(hidden_states), dim=-1, keepdim=True) + self.eps)
        normalized = hidden_states / rms
        return normalized * self.parameter


## Attention Mechanisms



### Single Attention Head

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, hidden_size: int, head_dim: int):
        """
        Single attention head implementation

        Args:
            hidden_size: Input dimension
            head_dim: Dimension of each attention head
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.head_dim = head_dim
        self.WQ = nn.Linear(hidden_size, head_dim, bias=False)
        self.WK = nn.Linear(hidden_size, head_dim, bias=False)
        self.WV = nn.Linear(hidden_size, head_dim, bias=False)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass for attention head

        Args:
            x: Input tensor (batch_size, seq_len, hidden_size)
            attn_mask: Attention mask (batch_size, seq_len, seq_len) - 1 for attend, 0 for mask

        Returns:
            attention_output: (batch_size, seq_len, head_dim)
            attention_weights: (batch_size, seq_len, seq_len)
        """
        Q = self.WQ(x)
        V = self.WV(x)
        K = self.WK(x)
        score = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if attn_mask is not None:
          score = score.masked_fill(attn_mask == 0, -torch.inf)
        attention_weights = torch.softmax(score, dim=-1)
        attention_output = torch.matmul(attention_weights, V)
        return attention_output, attention_weights

### Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int):
        """
        Multi-head attention implementation

        Args:
            hidden_size: Model dimension
            num_heads: Number of attention heads
        """
        super().__init__()
        assert hidden_size % num_heads == 0, f"The hidden size {hidden_size} is not divisible by the number of heads {num_heads}."
        head_dim = hidden_size // num_heads
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.heads = nn.ModuleList([AttentionHead(hidden_size, head_dim) for _ in range(num_heads)])
        self.linear = nn.Linear(hidden_size, hidden_size, bias=False)

    def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass for multi-head attention

        Args:
            hidden_states: Input tensor (batch_size, seq_len, hidden_size)
            attention_mask: Attention mask (batch_size, seq_len, seq_len)

        Returns:
            attention_output: (batch_size, seq_len, hidden_size)
            attention_weights: (batch_size, num_heads, seq_len, seq_len)
        """
        outputs = [each(hidden_states, attention_mask) for each in self.heads]
        attention_outputs_tuple = tuple(each[0] for each in outputs)
        attention_outputs = torch.stack(attention_outputs_tuple).transpose(0, 1).transpose(1, 2).flatten(2, 3)
        attention_weights_tuple = tuple(each[1] for each in outputs)
        attention_weights = torch.stack(attention_weights_tuple).transpose(0, 1)
        attention_outputs = self.linear(attention_outputs)
        return attention_outputs, attention_weights

## Transformer Block

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, intermediate_size: int):
        """
        Complete transformer block with attention and feed-forward

        Args:
            hidden_size: Model dimension
            num_heads: Number of attention heads
            intermediate_size: FFN hidden dimension
        """
        super().__init__()
        self.rms_att = RMSNorm(hidden_size=hidden_size)
        self.rms_ffn = RMSNorm(hidden_size=hidden_size)
        self.mha = MultiHeadAttention(hidden_size=hidden_size, num_heads=num_heads)
        self.ffn = FeedForward(hidden_size=hidden_size, intermediate_size=intermediate_size)

    def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass for transformer block

        Args:
            hidden_states: Input tensor (batch_size, seq_len, hidden_size)
            attention_mask: Attention mask

        Returns:
            hidden_states: Output tensor (batch_size, seq_len, hidden_size)
        """
        att_norm = self.rms_att(hidden_states)
        self_att = self.mha(att_norm, attention_mask)[0]
        res_conn_self_att = self_att + hidden_states
        ffn_norm = self.rms_ffn(res_conn_self_att)
        ffn_output = self.ffn(ffn_norm)
        res_conn_ffn = res_conn_self_att + ffn_output
        return res_conn_ffn

## Complete Transformer Model



### `create_causal_mask`

In [None]:
def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
    """Create a causal (lower triangular) attention mask

    Args:
        seq_len: Sequence length
        device: Device to create the mask on

    Returns:
        Causal mask of shape (1, seq_len, seq_len)
    """
    mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
    return mask.unsqueeze(0)

### TransformerModel


In [None]:
class TransformerModel(nn.Module):
    def __init__(self, config: TransformerConfig):
        """
        Complete transformer model for causal language modeling
        """
        super().__init__()
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.embeddings = nn.Embedding(num_embeddings=self.config.vocab_size, embedding_dim=self.config.hidden_size)
        self.pos_embeddings = nn.Embedding(num_embeddings=self.config.max_position_embeddings, embedding_dim=self.config.hidden_size)
        self.transformer = nn.ModuleList([TransformerBlock(hidden_size=self.config.hidden_size, num_heads=self.config.num_attention_heads, intermediate_size=self.config.intermediate_size)
         for _ in range(self.config.num_hidden_layers)])
        self.norm = RMSNorm(hidden_size=self.config.hidden_size)

    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass for transformer model

        Args:
            input_ids: Token IDs (batch_size, seq_len)
            attention_mask: Attention mask (batch_size, seq_len, seq_len)

        Returns:
            hidden_states: Final hidden states (batch_size, seq_len, hidden_size)
        """
        positions = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand(input_ids.size(0), -1)
        pos_embeds = self.pos_embeddings(positions)
        token_embeddings = self.embeddings(input_ids) + pos_embeds
        if attention_mask is None and self.config.use_causal_mask:
          attention_mask = create_causal_mask(input_ids.shape[1], token_embeddings.device)
        transf = token_embeddings
        for layer in self.transformer:
          transf = layer(transf, attention_mask=attention_mask)
        output = self.norm(transf)
        return output

### CausalLanguageModel

In [None]:
class CausalLanguageModel(nn.Module):
    def __init__(self, config: TransformerConfig):
        """Causal language model with transformer backbone"""
        super().__init__()
        self.config = config
        self.transformer = TransformerModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Forward pass for language model

        Args:
            input_ids: Token IDs (batch_size, seq_len)
            labels: Target labels for loss computation (batch_size, seq_len)

        Returns:
            If labels provided: (loss, logits)
            Else: logits only
        """
        hidden_states = self.transformer(input_ids)
        logits = self.lm_head(hidden_states)
        if labels is not None:
            logits_flat = logits[:, :-1, :].flatten(0, 1)
            labels_flat = labels[:, 1:].flatten(0, 1)
            return self.criterion(logits_flat, labels_flat), logits
        return logits

    @torch.no_grad()
    def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100, temperature: float = 1.0) -> torch.Tensor:
        """
        Generate text using the language model

        Args:
            input_ids: Starting token IDs (batch_size, seq_len)
            max_new_tokens: Maximum number of tokens to generate
            temperature: Sampling temperature

        Returns:
            Generated token IDs (batch_size, seq_len + max_new_tokens)
        """
        for _ in range(max_new_tokens):
            logits = self.forward(input_ids)[:, -1, :] / temperature
            probs = nn.functional.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            input_ids = torch.cat([input_ids, next_token], dim=1)
        return input_ids


## Transformer Training

In [None]:
print("Loading OpenWebText dataset...")
dataset = datasets.load_dataset("openwebtext", split="train", streaming=True, trust_remote_code=True)

In [None]:
# Load a pre-trained tokenizer (GPT-2 tokenizer works well for English text)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token  # Set padding token

print(f"Tokenizer loaded with vocab size: {tokenizer.vocab_size}")
print(f"Special tokens: PAD={tokenizer.pad_token_id}, EOS={tokenizer.eos_token_id}")

test_iter = islice(dataset, 0, test_dataset_size)
train_iter = islice(dataset, test_dataset_size, train_dataset_size+test_dataset_size)

train_dataset = Dataset(train_iter, tokenizer, max_length=1024, max_samples=train_dataset_size)

In [None]:
training_config = TrainingConfig(vocab_size=tokenizer.vocab_size)
# Create model config and initialize model
model_config = TransformerConfig(
    vocab_size=training_config.vocab_size,
    hidden_size=training_config.hidden_size,
    num_attention_heads=training_config.num_attention_heads,
    num_hidden_layers=training_config.num_hidden_layers,
    intermediate_size=training_config.intermediate_size,
    max_position_embeddings=training_config.max_position_embeddings
)


# Initialize model
model = CausalLanguageModel(model_config)
print(f"Model initialized with {count_parameters(model):,} parameters")

def get_lr_scheduler(optimizer, warmup_steps: int, total_steps: int):
    """Get learning rate scheduler with warmup"""
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return max(0.0, float(total_steps - current_step) / float(max(1, total_steps - warmup_steps)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

## Training Loop Implementation

In [None]:
class Trainer:
    def __init__(self, model, train_dataset, tokenizer, config: TrainingConfig):
        self.model = model
        self.tokenizer = tokenizer
        self.config = config

        # Setup device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        print(f"Using device: {self.device}")

        # Setup dataset for on-the-fly tokenization
        self.train_dataset = train_dataset
        self.batch_size = config.batch_size

        # Setup optimizer and scheduler
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )

        self.total_steps = (self.config.steps_per_epoch // self.batch_size) * config.num_epochs
        self.scheduler = get_lr_scheduler(
            self.optimizer,
            config.warmup_steps,
            self.total_steps
        )

        # Metrics
        self.metrics = TrainingMetrics()
        self.global_step = 0
        os.makedirs(config.output_dir, exist_ok=True)
        os.makedirs(config.log_dir, exist_ok=True)

        self.log_file = os.path.join(config.log_dir, "training_log.jsonl")
        with open(self.log_file, "w") as f:
            f.write(json.dumps({"event": "training_start"}) + "\n")

    def train_step(self, batch) -> float:
        """
        Single training step

        Args:
            batch: Batch of data

        Returns:
            loss: Training loss for this step
        """
        input_ids = batch["input_ids"]
        labels = batch["labels"]
        input_ids = input_ids.to(self.device)
        labels = labels.to(self.device)
        self.optimizer.zero_grad()
        loss, output = self.model(input_ids, labels)
        loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
        self.optimizer.step()
        self.scheduler.step()
        return loss.item()

    def evaluate_step(self, batch) -> float:
        """Evaluation step"""
        self.model.eval()

        with torch.no_grad():
            input_ids = batch['input_ids'].to(self.device)
            labels = batch['labels'].to(self.device)

            loss, logits = self.model(input_ids, labels)
            return loss.item()


    def train(self, start_epoch = 0):
        """Main training loop"""
        print(f"Starting training for {self.config.num_epochs} epochs")
        print(f"Total steps: {self.total_steps}")
        print(f"Warmup steps: {self.config.warmup_steps}")
        self.model.train()
        for epoch in range(start_epoch, self.config.num_epochs):
            print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}")

            epoch_loss = 0
            num_batches = self.config.steps_per_epoch // self.batch_size

            dataset = datasets.load_dataset("openwebtext", split="train", streaming=True)
            train_slice = islice(dataset, self.config.test_dataset_size, self.config.test_dataset_size + self.config.train_dataset_size)
            train_dataset = Dataset(train_slice, tokenizer, max_length=512, max_samples=self.config.train_dataset_size)
            train_iter = iter(train_dataset)

            progress_bar = tqdm(range(num_batches), desc=f"Epoch {epoch + 1}")

            for batch_idx in progress_bar:
                batch = {
                    'input_ids': [],
                    'labels': []
                }

                for _ in range(self.batch_size):
                    sample = next(train_iter)
                    batch['input_ids'].append(sample['input_ids'])
                    batch['labels'].append(sample['labels'])


                batch['input_ids'] = torch.stack(batch['input_ids'])
                batch['labels'] = torch.stack(batch['labels'])

                loss = self.train_step(batch)
                epoch_loss += loss

                current_lr = self.scheduler.get_last_lr()[0]
                self.metrics.update(loss, current_lr)
                self.global_step += 1

                progress_bar.set_postfix({
                    'loss': f'{loss:.4f}',
                    'avg_loss': f'{self.metrics.get_avg_loss():.4f}',
                    'lr': f'{current_lr:.2e}'
                })

                log_entry = {
                    "step": self.global_step,
                    "epoch": epoch + 1,
                    "batch_idx": batch_idx,
                    "loss": float(loss),
                    "avg_loss": float(self.metrics.get_avg_loss()),
                    "learning_rate": float(current_lr),
                }
                with open(self.log_file, "a") as f:
                    f.write(json.dumps(log_entry) + "\n")

                # Evaluate model
                if self.global_step % self.config.eval_steps == 0:
                    print(f"\nEvaluating model at step {self.global_step}:")
                    print("-" * 50)
                    evaluate_model(self.model, self.tokenizer, ["Once upon a time", "The little girl"])
                    print("-" * 50)
                    checkpoint_path = os.path.join(
                        self.config.output_dir,
                        f"checkpoint-{self.global_step}"
                    )
                    save_model(self.model, self.tokenizer, self.optimizer, self.scheduler, checkpoint_path)
                    self.model.train()

            avg_epoch_loss = epoch_loss / num_batches
            print(f"Epoch {epoch + 1} completed. Average loss: {avg_epoch_loss:.4f}")
            epoch_summary = {
                "event": "epoch_end",
                "epoch": epoch + 1,
                "avg_epoch_loss": float(avg_epoch_loss)
            }
            with open(self.log_file, "a") as f:
                f.write(json.dumps(epoch_summary) + "\n")
            checkpoint_path = os.path.join(
                self.config.output_dir,
                f"checkpoint-epoch-{epoch+1}"
            )
            save_model(self.model, self.tokenizer, self.optimizer, self.scheduler, checkpoint_path)
        save_model(self.model, self.tokenizer, self.optimizer, self.scheduler, self.config.output_dir)
        print("Training completed!")
        self.model.eval()

## Training Execution

In [None]:
# Initialize trainer
trainer = Trainer(model, train_dataset, tokenizer, training_config)

# Print model info
print(f"Model has {count_parameters(model):,} trainable parameters")

if last_trained_epoch > 0:
    CHECKPOINT_PATH = os.path.join(
        training_config.output_dir,
        f"checkpoint-epoch-{last_trained_epoch}"
    )

    model_file = os.path.join(CHECKPOINT_PATH, "model.pt")
    optimizer_file = os.path.join(CHECKPOINT_PATH, "optimizer.pt")
    scheduler_file = os.path.join(CHECKPOINT_PATH, "scheduler.pt")

    if os.path.exists(model_file):
        trainer.model.load_state_dict(torch.load(model_file, map_location=trainer.device))
        trainer.optimizer.load_state_dict(torch.load(optimizer_file, map_location=trainer.device))
        trainer.scheduler.load_state_dict(torch.load(scheduler_file, map_location=trainer.device))

        steps_per_epoch = training_config.train_dataset_size // training_config.batch_size
        start_global_step = last_trained_epoch * steps_per_epoch
        trainer.global_step = start_global_step
    else:
        print(f"\n Checkpoint not found at {CHECKPOINT_PATH}. Starting training from scratch.")

# Start training
trainer.train(last_trained_epoch)

## Model Evaluation and Generation

In [None]:
#trainer = Trainer(model, train_dataset, tokenizer, training_config)
#model_dir = "/transformer"
#load_model(model, model_dir)

test_prompts = [
    "Once upon a time",
    "The little girl",
    "In a magical forest",
    "Every morning"
]

# Evaluate the trained model
# evaluate_model(trainer.model, tokenizer, test_prompts)

# Calculate perplexity
print("\n" + "="*60)
print("PERPLEXITY EVALUATION")
print("="*60)
avg_perplexity = calculate_perplexity(model, tokenizer, test_iter)
print(f"\nFinal Average Perplexity: {avg_perplexity:.4f}")