## Imports and setup

In [None]:
# System imports
import sys
import os

# Add the parent directory of 'notebooks' to sys.path
parent_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # Move one level up
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# Data imports
from data.data_loader import MirDataProcessor, ChordDataProcessor
import data.youtube_download as youtube_download
from datagen.chordgen import generate_all_chords

# Model and local file imports
from models.CRNN import CRNNModel
from utils.model_utils import get_device

# Package imports
import torch
import torch.optim as optim
import torch.nn as nn
from pathlib import Path
from solver import Solver

# Directories
SECRETS_DIR = "secrets"
JSON_FILE = "chord_ref.json"

parent_dir_path = Path(parent_dir)

# Select device
# device = get_device()
device = "cpu"
print(f"Device is {device}")

## Process billboard data

In [None]:
# If you have already ran the downloader, change the value of download to False
download_mirdata = False

# Download and build useable train/test data out of the MIR Billboard dataset
billboard_data_processer = MirDataProcessor(download=download_mirdata, dataset_name="billboard", batch_size=64, seq_length=16)
if download_mirdata:
    billboard_data_processer.process_billboard_data()

In [None]:
# Create billboard loaders
# nrows set to shrink dataset for testing
# Generate MirDataProcessor loaders and determine target shapes
billboard_train_loader, billboard_test_loader, billboard_num_classes = billboard_data_processer.build_data_loaders(device=device, nrows=10000)
print(f"MIR Number of Classes: {billboard_num_classes}")

# Extract expected feature and label dimensions
billboard_train_data = next(iter(billboard_train_loader))
target_features_shape = billboard_train_data[0].shape
target_labels_shape = billboard_train_data[1].shape

print(f"target_features_shape: {target_features_shape}")
print(f"target_labels_shape: {target_labels_shape}")

## Process chord data

In [None]:
# # Process chord data, if you've already generated the chord files before you don't need to run this cell again

# # If you have already ran the downloader, change the value of download to False
# download_chordgen = False

# # If your sf2 file is already downloaded and in outdir/sf2/FluidR3_GM.sf2, set this to False
# download_sf2 = False
# out_dir = f"{Path.cwd().parents[1]}{os.path.sep}secrets"

# generate_all_chords(out_dir=out_dir, download_sf2=download_sf2, inversions=True, duration=1.0, make_dir=True, n_jobs=4)

In [None]:
# Create chord loaders
chord_data_processor = ChordDataProcessor(
    chord_json_path=parent_dir_path / SECRETS_DIR / JSON_FILE, 
    batch_size=64,
    seq_length=16,
    device=device
)

# Get target shapes from MirDataProcessor
mir_train_data = next(iter(billboard_train_loader))
target_features_shape = mir_train_data[0].shape
target_labels_shape = mir_train_data[1].shape

# Process chord data to match dimensions
chord_train_loader, chord_test_loader, chord_num_classes = chord_data_processor.process_all_and_build_loaders(
    target_features_shape=target_features_shape,
    target_labels_shape=target_labels_shape,
    test_size=0.2,
    random_state=42,
)

## Model training

In [None]:
# Step 1: Set up model to use chord data
# Modify the number of classes depending on the dataset
crnn_model = CRNNModel(input_features=24, num_classes=chord_num_classes, hidden_size=128).to(device)

optimizer = optim.Adam(crnn_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=3)

# Initialize solver for CRNNModel
crnn_model_chord_solver = Solver(
    model=crnn_model,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    train_dataloader=chord_train_loader,
    valid_dataloader=chord_test_loader,
    batch_size=32,
    epochs=10,
    device=device,
    early_stop_epochs=3,
    warmup_epochs=2,
    optuna_prune=False,
)

In [None]:
# Step 2: Train and evaluate on chord data
print("Training on chord data...")
crnn_model_chord_solver.train_and_evaluate(plot_results=True)

# Save the trained weights
torch.save(crnn_model.state_dict(), "chord_model_weights.pth")

In [None]:
# Step 3: Move model to using billboard data
print("Adapting model for billboard data...")
crnn_model.update_output_layer(num_classes=billboard_num_classes)

# Load pre-trained weights excluding output layer
state_dict = torch.load("chord_model_weights.pth")
state_dict = {k: v for k, v in state_dict.items() if "fc" not in k}
crnn_model.load_state_dict(state_dict, strict=False)

# Reinitialize output layer weights
crnn_model.fc.reset_parameters()

# Reinitialize the optimizer for billboard dataset
optimizer = optim.Adam(crnn_model.parameters(), lr=0.001)

# Initialize solver for CRNNModel with billboard data
crnn_model_billboard_solver = Solver(
    model=crnn_model,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    train_dataloader=billboard_train_loader,
    valid_dataloader=billboard_test_loader,
    batch_size=32,
    epochs=10,
    device=device,
    early_stop_epochs=3,
    warmup_epochs=2,
    optuna_prune=False,
)

In [None]:
# Step 4: Freeze feature extractor and train on billboard data
print("Freezing feature extractor and training on billboard data...")
crnn_model.freeze_feature_extractor()
crnn_model_billboard_solver.train_and_evaluate(plot_results=True)

In [None]:
# Step 5: Unfreeze feature extractor and train on billboard data
print("Unfreezing feature extractor and training on billboard data...")
crnn_model.unfreeze_feature_extractor()
crnn_model_billboard_solver.train_and_evaluate(plot_results=True)