# Notebook 3: Model Development & Training

**Objective:** Design and train a deep learning model that predicts both drum onsets and velocities from mel spectrograms.

This notebook implements our drum transcription neural network architecture, loss functions, and training pipeline. The model has dual heads for both onset detection and velocity prediction, allowing it to capture not just when a drum is hit but also how hard it's played.

## 1. Imports and Setup

Loading required libraries for model development, training, and evaluation.

In [None]:
# This so the GPU doesnt crash
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
import random
import os
import time
import sys
from sklearn.metrics import precision_recall_fscore_support


# For reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Paths
DATA_DIR = Path("../data")
PROCESSED_DATA_DIR = DATA_DIR / "processed"
MODEL_SAVE_DIR = Path("../models")
MODEL_SAVE_DIR.mkdir(exist_ok=True, parents=True)

# Drum mapping (same as in previous notebooks)
GM_DRUM_MAPPING = {
        36: "Kick",
        38: "Snare",
        42: "HiHat",
        47: "Tom",
        49: "Crash",
        51: "Ride"
    }

# Drum types constants
MAIN_DRUMS = list(GM_DRUM_MAPPING.keys())
MAIN_DRUM_NAMES = list(GM_DRUM_MAPPING.values())
N_DRUMS = len(MAIN_DRUMS)

# Configure plots
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams["figure.figsize"] = (12, 8)
sns.set_theme(style="whitegrid")

def save_checkpoint(model, optimizer, scheduler, epoch, train_losses, val_losses, f1_scores,
                   best_val_loss, patience_counter, filename):
    """Save a complete checkpoint that can be used to resume training."""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'f1_scores': f1_scores,
        'best_val_loss': best_val_loss,
        'patience_counter': patience_counter
    }, filename)
    print(f"Checkpoint saved to {filename}")

def load_checkpoint(filename, model, optimizer, scheduler=None):
    """Load a checkpoint to resume training."""
    if not Path(filename).exists():
        print(f"Checkpoint {filename} not found. Starting from scratch.")
        return 0, [], [], [], float('inf'), 0

    print(f"Loading checkpoint from {filename}")
    checkpoint = torch.load(filename, map_location=device)

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if scheduler and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    start_epoch = checkpoint['epoch'] + 1  # Resume from next epoch
    train_losses = checkpoint['train_losses']
    val_losses = checkpoint['val_losses']
    f1_scores = checkpoint['f1_scores']
    best_val_loss = checkpoint['best_val_loss']
    patience_counter = checkpoint['patience_counter']

    return start_epoch, train_losses, val_losses, f1_scores, best_val_loss, patience_counter

## 2. Create Dataset and DataLoader

Loading our preprocessed training examples (.npz files) and setting up the PyTorch data pipeline.

In [None]:
# Add this function to sample a smaller dataset
def create_small_dataset(data_dir, split, max_files=20):
    """Create a small toy dataset for testing the pipeline."""
    full_dir = Path(data_dir) / split
    all_files = list(full_dir.glob("*.npz"))

    print(f"Found {len(all_files)} total files for {split}")
    if len(all_files) <= max_files:
        return all_files

    # Take a random sample of files
    sampled_files = random.sample(all_files, max_files)
    print(f"Sampled {len(sampled_files)} files for {split} toy dataset")
    return sampled_files

# Modify the dataset class initialization
class DrumTranscriptionDataset(Dataset):
    """Dataset for loading preprocessed drum transcription examples."""

    def __init__(self, data_dir, split="train", toy_mode=False, max_files=20):
        """
        Initialize the dataset.

        Args:
            data_dir: Directory containing the processed data
            split: Which dataset split to use ('train', 'validation', or 'test')
            toy_mode: If True, use only a small subset of files
            max_files: Maximum number of files to use in toy mode
        """
        self.data_dir = Path(data_dir) / split

        #
        self.use_checkpoint = True

        if toy_mode:
            self.file_list = create_small_dataset(data_dir, split, max_files)
        else:
            self.file_list = list(self.data_dir.glob("*.npz"))

        if len(self.file_list) == 0:
            raise ValueError(f"No .npz files found in {self.data_dir}")

        print(f"Using {len(self.file_list)} examples in {split} set")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        # Load the NPZ file
        data = np.load(self.file_list[idx])

        # Extract features and targets
        mel_spec = data["mel_spec"].astype(np.float32)  # [n_mels, n_frames]
        onset_target = data["onset_target"].astype(np.float32)  # [n_drums, n_frames]
        velocity_target = data["velocity_target"].astype(np.float32)  # [n_drums, n_frames]

        # Add normalization to ensure inputs are scaled properly
        mel_spec = (mel_spec - mel_spec.mean()) / (mel_spec.std() + 1e-8)

        # Convert to tensors
        mel_spec = torch.from_numpy(mel_spec)
        onset_target = torch.from_numpy(onset_target)
        velocity_target = torch.from_numpy(velocity_target)

        return {
            "input": mel_spec,
            "onset_target": onset_target,
            "velocity_target": velocity_target,
            "file_path": str(self.file_list[idx])
        }

