In [1]:
import torch
from model.config import TSConfig
from model.modules.transformer import Transformer
import numpy as np
from utils import *
data_dir='data'
device='mps'
device_type='mps'

In [2]:
batch_size=4
config=TSConfig()
learning_rate = 1e-2
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
total_positions = len(train_data) - config.seq_len
total_batches = total_positions // batch_size

In [3]:
decoder_model=Transformer(config)
dm=decoder_model.to(device)

optimizer = torch.optim.AdamW(decoder_model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_batches, eta_min=1e-7)

In [4]:
# Define checkpoint paths and frequency
checkpoint_dir = "checkpoints"
checkpoint_frequency = 80000000
os.makedirs(checkpoint_dir, exist_ok=True)

# Try to load the latest checkpoint if exists
latest_checkpoint = os.path.join(checkpoint_dir, "latest_checkpoint.pt")

start_iter = 0
if os.path.exists(latest_checkpoint):
    try:
        start_iter = load_checkpoint(latest_checkpoint, decoder_model, optimizer)
        # Reset scheduler to the right step
        for _ in range(start_iter):
            optimizer.step()
            scheduler.step()
        print(f"Resuming training from iteration {start_iter}")
    except Exception as e:
        print(f"Failed to load checkpoint: {e}")

In [5]:
total_iter = start_iter
for batch_idx in range(start_iter, total_batches):
    total_iter += 1
    # Get sequential batch
    xb, yb = get_batch('train', batch_size, config.seq_len,device, batch_idx)
    # evaluate the loss
    logits, loss = decoder_model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    norm = torch.nn.utils.clip_grad_norm_(decoder_model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()
    
    # Print loss after every step
    print(f"batch {batch_idx}/{total_batches}: loss {loss.item():.4f}, lr: {scheduler.get_last_lr()[0]:.6f}")
    
    # Save checkpoint at regular intervals
    # if batch_idx % checkpoint_frequency == 0 or batch_idx == total_batches - 1:
    #     checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{batch_idx}.pt")
    #     save_checkpoint(decoder_model, optimizer, total_iter, checkpoint_path)
    #     # Also save as latest checkpoint
    #     save_checkpoint(decoder_model, optimizer, total_iter, latest_checkpoint)
    #     print(f"Saved checkpoint at iteration {total_iter}")
    if batch_idx ==50:
        break

batch 0/118116288: loss 75.0121, lr: 0.010000
batch 1/118116288: loss 72.6732, lr: 0.010000
batch 2/118116288: loss 69.8239, lr: 0.010000
batch 3/118116288: loss 65.9889, lr: 0.010000
batch 4/118116288: loss 61.7409, lr: 0.010000
batch 5/118116288: loss 58.7134, lr: 0.010000
batch 6/118116288: loss 56.6483, lr: 0.010000
batch 7/118116288: loss 56.2203, lr: 0.010000
batch 8/118116288: loss 54.1409, lr: 0.010000
batch 9/118116288: loss 53.2509, lr: 0.010000
batch 10/118116288: loss 49.7287, lr: 0.010000
batch 11/118116288: loss 47.4428, lr: 0.010000
batch 12/118116288: loss 46.3546, lr: 0.010000
batch 13/118116288: loss 44.1453, lr: 0.010000
batch 14/118116288: loss 41.9372, lr: 0.010000
batch 15/118116288: loss 40.1291, lr: 0.010000
batch 16/118116288: loss 37.4861, lr: 0.010000
batch 17/118116288: loss 36.1667, lr: 0.010000
batch 18/118116288: loss 35.6244, lr: 0.010000
batch 19/118116288: loss 33.8245, lr: 0.010000
batch 20/118116288: loss 32.8869, lr: 0.010000
batch 21/118116288: los