In [None]:
import torch

from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers

import model

from chat import save_checkpoint, load_checkpoint
from model import GPTLanguageModel, learning_rate, max_iters, eval_interval, eval_iters, device, block_size, batch_size

In [None]:
tokenizer = Tokenizer.from_file("./tokenizer.json")

In [None]:
train_size = 0.9
data = text2tensor(text)

model.vocab_size = tokenizer.get_vocab_size()

n = int(train_size * len(data))
train_data = data[:n]
val_data = data[n:]

In [None]:
m = GPTLanguageModel()
m = m.to(device)

In [None]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

@torch.no_grad()
def estimate_loss():

    out = {}
    m.eval()
    
    for split in ['train', 'val']:
    
        losses = torch.zeros(eval_iters)
    
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = m(X, Y)
            losses[k] = loss.item()
    
        out[split] = losses.mean()
    
    m.train()
    return out

In [None]:
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)
optimizer

In [None]:
# --- CONFIGURATION ---
resume_training = False  # Set to True if continuing from a previous run/dataset
checkpoint_path = "model_checkpoint.pth"
best_model_path = "best_model.pth"
best_val_loss = float('inf')
start_iter = 0

if resume_training:
    start_iter = load_checkpoint(checkpoint_path, m, optimizer)

print(f"Starting training from iteration {start_iter}...")

for iter in range(start_iter, max_iters):

    if iter % eval_interval == 0 or iter == max_iters - 1:

        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        
        if losses['val'] < best_val_loss:
            best_val_loss = losses['val']
            save_checkpoint(iter, m, optimizer, losses['val'], best_model_path)
            print(f"   (New best model saved!)")

        save_checkpoint(iter, m, optimizer, losses['val'], checkpoint_path)

    xb, yb = get_batch('train')
    
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

In [None]:
state = m.state_dict()
torch.save(state, './models/final.pth')