def collate_fn(batch):
    """
    Custom collate function to handle variable-length sequences.
    Pads sequences to the maximum length in the batch.
    """
    # Get max sequence length in this batch
    max_frames = max(item["input"].shape[1] for item in batch)

    # Initialize tensors for batched data
    batch_size = len(batch)
    n_mels = batch[0]["input"].shape[0]
    n_drums = batch[0]["onset_target"].shape[0]

    # Create padded tensors
    inputs = torch.zeros((batch_size, n_mels, max_frames))
    onset_targets = torch.zeros((batch_size, n_drums, max_frames))
    velocity_targets = torch.zeros((batch_size, n_drums, max_frames))
    file_paths = []

    # Fill in the data
    for i, item in enumerate(batch):
        frames = item["input"].shape[1]
        inputs[i, :, :frames] = item["input"]
        onset_targets[i, :, :frames] = item["onset_target"]
        velocity_targets[i, :, :frames] = item["velocity_target"]
        file_paths.append(item["file_path"])

    return {
        "input": inputs,
        "onset_target": onset_targets,
        "velocity_target": velocity_targets,
        "file_paths": file_paths
    }


In [None]:
# Create datasets
train_dataset = DrumTranscriptionDataset(PROCESSED_DATA_DIR, split="train")
val_dataset = DrumTranscriptionDataset(PROCESSED_DATA_DIR, split="validation")
test_dataset = DrumTranscriptionDataset(PROCESSED_DATA_DIR, split="test")

# Create dataloaders
batch_size = 1
num_workers = 1
pin_memory = False
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory
)

### 2.1 Inspect a Batch

Let's examine a batch of data to understand the input and target formats.

In [None]:
# Get a batch from the training loader
for batch in train_loader:
    inputs = batch["input"]
    onset_targets = batch["onset_target"]
    velocity_targets = batch["velocity_target"]

    print(f"Input shape: {inputs.shape}")  # [batch_size, n_mels, n_frames]
    print(f"Onset target shape: {onset_targets.shape}")  # [batch_size, n_drums, n_frames]
    print(f"Velocity target shape: {velocity_targets.shape}")  # [batch_size, n_drums, n_frames]

    # Plot one example
    idx = 0  # First example in batch
    plt.figure(figsize=(15, 9))

    # Plot mel spectrogram
    plt.subplot(3, 1, 1)
    plt.imshow(inputs[idx].numpy(), aspect='auto', origin='lower', cmap='viridis')
    plt.colorbar()
    plt.title('Mel Spectrogram')
    plt.ylabel('Mel Frequency Bin')

    # Plot onset targets
    plt.subplot(3, 1, 2)
    plt.imshow(onset_targets[idx].numpy(), aspect='auto', origin='lower', cmap='Reds')
    plt.colorbar()
    plt.title('Drum Onsets')
    plt.yticks(np.arange(N_DRUMS), MAIN_DRUM_NAMES)

    # Plot velocity targets
    plt.subplot(3, 1, 3)
    plt.imshow(velocity_targets[idx].numpy(), aspect='auto', origin='lower', cmap='Blues')
    plt.colorbar()
    plt.title('Drum Velocities (Normalized 0-1)')
    plt.yticks(np.arange(N_DRUMS), MAIN_DRUM_NAMES)
    plt.xlabel('Time Frames')

    plt.tight_layout()
    plt.show()

    # Just examine one batch
    break

## 3. Define Model Architecture

Create our drum transcription model based on a CNN architecture with dual output heads.

