# Imports

In [19]:
import os
import json
import time
import torch
import random
import logging
import numpy as np

from torch import nn
from tqdm import tqdm
from torch.optim import AdamW
from transformers import get_scheduler
from torch.utils.data import DataLoader, Dataset
from tokenizers import Tokenizer, models, trainers, pre_tokenizers

# Setup

In [20]:
def set_seed(seed=42):

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()

logging.basicConfig(
    
    filename="training.log",
    filemode='a',

    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.INFO)

# Settings

In [None]:
VOCAB_SIZE = 30000  
EMBED_DIM = 256
NUM_LAYERS = 8 
NUM_HEADS = 8 

HIDDEN_DIM = 512
BATCH_SIZE = 16
SEQ_LEN = 128  

EPOCHS = 100
LEARNING_RATE = 5e-4

GRAD_ACCUM_STEPS = 1
WARMUP_STEPS = 500
SAVE_EVERY = 10

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nWorking on : {DEVICE}\n")

# Data Dictory
MODEL_SAVE_DIR = "../Models"

STAGE_DATA_PATHS = {

    "grammar": "../Data/grammar.txt",
    "easy": "../Data/easy.txt",

    "intermediate": "../Data/intermediate.txt",
    "advanced": "../Data/advanced.txt"
}


Working on : cuda



# Loading the Data

In [22]:
class TextDataset(Dataset):

    def __init__(self, file_path, tokenizer, seq_len=SEQ_LEN):

        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()

        self.tokens = tokenizer.encode(text).ids
        self.seq_len = seq_len

    def __len__(self):
        return len(self.tokens) // self.seq_len

    def __getitem__(self, idx):

        start = idx * self.seq_len
        end = start + self.seq_len
        input_ids = self.tokens[start:end]

        target_ids = input_ids[1:] + [0]
        
        return torch.tensor(input_ids), torch.tensor(target_ids)

# Training the Tokeniser

In [23]:
def train_tokenizer(data_path, vocab_size):

    with open(data_path, 'r', encoding='utf-8') as f:
        text = f.read()

    tokenizer = Tokenizer(models.BPE())
    tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
    trainer = trainers.BpeTrainer(vocab_size=vocab_size, special_tokens=["<pad>", "<unk>"])

    tokenizer.train_from_iterator([text], trainer=trainer)

    return tokenizer

# Building the LLM Model

In [24]:
class SimpleTransformer(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_layers, seq_len):

        super(SimpleTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(seq_len, embed_dim)

        self.transformer_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads,
                dim_feedforward=hidden_dim) for _ in range(num_layers)])
        
        self.norm = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):

        positions = torch.arange(0, x.size(1), device=x.device).unsqueeze(0)

        x = self.embedding(x) + self.position_embedding(positions)
        x = x.transpose(0, 1)

        mask = torch.triu(torch.ones(x.size(0), x.size(0), device=x.device) * float('-inf'), diagonal=1)

        for block in self.transformer_blocks:
            x = block(x, src_mask=mask)

        x = x.transpose(0, 1)
        x = self.norm(x)
        
        return self.fc(x)

# Saving the Model

In [25]:
def save_checkpoint(model, optimizer, scheduler, epoch, loss, model_dir, name_prefix):

    checkpoint = {

        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),

        'epoch': epoch,
        'loss': loss}
    
    filename = os.path.join(model_dir, f"{name_prefix}_checkpoint_epoch{epoch}.pth")

    torch.save(checkpoint, filename)
    logging.info(f"Checkpoint saved: {filename}")

def save_model_settings(model, model_save_dir, settings, stage_name, epoch):

    model_filename = f"{stage_name}_epoch_{epoch}.pth"
    settings_filename = f"{stage_name}_epoch_{epoch}.json"

    torch.save(model.state_dict(), os.path.join(model_save_dir, model_filename))

    with open(os.path.join(model_save_dir, settings_filename), "w") as f:
        json.dump(settings, f, indent=2)
        
    logging.info(f"Model saved: {model_filename}")

# Training the Model

In [26]:
def train_on_stage(stage_name, data_path, model_settings):

    print(f"\n--- Training Stage: {stage_name.upper()} ---")

    tokenizer = train_tokenizer(data_path, VOCAB_SIZE)

    dataset = TextDataset(data_path, tokenizer)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    model = SimpleTransformer(

        vocab_size=VOCAB_SIZE,
        embed_dim=EMBED_DIM,
        num_heads=NUM_HEADS,

        hidden_dim=HIDDEN_DIM,
        num_layers=NUM_LAYERS,

        seq_len=SEQ_LEN).to(DEVICE)

    loss_fn = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

    num_training_steps = len(dataloader) * EPOCHS
    lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=num_training_steps)

    model.train()
    max_grad_norm = 1.0

    patience = 5
    trigger_times = 0
    best_loss = float('inf')

    for epoch in range(EPOCHS):

        model.train()
        epoch_loss = 0
        loop = tqdm(dataloader, leave=True)

        for batch in loop:

            input_ids, target_ids = batch
            input_ids = input_ids.to(DEVICE)
            target_ids = target_ids.to(DEVICE)

            outputs = model(input_ids)
            loss = loss_fn(outputs.view(-1, VOCAB_SIZE), target_ids.view(-1))

            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            optimizer.step()
            lr_scheduler.step()

            epoch_loss += loss.item()

            loop.set_description(f"[{stage_name}] Epoch {epoch}")
            loop.set_postfix(loss=loss.item())

        avg_loss = epoch_loss / len(dataloader)
        print(f"Stage: {stage_name} - Epoch {epoch} - Average loss: {avg_loss:.4f}")

        if avg_loss < best_loss:

            best_loss = avg_loss
            trigger_times = 0

            save_model_settings(model, MODEL_SAVE_DIR, model_settings, stage_name, epoch)
            print("New Best loss, model saved\n")
            
        else:
            trigger_times += 1
            print(f"No increase. Patience remaining: {trigger_times}/{patience}")

            if trigger_times >= patience:
                print("Early stopping triggered.")
                break

        if (epoch + 1) % SAVE_EVERY == 0:
            save_checkpoint(model, optimizer, lr_scheduler, epoch + 1, avg_loss, MODEL_SAVE_DIR, stage_name)

# Start

In [27]:
model_settings = {
    
    "VOCAB_SIZE": VOCAB_SIZE,
    "EMBED_DIM": EMBED_DIM,

    "NUM_LAYERS": NUM_LAYERS,
    "NUM_HEADS": NUM_HEADS,
    "HIDDEN_DIM": HIDDEN_DIM,

    "BATCH_SIZE": BATCH_SIZE,
    "SEQ_LEN": SEQ_LEN,
    "EPOCHS": EPOCHS,

    "LEARNING_RATE": LEARNING_RATE}

for stage, path in STAGE_DATA_PATHS.items():
    train_on_stage(stage, path, model_settings)


--- Training Stage: GRAMMAR ---


[grammar] Epoch 0:  25%|██▍       | 209/840 [00:24<01:15,  8.37it/s, loss=6.51]


KeyboardInterrupt: 