# ChordFormer Training on Google Colab

This notebook trains the ChordFormer chord recognition model using pre-extracted features from Google Drive.

## 1. Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 2. Set Paths

**Update these paths to match your Google Drive structure!**

In [None]:
# UPDATE THESE PATHS to match your Google Drive folder structure
DRIVE_BASE = '/content/drive/MyDrive'  # Base path to your Drive

# Path to your features folder (containing song_*.npz files and normalization.json)
FEATURES_DIR = f'{DRIVE_BASE}/features'

# Path to save checkpoints (will be created if doesn't exist)
CHECKPOINTS_DIR = f'{DRIVE_BASE}/checkpoints'

# Path to the chord_recognition code folder
CODE_DIR = f'{DRIVE_BASE}/chord_recognition'

print(f'Features dir: {FEATURES_DIR}')
print(f'Checkpoints dir: {CHECKPOINTS_DIR}')
print(f'Code dir: {CODE_DIR}')

## 3. Verify Paths and Check Data

In [None]:
import os

# Check features directory
if os.path.exists(FEATURES_DIR):
    files = os.listdir(FEATURES_DIR)
    npz_files = [f for f in files if f.endswith('.npz')]
    print(f'✓ Features directory found: {len(npz_files)} .npz files')
    
    # Check for normalization.json
    if 'normalization.json' in files:
        print('✓ normalization.json found')
    else:
        print('✗ WARNING: normalization.json not found!')
else:
    print(f'✗ ERROR: Features directory not found at {FEATURES_DIR}')
    print('Please update FEATURES_DIR path above')

# Check/create checkpoints directory
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
print(f'✓ Checkpoints directory ready: {CHECKPOINTS_DIR}')

# Check code directory
if os.path.exists(CODE_DIR):
    print(f'✓ Code directory found')
    if os.path.exists(f'{CODE_DIR}/train.py'):
        print('✓ train.py found')
else:
    print(f'✗ ERROR: Code directory not found at {CODE_DIR}')

## 4. Install Dependencies

In [None]:
!pip install torch numpy tqdm scipy mir_eval -q
print('✓ Dependencies installed')

## 5. Check GPU

In [None]:
import torch

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f'✓ GPU available: {gpu_name} ({gpu_memory:.1f} GB)')
else:
    print('✗ WARNING: No GPU available! Training will be slow.')
    print('Go to Runtime > Change runtime type > GPU')

## 6. Add Code Directory to Path

In [None]:
import sys
sys.path.insert(0, CODE_DIR)

# Test imports
try:
    import config
    from train import ChordDataset, create_chordformer_model, train_chordformer
    print('✓ Imports successful')
    print(f'  Model type: {config.MODEL_TYPE}')
    print(f'  Batch size: {config.BATCH_SIZE}')
    print(f'  Learning rate: {config.LEARNING_RATE}')
except ImportError as e:
    print(f'✗ Import error: {e}')

## 7. Load Dataset

In [None]:
import numpy as np
import json
from torch.utils.data import DataLoader

# Load normalization parameters
norm_path = os.path.join(FEATURES_DIR, 'normalization.json')
with open(norm_path, 'r') as f:
    norm_data = json.load(f)
    
mean = np.array(norm_data['mean'], dtype=np.float32)
std = np.array(norm_data['std'], dtype=np.float32)
print(f'Normalization loaded: {len(mean)} bins')

# Find all feature files
feature_files = sorted([
    os.path.join(FEATURES_DIR, f) 
    for f in os.listdir(FEATURES_DIR) 
    if f.endswith('.npz')
])
print(f'Found {len(feature_files)} songs')

# Split into train/val/test
n_songs = len(feature_files)
n_train = int(n_songs * config.TRAIN_RATIO)
n_val = int(n_songs * config.VAL_RATIO)

# Shuffle with fixed seed for reproducibility
np.random.seed(42)
indices = np.random.permutation(n_songs)

train_files = [feature_files[i] for i in indices[:n_train]]
val_files = [feature_files[i] for i in indices[n_train:n_train+n_val]]
test_files = [feature_files[i] for i in indices[n_train+n_val:]]

print(f'Train: {len(train_files)}, Val: {len(val_files)}, Test: {len(test_files)}')

## 8. Create Data Loaders

In [None]:
# Create datasets
train_dataset = ChordDataset(
    train_files, 
    mean, std, 
    sequence_length=config.SEQUENCE_LENGTH,
    augment=True
)

val_dataset = ChordDataset(
    val_files, 
    mean, std, 
    sequence_length=config.SEQUENCE_LENGTH,
    augment=False
)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f'Train batches: {len(train_loader)}')
print(f'Val batches: {len(val_loader)}')

# Test one batch
batch = next(iter(train_loader))
print(f'Batch shape: {batch["features"].shape}')

## 9. Create Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Create model
model = create_chordformer_model()
model = model.to(device)

# Count parameters
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Model parameters: {n_params:,}')

## 10. Train Model

In [None]:
# Training configuration
EPOCHS = config.NUM_EPOCHS  # or set custom: 50, 100, etc.

print(f'Starting training for {EPOCHS} epochs...')
print(f'Checkpoints will be saved to: {CHECKPOINTS_DIR}')
print('-' * 50)

# Run training
best_model_path = train_chordformer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    num_epochs=EPOCHS,
    checkpoint_dir=CHECKPOINTS_DIR,
    early_stopping_patience=config.EARLY_STOPPING_PATIENCE
)

print(f'\n✓ Training complete!')
print(f'Best model saved to: {best_model_path}')

## 11. Export to ONNX (Optional)

In [None]:
# Load best model and export to ONNX
from export_onnx import export_chordformer_onnx

onnx_path = os.path.join(CHECKPOINTS_DIR, 'chord_model.onnx')

# Load best checkpoint
checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Export
export_chordformer_onnx(model, onnx_path)
print(f'✓ ONNX model exported to: {onnx_path}')

## 12. Download Model

Run this cell to download the trained model to your local machine.

In [None]:
from google.colab import files

# Download the best checkpoint
if os.path.exists(best_model_path):
    files.download(best_model_path)
    print(f'Downloading: {best_model_path}')

# Download ONNX model if it exists
if os.path.exists(onnx_path):
    files.download(onnx_path)
    print(f'Downloading: {onnx_path}')