In [None]:
class ConvBlock(nn.Module):
    """Convolutional block with batch normalization and ReLU activation."""

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class DrumTranscriptionModel(nn.Module):
    """CNN model for drum transcription with dual output heads."""

    def __init__(self, n_mels=229, n_drums=6, use_lstm=True, dropout=0.3):
        """
        Initialize the model.

        Args:
            n_mels: Number of mel frequency bands in input
            n_drums: Number of drum types to detect
            use_lstm: Whether to use LSTM layers for temporal modeling
        """
        super(DrumTranscriptionModel, self).__init__()
        self.n_mels = n_mels
        self.n_drums = n_drums
        self.use_lstm = use_lstm

        # Feature extraction: convolutional layers
        self.conv_stack = nn.Sequential(
            # Layer 1: [B, 1, n_mels, T] -> [B, 32, n_mels//2, T]
            ConvBlock(1, 32),
            nn.MaxPool2d(kernel_size=(2, 1)),  # Frequency pooling

            # Layer 2: [B, 32, n_mels//2, T] -> [B, 64, n_mels//4, T]
            ConvBlock(32, 64),
            nn.MaxPool2d(kernel_size=(2, 1)),  # Frequency pooling

            # Layer 3: [B, 64, n_mels//4, T] -> [B, 128, n_mels//8, T]
            ConvBlock(64, 128),
            nn.MaxPool2d(kernel_size=(2, 1)),  # Frequency pooling

            # Layer 4: [B, 128, n_mels//8, T] -> [B, 128, n_mels//16, T]
            ConvBlock(128, 128),
            nn.MaxPool2d(kernel_size=(2, 1))   # Frequency pooling
        )

        # Calculate feature dimensions after CNN
        self.cnn_output_freq_dim = n_mels // 16
        self.cnn_output_channels = 128
        self.cnn_output_dim = self.cnn_output_channels * self.cnn_output_freq_dim

        # Optional: Bi-directional LSTM layer for temporal modeling
        if self.use_lstm:
            self.lstm = nn.LSTM(
                input_size=self.cnn_output_dim,
                hidden_size=256, num_layers=1, batch_first=True,
                bidirectional=True, dropout=dropout if dropout > 0 else 0 # LSTM dropout only applied if num_layers > 1
            )
            # Apply dropout after LSTM if specified and num_layers is 1
            self.lstm_dropout = nn.Dropout(dropout) if dropout > 0 and self.lstm.num_layers == 1 else nn.Identity()
            feature_dim = 512 # bidirectional LSTM output (2 * hidden_size)
        else:
            feature_dim = self.cnn_output_dim
            self.lstm = None
            self.lstm_dropout = nn.Identity()

        # Onset detection head
        self.onset_head = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, n_drums),
            # nn.Sigmoid()  # BCE with logits loss
        )

        # Velocity prediction head
        self.velocity_head = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, n_drums),
            # ReLU so that we dont have negatives
            nn.ReLU()
        )

    def forward(self, x):
        """
        Forward pass through the network.

        Args:
            x: Input tensor of shape [batch_size, n_mels, n_frames]

        Returns:
            Tuple of (onset_predictions, velocity_predictions)
        """
        batch_size, n_mels, n_frames = x.shape
        x = x.unsqueeze(1) # [B, 1, n_mels, T]

        # CNN forward
        x = self.conv_stack(x) # [B, C, n_mels_reduced, T]

        # Reshape for RNN/Linear
        x = x.permute(0, 3, 1, 2) # [B, T, C, n_mels_reduced]
        # Combine channel and frequency dims
        x = x.reshape(batch_size, n_frames, self.cnn_output_dim) # [B, T, C * n_mels_reduced]

        # Optional LSTM forward
        if self.use_lstm:
            x, _ = self.lstm(x) # [B, T, 2 * hidden_size]
            x = self.lstm_dropout(x) # Apply dropout

        # Output heads forward
        onset_logits = self.onset_head(x) # [B, T, n_drums] - Logits
        velocity_pred = self.velocity_head(x) # [B, T, n_drums] - Velocities >= 0

        # Reshape to match target format [B, n_drums, T]
        onset_logits = onset_logits.transpose(1, 2)
        velocity_pred = velocity_pred.transpose(1, 2)

        # onset_logits are the raw outputs before sigmoid
        # velocity_pred are the predicted velocities (>= 0 due to final ReLU)
        return onset_logits, velocity_pred


## 4. Define Loss Functions

We'll define a combined loss function that handles both onset detection (binary cross entropy) and velocity prediction (mean squared error).

In [None]:
def combined_loss_function(
    onset_pred,
    velocity_pred,
    onset_target,
    velocity_target,
    onset_weight=0.8,
    positive_weight=10.0
):
    """
    Combined loss for both onset detection and velocity prediction.

    Args:
        onset_pred: Onset predictions [B, n_drums, T]
        velocity_pred: Velocity predictions [B, n_drums, T]
        onset_target: Onset targets [B, n_drums, T]
        velocity_target: Velocity targets [B, n_drums, T]
        onset_weight: Weight for onset loss (velocity_weight = 1 - onset_weight)

    Returns:
        Tuple of (combined_loss, onset_loss, velocity_loss)
    """
    # Create weight tensor with high values for positive examples
    weights = torch.ones_like(onset_target)
    weights[onset_target > 0] = positive_weight

    # Use binary_cross_entropy_with_logits instead of binary_cross_entropy
    onset_loss = F.binary_cross_entropy_with_logits(onset_pred, onset_target, weight=weights)

    # For velocity prediction with masked loss, need probabilities from logits
    onset_probs = torch.sigmoid(onset_pred)

    # Rest of the function stays the same
    mask = onset_target > 0
    if mask.sum() > 0:
        velocity_loss = F.mse_loss(velocity_pred[mask], velocity_target[mask])
    else:
        velocity_loss = torch.tensor(0.0, device=onset_pred.device)

    velocity_weight = 1.0 - onset_weight
    combined = onset_weight * onset_loss + velocity_weight * velocity_loss

    return combined, onset_loss, velocity_loss

## 5. Training Loop

Define functions to train the model and evaluate it on validation data.

In [None]:
# Initialize weights for better performance
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.1)

