# Partial Model Loading and Training

This notebook demonstrates how to load a pretrained model into the first $x$ layers of a larger model and train the remaining layers (or fine-tune the whole thing).

In [None]:
import os
import torch
import torch.nn as nn
from types import SimpleNamespace
from trainer import DecoderOnlyTransformer, Trainer, get_optimizer, get_scheduler, get_dataloaders

# Ensure we are in the right directory (if running from elsewhere)
# os.chdir('c:\\Users\\mayda\\Desktop\\music-autocomplete') # Uncomment if needed

In [None]:
def load_partial_weights(model, checkpoint_path, n_layers_to_load, freeze=False):
    """
    Loads weights from a checkpoint into the first `n_layers_to_load` blocks of the model.
    Also loads embeddings.
    
    Args:
        model: The target model (DecoderOnlyTransformer).
        checkpoint_path: Path to the pretrained checkpoint.
        n_layers_to_load: Number of blocks to load from the checkpoint.
        freeze: If True, freezes the loaded parameters.
    """
    if not os.path.exists(checkpoint_path):
        print(f"Checkpoint not found at {checkpoint_path}")
        return
    
    # Check if file is likely a Git LFS pointer
    if os.path.getsize(checkpoint_path) < 1024:
        print(f"WARNING: Checkpoint file {checkpoint_path} is very small ({os.path.getsize(checkpoint_path)} bytes).")
        print("It might be a Git LFS pointer file. Please run 'git lfs pull' to download the actual model.")
        return
        
    print(f"Loading weights from {checkpoint_path}...")
    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu') # Load to CPU first
    except Exception as e:
        print(f"Error loading checkpoint (is it a valid PyTorch file?): {e}")
        return

    state_dict = checkpoint['model_state_dict']
    
    # 1. Load Embeddings (Token & Positional)
    # We assume vocab_size and block_size are compatible or we only load what matches if we were being very robust,
    # but here we assume they match as per requirements.
    try:
        model.token_embedding.load_state_dict({'weight': state_dict['token_embedding.weight']})
        model.position_embedding.load_state_dict({'weight': state_dict['position_embedding.weight']})
        print("Loaded embeddings.")
    except Exception as e:
        print(f"Error loading embeddings: {e}")

    if freeze:
        model.token_embedding.weight.requires_grad = False
        model.position_embedding.weight.requires_grad = False
    
    # 2. Load Decoder Blocks
    # The blocks are in model.blocks, which is a Sequential.
    # State dict keys will look like 'blocks.0.sa.in_proj_weight', etc.
    
    loaded_blocks = 0
    for i in range(n_layers_to_load):
        if i >= len(model.blocks):
            print(f"Warning: Requested to load layer {i}, but model only has {len(model.blocks)} blocks. Stopping.")
            break
            
        block_prefix = f'blocks.{i}.'
        
        # Filter state_dict for this block
        block_state = {}
        for k, v in state_dict.items():
            if k.startswith(block_prefix):
                # Remove the prefix to match the block's internal state
                local_key = k[len(block_prefix):]
                block_state[local_key] = v
        
        if not block_state:
            print(f"Warning: No weights found for block {i} in checkpoint.")
            continue
            
        try:
            model.blocks[i].load_state_dict(block_state)
            loaded_blocks += 1
            
            if freeze:
                for param in model.blocks[i].parameters():
                    param.requires_grad = False
                    
        except Exception as e:
            print(f"Error loading block {i}: {e}")
            
    print(f"Successfully loaded {loaded_blocks} blocks.")
    if freeze:
        print(f"Frozen parameters for embeddings and first {loaded_blocks} blocks.")

In [None]:
# Configuration
# Adjust these parameters as needed

class Config:
    # Data
    tokenized_dir = "lmd_matched_processed"
    tokenizer_file = "lmd_matched_tokenizer.json"
    num_songs = 1000
    val_split = 0.1
    
    # Model Architecture (Target Model)
    vocab_size = 5000
    block_size = 1024
    n_embed = 512
    n_head = 16       # Target: 16 heads
    n_blocks = 16     # Target: 16 blocks
    dropout = 0.1
    
    # Training
    learning_rate = 3e-4
    max_epochs = 5
    batch_size = 32
    optimizer_type = 'adamw'
    weight_decay = 0.01
    adam_beta1 = 0.9
    adam_beta2 = 0.95
    momentum = 0.9
    grad_clip = 1.0
    
    # Scheduler
    scheduler = 'cosine'
    min_learning_rate = 3e-5
    lr_step_size = 10
    lr_gamma = 0.1
    lr_patience = 3
    
    # Checkpointing
    checkpoint_path = 'trained_models/partial_run/model.pt'
    load_checkpoint = False # We handle loading manually
    run_name = 'partial_run_16x16'
    compile = False
    
    # Partial Loading Settings
    pretrained_path = 'trained_models/full_run_v2/model.pt' # Path to source model
    n_layers_to_load = 8
    freeze_loaded = True

config = Config()

# Ensure checkpoint directory exists
os.makedirs(os.path.dirname(config.checkpoint_path), exist_ok=True)

In [None]:
# 1. Initialize the Model
model = DecoderOnlyTransformer(
    vocab_size=config.vocab_size,
    n_embed=config.n_embed,
    n_head=config.n_head,
    n_blocks=config.n_blocks,
    block_size=config.block_size,
    dropout=config.dropout
)

print(f"Model created with {config.n_blocks} blocks.")

In [None]:
# 2. Load Partial Weights
load_partial_weights(
    model, 
    config.pretrained_path, 
    n_layers_to_load=config.n_layers_to_load, 
    freeze=config.freeze_loaded
)

In [None]:
# 3. Setup Data, Optimizer, and Trainer
train_loader, val_loader = get_dataloaders(config)
total_steps = len(train_loader) * config.max_epochs

optimizer = get_optimizer(model, config)
scheduler = get_scheduler(optimizer, config, total_steps)

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config
)

In [None]:
# 4. Run Training
trainer.train()