## Imports and setup

In [None]:
# Create or update conda environment
# !conda env create -f ../../environment.yaml
!conda env update -f ../../environment.yaml
# !conda remove --name amadeus-ex-machina --all

In [None]:
# Activate conda environment
!conda init
!conda activate amadeus-ex-machina

In [None]:
# System imports
import sys
import os

# Add the parent directory of 'notebooks' to sys.path
parent_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # Move one level up
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# Data imports
from data.data_loader import MirDataProcessor, ChordDataProcessor
import data.youtube_download as youtube_download
from datagen.chordgen import generate_all_chords

# Model and local file imports
from models.CRNN import CRNNModel
from utils.model_utils import get_device

# Package imports
import torch
import torch.optim as optim
import torch.nn as nn
from pathlib import Path
from solver import Solver

# Directories
SECRETS_DIR = "secrets"
JSON_FILE = "chord_ref.json"

parent_dir_path = Path(parent_dir)

# Select device
device = get_device()
print(f"Device is {device}")

## Process billboard data

In [None]:
# If you have already ran the downloader, change the value of download to False
download_mirdata = False

# Download and build useable train/test data out of the MIR Billboard dataset
billboard_data_processer = MirDataProcessor(download=download_mirdata, batch_size=64)
if download_mirdata:
    billboard_data_processer.process_billboard_data()

In [None]:
# Create billboard loaders
# nrows set to shrink dataset for testing
billboard_train_loader, billboard_test_loader, billboard_num_classes = billboard_data_processer.build_data_loaders(device=device, nrows=10000)

print(f"Billboard number of classes: {billboard_num_classes}")

## Process chord data

In [None]:
# Process chord data, if you've already generated the chord files before you don't need to run this cell again

# If you have already ran the downloader, change the value of download to False
download_chordgen = False

# If your sf2 file is already downloaded and in outdir/sf2/FluidR3_GM.sf2, set this to False
download_sf2 = False
out_dir = f"{Path.cwd().parents[1]}{os.path.sep}secrets"

generate_all_chords(out_dir=out_dir, download_sf2=download_sf2, inversions=True, duration=1.0, make_dir=True, n_jobs=4)

In [None]:
# Create chord loaders
chord_data_processor = ChordDataProcessor(
    chord_json_path=parent_dir_path / SECRETS_DIR / JSON_FILE, 
    batch_size=64,
    seq_length=16,
    process_sequential=True,
    device=device
)

# Perform all steps and retrieve DataLoaders and number of classes
chord_train_loader, chord_test_loader, chord_num_classes = chord_data_processor.build_data_loaders()

print(f"Chord number of classes: {chord_num_classes}")

In [None]:
# # Set billboard data path
# billboard_data_path = "../data/raw/McGill-Billboard"

# # Process lab files in the extracted directory
# youtube_download.process_lab_files(billboard_data_path)

## Model training

In [None]:
# Be sure to modify the num of classes depending on the dataset
crnn_chord_model = CRNNModel(input_features=24, num_classes=chord_num_classes, hidden_size=128).to(device)
optimizer = optim.Adam(crnn_chord_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3)

# Initialize solver for CRNNModel
crnn_model_chord_solver = Solver(
    model=crnn_chord_model,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    train_dataloader=chord_train_loader,
    valid_dataloader=chord_test_loader,
    batch_size=32,
    epochs=10,
    device=device,
    early_stop_epochs=3,
    warmup_epochs=2,
    optuna_prune=False
)

In [None]:
chord_best_val_accuracy = crnn_model_chord_solver.train_and_evaluate(plot_results=True)
print(f"Chord data validation Accuracy: {chord_best_val_accuracy:.4f}")

In [None]:
# Be sure to modify the num of classes depending on the dataset
crnn_billboard_model = CRNNModel(input_features=24, num_classes=billboard_num_classes, hidden_size=128).to(device)
optimizer = optim.Adam(crnn_billboard_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3)

# Initialize solver for CRNNModel
crnn_model_billboard_solver = Solver(
    model=crnn_billboard_model,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    train_dataloader=billboard_train_loader,
    valid_dataloader=billboard_test_loader,
    batch_size=32,
    epochs=10,
    device=device,
    early_stop_epochs=3,
    warmup_epochs=2,
    optuna_prune=False
)

In [None]:
billboard_best_val_accuracy = crnn_model_billboard_solver.train_and_evaluate(plot_results=True)
print(f"Billboard data validation Accuracy: {billboard_best_val_accuracy:.4f}")