In [None]:
def train_epoch(model, train_loader, optimizer, device, onset_weight=0.8, scaler=None, accum_steps=4):
    """Train the model for one epoch with gradient accumulation and mixed precision."""
    model.train()
    epoch_loss = 0
    epoch_onset_loss = 0
    epoch_velocity_loss = 0
    batch_count = 0
    optimizer.zero_grad()  # Zero gradients once at the beginning

    # Loop over batches
    progress_bar = tqdm(train_loader, desc="Training")
    for i, batch in enumerate(progress_bar):
        # Move data to device
        inputs = batch["input"].to(device)
        onset_target = batch["onset_target"].to(device)
        velocity_target = batch["velocity_target"].to(device)

        # Forward pass with mixed precision
        with torch.amp.autocast(device_type='cuda'):
            onset_pred, velocity_pred = model(inputs)

            # Calculate loss
            loss, onset_loss, velocity_loss = combined_loss_function(
                onset_pred, velocity_pred, onset_target, velocity_target, onset_weight
            )

            # Scale the loss by accumulation steps
            loss = loss / accum_steps

        # Backward pass with scaled gradients
        scaler.scale(loss).backward()

        # Only update weights after accumulating gradients for accum_steps batches
        if (i + 1) % accum_steps == 0 or (i + 1) == len(train_loader):
            # Unscale before optimizer step (helps with gradient clipping if used)
            scaler.unscale_(optimizer)

            # Update weights with scaling aware step
            scaler.step(optimizer)

            # Update the scaler
            scaler.update()

            # Zero gradients for next batch
            optimizer.zero_grad()

        # For metrics tracking, use the unscaled loss
        actual_loss = loss.item() * accum_steps
        epoch_loss += actual_loss
        epoch_onset_loss += onset_loss.item()
        epoch_velocity_loss += velocity_loss.item()
        batch_count += 1

        # Update progress bar
        progress_bar.set_postfix({
            'loss': f"{actual_loss:.4f}",
            'o_loss': f"{onset_loss.item():.4f}",
            'v_loss': f"{velocity_loss.item():.4f}"
        })

    # Calculate average losses
    return {
        'loss': epoch_loss / batch_count,
        'onset_loss': epoch_onset_loss / batch_count,
        'velocity_loss': epoch_velocity_loss / batch_count
    }

# Replace both current validation functions with this single one
def validate(model, val_loader, device, onset_weight=0.8, positive_weight=10.0, fixed_threshold=None):
    """
    Evaluate the model on validation data with threshold optimization.

    Args:
        model: The model to evaluate
        val_loader: DataLoader for validation data
        device: Device to run on (CPU or GPU)
        onset_weight: Weight for onset loss
        positive_weight: Weight for positive examples in BCE loss
        fixed_threshold: If provided, use this threshold instead of optimizing

    Returns:
        Dictionary of metrics including optimized threshold
    """
    model.eval()
    val_loss = 0
    val_onset_loss = 0
    val_velocity_loss = 0
    batch_count = 0

    # For computing onset detection metrics
    all_onset_preds = []
    all_onset_targets = []
    all_drum_preds = {i: [] for i in range(len(MAIN_DRUM_NAMES))}
    all_drum_targets = {i: [] for i in range(len(MAIN_DRUM_NAMES))}

    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc="Validating")
        for batch in progress_bar:
            # Move data to device
            inputs = batch["input"].to(device)
            onset_target = batch["onset_target"].to(device)
            velocity_target = batch["velocity_target"].to(device)

            # Forward pass
            onset_pred, velocity_pred = model(inputs)

            # Calculate loss
            loss, onset_loss, velocity_loss = combined_loss_function(
                onset_pred, velocity_pred, onset_target, velocity_target,
                onset_weight=onset_weight,
                positive_weight=positive_weight
            )

            # Convert to probabilities for metrics
            onset_probs = torch.sigmoid(onset_pred)

            # Update metrics
            val_loss += loss.item()
            val_onset_loss += onset_loss.item()
            val_velocity_loss += velocity_loss.item()
            batch_count += 1

            # Store predictions for F1 calculation
            onset_probs_np = onset_probs.detach().cpu().numpy()
            onset_target_np = onset_target.detach().cpu().numpy()
            all_onset_preds.append(onset_probs_np.reshape(-1))
            all_onset_targets.append(onset_target_np.reshape(-1))

            # Store per-drum predictions
            for i in range(len(MAIN_DRUM_NAMES)):
                all_drum_preds[i].extend(onset_probs_np[:, i, :].flatten())
                all_drum_targets[i].extend(onset_target_np[:, i, :].flatten())

            # Update progress bar
            progress_bar.set_postfix({
                'val_loss': f"{loss.item():.4f}",
                'o_loss': f"{onset_loss.item():.4f}",
                'v_loss': f"{velocity_loss.item():.4f}"
            })

    # Flatten all predictions and targets
    all_preds = np.concatenate(all_onset_preds)
    all_targets = np.concatenate(all_onset_targets)

    # If a fixed threshold is provided, use it. Otherwise, optimize.
    if fixed_threshold is not None:
        best_threshold = fixed_threshold
        binary_preds = (all_preds > best_threshold).astype(int)
        precision, recall, f1, _ = precision_recall_fscore_support(
            all_targets, binary_preds, average='binary', zero_division=0
        )
        best_precision = precision
        best_recall = recall
        best_f1 = f1
        print(f"Using fixed threshold: {best_threshold:.2f}, F1: {best_f1:.4f}")
    else:
        # Try different thresholds to find best F1
        best_f1 = 0
        best_threshold = 0.5
        best_precision = 0
        best_recall = 0

        # Test thresholds in a range
        for threshold in np.arange(0.2, 0.8, 0.05):
            binary_preds = (all_preds > threshold).astype(int)
            precision, recall, f1, _ = precision_recall_fscore_support(
                all_targets, binary_preds, average='binary', zero_division=0
            )

            if f1 > best_f1:
                best_f1 = f1
                best_threshold = threshold
                best_precision = precision
                best_recall = recall

        print(f"Best threshold: {best_threshold:.2f} with F1: {best_f1:.4f}")

    # Per-drum metrics with best threshold
    drum_metrics = {}
    for i, drum_name in enumerate(MAIN_DRUM_NAMES):
        drum_preds = np.array(all_drum_preds[i])
        drum_targets = np.array(all_drum_targets[i])

        binary_preds = (drum_preds > best_threshold).astype(int)
        p, r, f1, _ = precision_recall_fscore_support(
            drum_targets, binary_preds, average='binary', zero_division=0
        )
        drum_metrics[drum_name] = {'precision': p, 'recall': r, 'f1': f1}

    # Calculate average losses
    return {
        'loss': val_loss / batch_count,
        'onset_loss': val_onset_loss / batch_count,
        'velocity_loss': val_velocity_loss / batch_count,
        'precision': best_precision,
        'recall': best_recall,
        'f1': best_f1,
        'threshold': best_threshold,
        'drum_metrics': drum_metrics
    }

