## Imports and setup

In [None]:
# Create or update conda environment
# !conda env create -f ../environment.yaml
!conda env update -f ../environment.yaml
# !conda remove --name amadeus-ex-machina --all

In [None]:
# Activate conda environment
!conda init
!conda activate amadeus-ex-machina

In [1]:
# System imports
import sys
import os

# Add the parent directory of 'notebooks' to sys.path
parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))  # Move one level up
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# Class/model imports
from data.data_loader import MirDataProcessor
from utils.model_utils import get_device
from solver import Solver
import data.youtube_download as youtube_download

# Package imports
import torch

# Select device
device = get_device()
print(f"Device is {device}")

Device is mps


## Download and process data

In [None]:
# Download and build useable train/test data out of the MIR Billboard dataset
data_processer = MirDataProcessor(download=True, batch_size=64)
data_processer.process_data()

# Create data loeaders for train and test set
train_loader, test_loader, num_classes = data_processer.build_data_loaders()

print(f"Number of classes: {num_classes}")

In [None]:
# Set billboard data path
billboard_data_path = "../data/raw/McGill-Billboard"

# Process lab files in the extracted directory
youtube_download.process_lab_files(billboard_data_path)

Starting to process .lab files in base directory: ../data/raw/McGill-Billboard



Entering directory: ../data/raw/McGill-Billboard/1069
Processing .lab file: ../data/raw/McGill-Billboard/1069/salami_chords.txt
Attempting to parse .lab file: ../data/raw/McGill-Billboard/1069/salami_chords.txt
Expected output path for MP3: ../data/raw/McGill-Billboard/1069/salami_chords (The J. Geils Band - Must Of Got Lost).mp3
Initiating YouTube download for query: Must Of Got Lost The J. Geils Band
[youtube:search] Extracting URL: ytsearch:Must Of Got Lost The J. Geils Band
[download] Downloading playlist: Must Of Got Lost The J. Geils Band
[youtube:search] query "Must Of Got Lost The J. Geils Band": Downloading web client config
[youtube:search] query "Must Of Got Lost The J. Geils Band" page 1: Downloading API JSON
[youtube:search] Playlist Must Of Got Lost The J. Geils Band: Downloading 1 items of 1
[download] Downloading item 1 of 1
[youtube] Extracting URL: https://www.youtube.com/watch?v=avSyXCC

[download] Got error: 1310720 bytes read, 517747 more expected


[download] Got error: 1310720 bytes read, 517747 more expected



Entering directory: ../data/raw/McGill-Billboard/0389
Processing .lab file: ../data/raw/McGill-Billboard/0389/salami_chords.txt
Attempting to parse .lab file: ../data/raw/McGill-Billboard/0389/salami_chords.txt
Expected output path for MP3: ../data/raw/McGill-Billboard/0389/salami_chords (General Public - Tenderness).mp3
Initiating YouTube download for query: Tenderness General Public
[youtube:search] Extracting URL: ytsearch:Tenderness General Public
[download] Downloading playlist: Tenderness General Public
[youtube:search] query "Tenderness General Public": Downloading web client config
[youtube:search] query "Tenderness General Public" page 1: Downloading API JSON
[youtube:search] Playlist Tenderness General Public: Downloading 1 items of 1
[download] Downloading item 1 of 1
[youtube] Extracting URL: https://www.youtube.com/watch?v=bKyVQid8Ch4
[youtube] bKyVQid8Ch4: Downloading webpage
[youtube] bKyVQid8Ch4: Download

## Build and train models

In [None]:
# Initialize Solver for MLPChordClassifier
mlp_chord_classifier_solver = Solver(
    model_type="MLPChordClassifier",
    model_kwargs={"input_size": 24, "num_classes": num_classes},
    device=device,
    batch_size=128,
    learning_rate=0.001,
    epochs=20,
)

# Train and evaluate the model
mlp_chord_classifier_solver.train_and_evaluate(train_loader, test_loader, plot_results=True)

In [None]:
# Initialize Solver for CRNNModel
crnn_model_solver = Solver(
    model_type="CRNNModel",
    model_kwargs={"input_features": 24, "num_classes": num_classes, "hidden_size": 128},
    device=device,
    batch_size=128,
    learning_rate=0.001,
    epochs=20,
)

crnn_model_solver.train_and_evaluate(train_loader, test_loader, plot_results=True)

In [None]:
# Initialize Solver for CNNModel
cnn_model_solver = Solver(
    model_type="CNNModel",
    model_kwargs={"input_channels": 24, "num_classes": num_classes},
    device=device,
    batch_size=128,
    learning_rate=0.001,
    epochs=20,
)

cnn_model_solver.train_and_evaluate(train_loader, test_loader, plot_results=True)

In [None]:
# Initialize Solver for RNNModel
rnn_model_solver = Solver(
    model_type="RNNModel",
    model_kwargs={"input_size": 24, "hidden_size": 128, "output_size": num_classes},
    device=device,
    batch_size=128,
    learning_rate=0.001,
    epochs=20,
)

rnn_model_solver.train_and_evaluate(train_loader, test_loader, plot_results=True)

## Run inference

In [None]:
scaler = data_processor.scaler
label_encoder = data_processor.label_encoder
chroma_path = "/my/path/to/amadeus-ex-machina/data/raw/McGill-Billboard/0003/bothchroma.csv"

# Run inference using the trained model
solver.run_inference(
    chroma_path,
    scaler,
    label_encoder,
)