# Fine-tune on High Quality Dataset

This notebook demonstrates how to:
1.  Process a new raw MIDI dataset using the existing tokenizer.
2.  Load a pretrained model.
3.  Fine-tune the model on the new dataset.

In [None]:
import os
import torch
from pathlib import Path
from miditok import REMI
from trainer import DecoderOnlyTransformer, Trainer, get_optimizer, get_scheduler, get_dataloaders

# Ensure we are in the right directory
# os.chdir('c:\\Users\\mayda\\Desktop\\music-autocomplete') # Uncomment if needed

In [None]:
# Configuration

class Config:
    # --- New Data Settings ---
    raw_data_dir = "hq_dataset_raw"       # Folder containing your new .mid files
    processed_data_dir = "hq_dataset_processed" # Folder to save tokenized .json files
    tokenizer_file = "lmd_matched_tokenizer.json" # Use the SAME tokenizer as pretrained model
    
    # --- Fine-tuning Settings ---
    pretrained_model_path = "trained_models/full_run_v2/model.pt"
    run_name = "finetune_hq_v1"
    
    # --- Training Hyperparameters ---
    # Usually lower LR for fine-tuning
    learning_rate = 1e-5 
    max_epochs = 10
    batch_size = 16
    
    # --- Model Architecture (MUST MATCH PRETRAINED MODEL) ---
    vocab_size = 5000
    block_size = 1024
    n_embed = 512
    n_head = 8
    n_blocks = 8
    dropout = 0.1
    
    # --- Other Settings ---
    num_songs = 1000 # Number of songs to use from the new dataset
    val_split = 0.1
    optimizer_type = 'adamw'
    weight_decay = 0.01
    adam_beta1 = 0.9
    adam_beta2 = 0.95
    momentum = 0.9
    grad_clip = 1.0
    scheduler = 'cosine'
    min_learning_rate = 1e-6
    lr_step_size = 10
    lr_gamma = 0.1
    lr_patience = 3
    checkpoint_path = f'trained_models/{run_name}/model.pt'
    load_checkpoint = False # We handle loading manually
    compile = False
    
    # Helper to map 'tokenized_dir' for get_dataloaders compatibility
    @property
    def tokenized_dir(self):
        return self.processed_data_dir

config = Config()

# Create output directories
os.makedirs(config.processed_data_dir, exist_ok=True)
os.makedirs(os.path.dirname(config.checkpoint_path), exist_ok=True)

In [None]:
# 1. Process Data (Tokenization)

def process_data(config):
    print(f"Checking for new data in {config.raw_data_dir}...")
    raw_path = Path(config.raw_data_dir)
    
    if not raw_path.exists():
        print(f"WARNING: Raw data directory '{config.raw_data_dir}' does not exist.")
        print("Please create it and put your .mid files there.")
        return False
        
    midi_files = list(raw_path.rglob("*.mid"))
    if not midi_files:
        print(f"WARNING: No .mid files found in '{config.raw_data_dir}'.")
        return False
        
    print(f"Found {len(midi_files)} MIDI files.")
    
    # Load Tokenizer
    if not os.path.exists(config.tokenizer_file):
        print(f"ERROR: Tokenizer file '{config.tokenizer_file}' not found.")
        return False
        
    print(f"Loading tokenizer from {config.tokenizer_file}...")
    tokenizer = REMI(params=config.tokenizer_file)
    
    # Tokenize
    print(f"Tokenizing files to {config.processed_data_dir}...")
    tokenizer.tokenize_dataset(
        files_paths=midi_files,
        out_dir=config.processed_data_dir
    )
    print("Tokenization complete.")
    return True

# Run processing
# Note: If you have already processed the data, you can skip this or it will overwrite/add to existing files
data_ready = process_data(config)

In [None]:
# 2. Initialize Model and Load Pretrained Weights

model = DecoderOnlyTransformer(
    vocab_size=config.vocab_size,
    n_embed=config.n_embed,
    n_head=config.n_head,
    n_blocks=config.n_blocks,
    block_size=config.block_size,
    dropout=config.dropout
)

print(f"Model initialized.")

if os.path.exists(config.pretrained_model_path):
    # Check for Git LFS pointer
    if os.path.getsize(config.pretrained_model_path) < 1024:
        print(f"WARNING: Pretrained model file is very small. It might be a Git LFS pointer.")
        print("Please run 'git lfs pull' to download the actual weights.")
    else:
        print(f"Loading pretrained weights from {config.pretrained_model_path}...")
        try:
            checkpoint = torch.load(config.pretrained_model_path, map_location='cpu')
            model.load_state_dict(checkpoint['model_state_dict'])
            print("Weights loaded successfully.")
        except Exception as e:
            print(f"Error loading weights: {e}")
else:
    print(f"WARNING: Pretrained model path {config.pretrained_model_path} does not exist.")
    print("Starting from scratch (random initialization).")

In [None]:
# 3. Setup Training

if data_ready or len(list(Path(config.processed_data_dir).glob("*.json"))) > 0:
    print("Loading DataLoaders...")
    train_loader, val_loader = get_dataloaders(config)
    
    total_steps = len(train_loader) * config.max_epochs
    optimizer = get_optimizer(model, config)
    scheduler = get_scheduler(optimizer, config, total_steps)

    trainer = Trainer(
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        train_loader=train_loader,
        val_loader=val_loader,
        config=config
    )
    
    print("Starting Fine-tuning...")
    trainer.train()
else:
    print("No processed data found. Cannot start training.")