def train_model(model, train_loader, val_loader, device,
                learning_rate=0.001, epochs=30, patience=5, onset_weight=0.8,
                resume_from=None, accum_steps=8):
    """Train the model with early stopping, learning rate scheduling, and checkpointing."""

    # Initialize optimizer and scheduler
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3
    )

    # Enable mixed precision training
    scaler = torch.amp.GradScaler('cuda')

    # Track metrics
    best_val_loss = float('inf')
    patience_counter = 0
    train_losses = []
    val_losses = []
    f1_scores = []
    start_epoch = 0

    # Create checkpoint directory if it doesn't exist
    checkpoint_dir = MODEL_SAVE_DIR / 'checkpoints'
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    # Resume training if a checkpoint is provided
    if resume_from:
        start_epoch, train_losses, val_losses, f1_scores, best_val_loss, patience_counter = load_checkpoint(
            resume_from, model, optimizer, scheduler
        )
        print(f"Resuming training from epoch {start_epoch}")

    # Main training loop
    for epoch in range(start_epoch, epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        start_time = time.time()

        # Check memory useage
        print(f"VRAM used: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
        print(f"VRAM free: {(12 - torch.cuda.max_memory_allocated() / 1e9):.2f} GB")

        # Train for one epoch
        train_metrics = train_epoch(model, train_loader, optimizer, device, onset_weight, scaler, accum_steps=accum_steps)
        train_loss = train_metrics['loss']
        train_losses.append(train_loss)

        # Validate
        val_metrics = validate(model, val_loader, device, onset_weight)
        val_loss = val_metrics['loss']
        val_losses.append(val_loss)
        f1_scores.append(val_metrics['f1'])

        # Update learning rate scheduler
        scheduler.step(val_loss)

        # Print epoch summary
        epoch_time = time.time() - start_time
        print(f"Epoch {epoch+1}/{epochs} completed in {epoch_time:.2f}s")
        print(f"Train Loss: {train_loss:.4f} (Onset: {train_metrics['onset_loss']:.4f}, Velocity: {train_metrics['velocity_loss']:.4f})")
        print(f"Val Loss: {val_loss:.4f} (Onset: {val_metrics['onset_loss']:.4f}, Velocity: {val_metrics['velocity_loss']:.4f})")
        print(f"F1 Score: {val_metrics['f1']:.4f}, Precision: {val_metrics['precision']:.4f}, Recall: {val_metrics['recall']:.4f}")

        # Save regular checkpoint after each epoch
        checkpoint_path = checkpoint_dir / f"checkpoint_epoch_{epoch+1}.pt"
        save_checkpoint(
            model, optimizer, scheduler, epoch, train_losses, val_losses,
            f1_scores, best_val_loss, patience_counter, checkpoint_path
        )

        # Also save latest checkpoint (overwriting the previous one)
        latest_path = checkpoint_dir / "checkpoint_latest.pt"
        save_checkpoint(
            model, optimizer, scheduler, epoch, train_losses, val_losses,
            f1_scores, best_val_loss, patience_counter, latest_path
        )

        # Automatically download checkpoints every few epochs if in Colab
        if 'google.colab' in sys.modules and (epoch % 5 == 0 or epoch == epochs-1):
            try:
                from google.colab import files
                print("\nDownloading checkpoint files to prevent data loss...")

                # Download best model if it exists
                best_model_path = MODEL_SAVE_DIR / 'drum_transcription_best.pt'
                if best_model_path.exists():
                    files.download(str(best_model_path))
                    print("Downloaded best model")

                # Download latest checkpoint
                files.download(str(latest_path))
                print("Downloaded latest checkpoint")

                # Continue training without interruption
            except Exception as e:
                print(f"Could not download automatically: {e}")
                print("Consider manually running the download_checkpoints() function in a separate cell")

        # Check for improvement and save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0

            # Save the best model
            best_model_path = MODEL_SAVE_DIR / 'drum_transcription_best.pt'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'f1_score': val_metrics['f1'],
            }, best_model_path)

            print(f"Saved best model with validation loss: {best_val_loss:.4f}")
        else:
            patience_counter += 1
            print(f"No improvement for {patience_counter} epochs")

            if patience_counter >= patience:
                print(f"Early stopping after {epoch+1} epochs!")
                break

    # Save the final model
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'f1_score': val_metrics['f1'],
    }, MODEL_SAVE_DIR / 'drum_transcription_final.pt')

    # Plot training history
    plt.figure(figsize=(12, 8))

    plt.subplot(2, 1, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Loss History')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.subplot(2, 1, 2)
    plt.plot(f1_scores, label='F1 Score', color='green')
    plt.title('F1 Score History')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()

    return model, train_losses, val_losses, f1_scores

Visualization

In [None]:
def visualize_predictions(model, data_loader, device, threshold=0.5, num_samples=3):
    """Visualize model predictions on a few samples."""
    model.eval()

    samples_seen = 0

    with torch.no_grad():
        for batch in data_loader:
            inputs = batch["input"].to(device)
            onset_targets = batch["onset_target"]
            velocity_targets = batch["velocity_target"]
            file_paths = batch["file_paths"]

            # Forward pass
            onset_preds, velocity_preds = model(inputs)
            onset_preds = onset_preds.cpu()
            onset_probs = torch.sigmoid(onset_preds)
            velocity_preds = velocity_preds.cpu()

            # Loop through each item in the batch
            for i in range(min(inputs.size(0), num_samples - samples_seen)):
                # Get a single sample
                input_spec = inputs[i].cpu()
                onset_target = onset_targets[i]
                velocity_target = velocity_targets[i]
                onset_pred = onset_preds[i]
                velocity_pred = velocity_preds[i]

                # Create binary onset predictions - fix applied here
                binary_onset = (onset_probs[i] > threshold).float()

                # Create a mask for velocities based on predicted onsets
                masked_velocity_pred = velocity_pred * binary_onset

                # Visualize
                plt.figure(figsize=(15, 10))

                # Plot input spectrogram
                plt.subplot(4, 1, 1)
                plt.imshow(input_spec.numpy(), aspect='auto', origin='lower', cmap='viridis')
                plt.colorbar()
                plt.title('Mel Spectrogram')
                plt.ylabel('Mel Bin')

                # Plot ground truth onsets
                plt.subplot(4, 1, 2)
                plt.imshow(onset_target.numpy(), aspect='auto', origin='lower', cmap='Reds')
                plt.colorbar()
                plt.title('Ground Truth Onsets')
                plt.yticks(np.arange(N_DRUMS), MAIN_DRUM_NAMES)

                # Plot predicted onsets (binary)
                plt.subplot(4, 1, 3)
                plt.imshow(binary_onset.numpy(), aspect='auto', origin='lower', cmap='OrRd')
                plt.colorbar()
                plt.title(f'Predicted Onsets (threshold={threshold})')
                plt.yticks(np.arange(N_DRUMS), MAIN_DRUM_NAMES)

                # Plot predicted velocities (masked by onset predictions)
                plt.subplot(4, 1, 4)
                plt.imshow(masked_velocity_pred.numpy(), aspect='auto', origin='lower', cmap='Blues')
                plt.colorbar()
                plt.title('Predicted Velocities (only where onset predicted)')
                plt.yticks(np.arange(N_DRUMS), MAIN_DRUM_NAMES)
                plt.xlabel('Time Frame')

                plt.tight_layout()
                plt.show()

                # Show the file path of this example
                print(f"File: {file_paths[i]}")

                samples_seen += 1

            if samples_seen >= num_samples:
                break

## 6. Grid Search

In [None]:
def grid_search_hyperparameters(train_loader, val_loader, device, max_epochs=10):
    """
    Run a grid search over key hyperparameters and return the best configuration.
    """
    # Define grid search parameters
    param_grid = {
        'use_lstm': [True, False],
        'positive_weight': [10.0, 15.0, 20.0],
        'dropout': [0.2, 0.3, 0.5],
        'learning_rate': [0.001, 0.002, 0.005],
        'onset_weight': [0.7, 0.8, 0.9]
    }

    results = []
    best_f1 = 0
    best_config = {}

    print("Starting grid search...")

    # Generate configs to try (limited number for practicality)
    configs = []
    for use_lstm in param_grid['use_lstm']:
        for pos_weight in param_grid['positive_weight']:
            for dropout in param_grid['dropout']:
                # Limit combinations to make search practical
                if use_lstm:
                    configs.append({
                        'use_lstm': use_lstm,
                        'positive_weight': pos_weight,
                        'dropout': dropout,
                        'learning_rate': 0.002,  # Fixed for LSTM
                        'onset_weight': 0.8      # Fixed for LSTM
                    })
                else:
                    configs.append({
                        'use_lstm': use_lstm,
                        'positive_weight': pos_weight,
                        'dropout': dropout,
                        'learning_rate': 0.005,  # Fixed for CNN
                        'onset_weight': 0.8      # Fixed for CNN
                    })

    print(f"Will evaluate {len(configs)} configurations")

    for i, config in enumerate(configs):
        print(f"\nTesting configuration {i+1}/{len(configs)}:")
        print(config)

        # Create and initialize model
        model = DrumTranscriptionModel(
            n_mels=229,
            n_drums=N_DRUMS,
            use_lstm=config['use_lstm'],
            dropout=config['dropout']
        )
        model.apply(init_weights)
        model = model.to(device)

        # Define custom loss function with this config's positive weight
        def custom_loss_fn(onset_pred, velocity_pred, onset_target, velocity_target):
            return combined_loss_function(
                onset_pred, velocity_pred, onset_target, velocity_target,
                onset_weight=config['onset_weight'],
                positive_weight=config['positive_weight']
            )

        # Train with early stopping
        optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=2
        )

        # Enable mixed precision training
        scaler = torch.amp.GradScaler('cuda')

        best_val_f1 = 0
        best_epoch = 0
        early_stop_counter = 0

        for epoch in range(max_epochs):
            # Train one epoch
            train_metrics = train_epoch(
                model, train_loader, optimizer, device,
                onset_weight=config['onset_weight'],
                scaler=scaler,
                accum_steps=1  # Smaller for grid search
            )

            # Validate with threshold optimization
            val_metrics = validate(
                model, val_loader, device,
                onset_weight=config['onset_weight'],
                positive_weight=config['positive_weight']
            )

            print(f"Epoch {epoch+1}: F1={val_metrics['f1']:.4f} (thresh={val_metrics['threshold']:.2f}), "
                  f"Train Loss={train_metrics['loss']:.4f}, Val Loss={val_metrics['loss']:.4f}")

            # Update learning rate scheduler
            scheduler.step(val_metrics['loss'])

            # Track best model
            if val_metrics['f1'] > best_val_f1:
                best_val_f1 = val_metrics['f1']
                best_epoch = epoch
                early_stop_counter = 0
            else:
                early_stop_counter += 1
                if early_stop_counter >= 3:  # Early stopping
                    break

        # Store results
        results.append({
            'config': config,
            'best_f1': best_val_f1,
            'best_epoch': best_epoch + 1
        })

        # Update overall best
        if best_val_f1 > best_f1:
            best_f1 = best_val_f1
            best_config = config.copy()

        # Clean up
        del model
        torch.cuda.empty_cache()

    # Sort results
    results.sort(key=lambda x: x['best_f1'], reverse=True)

    # Print top results
    print("\nTop configurations:")
    for i in range(min(5, len(results))):
        res = results[i]
        print(f"{i+1}. F1: {res['best_f1']:.4f} at epoch {res['best_epoch']} with: {res['config']}")

    return best_config, results

