In [None]:
# For Google Colab
# Upload the folder containing this file to google drive.
import sys, os
# Checking if the notebook is opened in google colab
#If YES, mount the google drive and change the directory
if 'google.colab' in sys.modules:

    # mount google drive
    from google.colab import drive
    drive.mount('/content/drive')

    # change path to the folder
    path = 'xxxxx/xxxxx'
    print(path)
    #os.chdir changes the current working directory
    os.chdir(path)
    !pwd

In [None]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils_gpt_ import *
from transformer_ import *
import random


In [None]:
def train_steps(model, train_loader, val_loader, optimizer, scheduler, device,
                total_steps=25000, log_interval=100, eval_interval=1000,
                save_path="trained_gpt_model.pt"):
    import random, math
    model.train()
    global_step = 0
    total_mean_loss, total_token_nll, total_tokens = 0.0, 0.0, 0

    all_idxs = list(range(len(train_loader)))
    best_val = float("inf")
    best_state = None

    while global_step < total_steps:
        step_idx = random.choice(all_idxs)
        input_ids, targets = train_loader.get_batch(step_idx)
        input_ids = input_ids.to(device)
        targets = targets.to(device)

        logits, loss = model(input_ids, targets)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        tokens_per_batch = targets.numel()
        total_mean_loss += loss.item()
        total_token_nll += loss.item() * tokens_per_batch
        total_tokens += tokens_per_batch
        global_step += 1

        if global_step % log_interval == 0:
            avg_mean = total_mean_loss / max(1, global_step)
            avg_token_true = total_token_nll / max(1, total_tokens)
            ppl = math.exp(avg_token_true)
            print(f"Step {global_step}/{total_steps} | LR {scheduler.get_last_lr()[0]:.6f} | "
                  f"Avg Train Loss (mean): {avg_mean:.4f} | NLL/token: {avg_token_true:.4f} | PPL: {ppl:.2f}")

        if global_step % eval_interval == 0 or global_step == total_steps:
            model.eval()
            fixed_val_indices = list(range(len(val_loader)))  # full val
            avg_val_loss, val_ppl = evaluate_on_fixed_subset(model, val_loader, device, fixed_val_indices)
            print(f"[Eval] Step {global_step} | Val NLL/token: {avg_val_loss:.4f} | Val PPL: {val_ppl:.2f}")
            if avg_val_loss < best_val:
                best_val = avg_val_loss
                best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            model.train()

    if best_state is not None:
        model.load_state_dict(best_state)
    torch.save({
        "model_config": model.get_model_config(),
        "model_state_dict": model.state_dict()
    }, save_path)
    print(f"Saved best model to {save_path} (best Val NLL/token {best_val:.4f})")


def run_step_training(config):
    device = get_device()

    # Data
    corpus_common = CorpusCommon(path=config['data_path'], add_bos=False, add_eos=True)
    dictionary = corpus_common.dictionary
    save_vocabulary_dict(dictionary, config['vocab_save_path'])

    train_loader = GPTDataLoader.from_tensor(corpus_common.train, dictionary, config['batch_size'], config['seq_len'])
    val_loader   = GPTDataLoader.from_tensor(corpus_common.valid, dictionary, config['batch_size'], config['seq_len'])
    print(f"Train stats: {train_loader.get_stats()}")
    print(f"Valid stats: {val_loader.get_stats()}")

    # Model
    model = GPT(
        vocab_size=len(dictionary),
        d_model=config['d_model'],
        num_heads=config['num_heads'],
        num_layers=config['num_layers'],
        d_ff=config['d_ff'],
        max_seq_length=config['max_seq_length'],
        dropout=config['dropout']
    ).to(device)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Optimizer (simple version; consider param groups to exclude bias/LayerNorm from weight decay)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=config['learning_rate'],
                                  weight_decay=config['weight_decay'])

    # Scheduler: base it on total_steps 
    # 25000 is the default fallback value (used only if total_steps is not defined)
    total_steps = config.get('total_steps', 25000)
    warmup_steps = max(1, int(config.get('warmup_ratio', 0.05) * total_steps))
    base_lr = config['learning_rate']
    min_lr = config.get('min_lr', 1e-6)

    if config.get('scheduler_type', 'cosine').lower() == 'linear':
        lambda_fn = get_linear_warmup_decay_lambda_with_floor(total_steps, warmup_steps, min_lr, base_lr)
    else:
        lambda_fn = get_cosine_warmup_decay_lambda_with_floor(total_steps, warmup_steps,
                                                              config.get('cosine_cycles', 0.5), min_lr, base_lr)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_fn)

    # Train by steps
    train_steps(model, train_loader, val_loader, optimizer, scheduler, device,
                total_steps=total_steps,
                log_interval=config.get('log_interval', 100),
                eval_interval=config.get('eval_interval', 1000),
                save_path=config.get('model_save_path', 'trained_gpt_model.pt'))

    return model

In [None]:
config = {
        'data_path': 'ptb',
        'vocab_save_path': 'vocabulary.txt',
        'model_save_path': 'trained_gpt_model.pt',
        'batch_size': 32,
        'seq_len': 256,
        'd_model': 256,
        'num_heads': 4,
        'num_layers': 4,
        'd_ff': 2048,
        'max_seq_length': 1024,
        'dropout': 0.95,
        'learning_rate': 1e-4,        # consider lowering from 1e-3
        'weight_decay': 0.01,
        'scheduler_type': 'cosine',
        'warmup_ratio': 0.05,
        'cosine_cycles': 0.5,
        'min_lr': 1e-6,
        'total_steps': 20000,
        'log_interval': 1000,
        'eval_interval': 1000,
    }
run_step_training(config)