In [1]:
from model.config import GPTConfig,TSConfig
from model.modules.transformer import Transformer
import torch
import numpy as np
import os
import torch
import typing
from model.modules.embedding import Embedding
data_dir='data'
device='mps'
device_type='mps'

In [2]:
def save_checkpoint(model, optimizer, iteration, out):
    """
    Save model, optimizer, and iteration state to a file.
    
    Args:
        model: torch.nn.Module - The model to save
        optimizer: torch.optim.Optimizer - The optimizer to save
        iteration: int - Current iteration number
        out: str | os.PathLike | typing.BinaryIO | typing.IO[bytes] - Output file or path
    """
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'iteration': iteration
    }
    torch.save(checkpoint, out)


In [3]:
def load_checkpoint(src, model, optimizer):
    """
    Load model, optimizer state from a checkpoint file.
    
    Args:
        src: str | os.PathLike | typing.BinaryIO | typing.IO[bytes] - Input file or path
        model: torch.nn.Module - The model to load state into
        optimizer: torch.optim.Optimizer - The optimizer to load state into
        
    Returns:
        int: The iteration number saved in the checkpoint
    """
    checkpoint = torch.load(src)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['iteration']

In [4]:
batch_size=2
config=TSConfig()
learning_rate = 6e-4
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
total_positions = len(train_data) - config.seq_len
total_batches = total_positions // batch_size

In [6]:
total_batches

236232576

In [5]:
decoder_model=Transformer(config)
dm=decoder_model.to(device)

optimizer = torch.optim.AdamW(decoder_model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_batches, eta_min=1e-6)

In [6]:
# Define checkpoint paths and frequency
checkpoint_dir = "checkpoints"
checkpoint_frequency = 80000000
os.makedirs(checkpoint_dir, exist_ok=True)

# Try to load the latest checkpoint if exists
latest_checkpoint = os.path.join(checkpoint_dir, "latest_checkpoint.pt")

start_iter = 0
if os.path.exists(latest_checkpoint):
    try:
        start_iter = load_checkpoint(latest_checkpoint, decoder_model, optimizer)
        # Reset scheduler to the right step
        for _ in range(start_iter):
            optimizer.step()
            scheduler.step()
        print(f"Resuming training from iteration {start_iter}")
    except Exception as e:
        print(f"Failed to load checkpoint: {e}")


In [7]:
def get_batch(split, batch_size, seq_len, batch_idx=None):
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
    total_positions = len(data) - seq_len
    if batch_idx is None:
        ix = torch.randint(total_positions, (batch_size,))
    else:
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, total_positions)
        ix = torch.arange(start_idx, end_idx)
        if len(ix) < batch_size:
            padding = torch.zeros(batch_size - len(ix), dtype=torch.long)
            ix = torch.cat([ix, padding])
    x = torch.stack([torch.from_numpy((data[i:i+seq_len]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+seq_len]).astype(np.int64)) for i in ix])  
    if device_type == 'cuda':
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [8]:
# total_iter = 0
# for batch_idx in range(total_batches):
#     total_iter += 1
#     # Get sequential batch
#     xb, yb = get_batch('train', batch_size, config.seq_len, batch_idx)
    
#     # evaluate the loss
#     logits, loss = decoder_model(xb, yb)
#     optimizer.zero_grad(set_to_none=True)
#     loss.backward()
#     norm = torch.nn.utils.clip_grad_norm_(decoder_model.parameters(), 1.0)
#     optimizer.step()
#     scheduler.step()
#     # Print loss after every step
#     print(f"batch {batch_idx}/{total_batches}: loss {loss.item():.4f}")
#     break

In [9]:
total_iter = start_iter
for batch_idx in range(start_iter, total_batches):
    total_iter += 1
    # Get sequential batch
    xb, yb = get_batch('train', batch_size, config.seq_len, batch_idx)
    # evaluate the loss
    logits, loss = decoder_model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    norm = torch.nn.utils.clip_grad_norm_(decoder_model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()
    
    # Print loss after every step
    print(f"batch {batch_idx}/{total_batches}: loss {loss.item():.4f}, lr: {scheduler.get_last_lr()[0]:.6f}")
    
    # Save checkpoint at regular intervals
    if batch_idx % checkpoint_frequency == 0 or batch_idx == total_batches - 1:
        checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{batch_idx}.pt")
        save_checkpoint(decoder_model, optimizer, total_iter, checkpoint_path)
        # Also save as latest checkpoint
        save_checkpoint(decoder_model, optimizer, total_iter, latest_checkpoint)
        print(f"Saved checkpoint at iteration {total_iter}")
    break

torch.Size([2, 256, 512])
torch.Size([2, 256, 512])
torch.Size([2, 256, 512])
torch.Size([2, 256, 512])
batch 0/236232576: loss 78.6195, lr: 0.000600
Saved checkpoint at iteration 1


In [10]:
sum(p.numel() for p in decoder_model.parameters() if p.requires_grad)

22694400

In [12]:
def count_parameters(model, skip_embedding=True):
    """
    Count the total number of trainable parameters in a PyTorch model,
    excluding custom Embedding layers.
    
    Args:
        model: torch.nn.Module - The model to analyze
        skip_embedding: bool - Whether to exclude custom Embedding layers
        
    Returns:
        int: Total number of trainable parameters
        
    Example:
        model = Transformer(config)
        print(f"Non-embedding parameters: {count_parameters(model):,}")
        print(f"All parameters: {count_parameters(model, skip_embedding=False):,}")
    """
    if skip_embedding:
        return sum(p.numel() for name, p in model.named_parameters() 
                  if p.requires_grad and not any(isinstance(module, Embedding) 
                                              for name_prefix, module in model.named_modules() 
                                              if name.startswith(name_prefix)))
    else:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [14]:
count_parameters(decoder_model,skip_embedding=False)

22694400