In [None]:
# Create larger toy datasets for grid search
toy_train_dataset = DrumTranscriptionDataset(PROCESSED_DATA_DIR, split="train", toy_mode=True, max_files=100)
toy_val_dataset = DrumTranscriptionDataset(PROCESSED_DATA_DIR, split="validation", toy_mode=True, max_files=20)

# Create toy dataloaders
toy_batch_size = 3
toy_train_loader = DataLoader(
    toy_train_dataset,
    batch_size=toy_batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=1,
    pin_memory=False
)

toy_val_loader = DataLoader(
    toy_val_dataset,
    batch_size=toy_batch_size,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=1,
    pin_memory=False
)

# Run grid search
best_config, all_results = grid_search_hyperparameters(
    toy_train_loader,
    toy_val_loader,
    device,
    max_epochs=15
)

print("\nBest configuration found:")
print(f"F1 score: {all_results[0]['best_f1']:.4f}")
print(best_config)


In [None]:
# --- Load and evaluate the best model ---
print("\nLoading the overall best model for evaluation...")

# Define path to the best model saved during full training
best_model_path = MODEL_SAVE_DIR / 'drum_transcription_best.pt'

if best_model_path.exists():
    # Load the checkpoint
    checkpoint = torch.load(best_model_path, map_location=device)

    best_model_instance = DrumTranscriptionModel(
        n_mels=229,
        n_drums=N_DRUMS,
        use_lstm=best_config['use_lstm'],
        dropout_rate=best_config['dropout_rate'],
    )

    # Load the state dict
    best_model_instance.load_state_dict(checkpoint['model_state_dict'])
    best_model_instance = best_model_instance.to(device)
    best_model_instance.eval()

    print(f"Loaded best model from epoch {checkpoint.get('epoch', 'N/A')+1} with validation loss: {checkpoint.get('val_loss', 'N/A'):.4f}")

    # Validate on the test set using the best config parameters from grid search
    # for consistency in evaluation metrics (e.g., positive_weight)
    print("Validating the loaded best model on the test set...")
    test_metrics_loaded = validate(
        best_model_instance,
        test_loader,
        device,
        onset_weight=best_config['onset_weight'],
        positive_weight=best_config['positive_weight']
    )

    # Print detailed metrics
    print(f"\nTest Set Evaluation (Loaded Best Model, threshold={test_metrics_loaded['threshold']:.2f}):")
    print(f"  Loss: {test_metrics_loaded['loss']:.4f}")
    print(f"  F1 Score: {test_metrics_loaded['f1']:.4f}")
    print(f"  Precision: {test_metrics_loaded['precision']:.4f}")
    print(f"  Recall: {test_metrics_loaded['recall']:.4f}")
    print("\n  Per-drum performance:")
    for drum_name, metrics in test_metrics_loaded['drum_metrics'].items():
        print(f"    {drum_name}: F1={metrics['f1']:.4f}, P={metrics['precision']:.4f}, R={metrics['recall']:.4f}")

    # Visualize predictions using the optimized threshold found during validation
    print("\nVisualizing predictions from the loaded best model...")
    visualize_predictions(
        best_model_instance,
        test_loader,
        device,
        threshold=test_metrics_loaded['threshold'],
        num_samples=3
    )

else:
    print(f"Best model checkpoint not found at {best_model_path}. Skipping evaluation.")
    # Assign None or handle the case where the best model doesn't exist yet
    best_model_instance = None
    test_metrics_loaded = None


### 6.2. Real Model


In [None]:
print(f"VRAM used: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
print(f"VRAM free: {(12 - torch.cuda.max_memory_allocated() / 1e9):.2f} GB")

In [None]:

# Train final model with best config
print("\nTraining final model with best configuration...")
final_model = DrumTranscriptionModel(
    n_mels=229,
    n_drums=N_DRUMS,
    use_lstm=best_config['use_lstm'],
    dropout=best_config['dropout']
)
final_model.apply(init_weights)
final_model = final_model.to(device)

# Train with best config
final_model, _, _, _ = train_model(
    model=final_model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    learning_rate=best_config['learning_rate'],
    epochs=30,
    patience=7,  # More patience for final model
    onset_weight=best_config['onset_weight'],
    accum_steps=4,  # Can adjust based on memory
    resume_from=None  # Start fresh
)

## 7. Evaluate on Test Set

Evaluate the best model on the test set to get an unbiased measure of performance.

In [None]:
# Load best model
best_model_path = MODEL_SAVE_DIR / 'drum_transcription_best.pt'
checkpoint = torch.load(best_model_path, map_location=device)

# Create a new model instance and load state
best_model = DrumTranscriptionModel(n_mels=229, n_drums=N_DRUMS, use_lstm=True)
best_model.load_state_dict(checkpoint['model_state_dict'])
best_model = best_model.to(device)

print(f"Loaded best model from epoch {checkpoint['epoch']+1} with validation loss: {checkpoint['val_loss']:.4f}")

# Evaluate on test set
test_metrics = validate(best_model, test_loader, device, onset_weight=onset_weight)

print("\nTest Set Evaluation:")
print(f"Test Loss: {test_metrics['loss']:.4f} (Onset: {test_metrics['onset_loss']:.4f}, Velocity: {test_metrics['velocity_loss']:.4f})")
print(f"F1 Score: {test_metrics['f1']:.4f}, Precision: {test_metrics['precision']:.4f}, Recall: {test_metrics['recall']:.4f}")

## 8. Visualize Predictions

Let's visualize some predictions from the test set to qualitatively assess model performance.

In [None]:
# Visualize some predictions
visualize_predictions(best_model, test_loader, device, threshold=0.5, num_samples=3)

## 9. Conclusion

In this notebook, we've successfully built and trained a deep learning model for drum transcription. The model has dual outputs: one for detecting drum onsets (when a drum is hit) and another for predicting the velocity (how hard it's hit).

**Key accomplishments:**

1. Created a PyTorch dataset for loading processed training examples
2. Designed a CNN/CRNN architecture with dual output heads
3. Implemented a combined loss function for multi-task learning
4. Trained the model with early stopping and learning rate scheduling
5. Evaluated performance using appropriate metrics
6. Visualized predictions alongside ground truth

The F1 score on the test set gives us an idea of how well the model is detecting drum hits. In our next notebook, we'll explore how to convert these predictions into MIDI files and evaluate the full transcription system.