## Imports and setup

In [None]:
# Create or update conda environment
# !conda env create -f ../environment.yaml
!conda env update -f ../environment.yaml

In [1]:
# Activate conda environment
!conda init
!conda activate amadeus-ex-machina

no change     /opt/anaconda3/condabin/conda
no change     /opt/anaconda3/bin/conda
no change     /opt/anaconda3/bin/conda-env
no change     /opt/anaconda3/bin/activate
no change     /opt/anaconda3/bin/deactivate
no change     /opt/anaconda3/etc/profile.d/conda.sh
no change     /opt/anaconda3/etc/fish/conf.d/conda.fish
no change     /opt/anaconda3/shell/condabin/Conda.psm1
no change     /opt/anaconda3/shell/condabin/conda-hook.ps1
no change     /opt/anaconda3/lib/python3.12/site-packages/xontrib/conda.xsh
no change     /opt/anaconda3/etc/profile.d/conda.csh
no change     /Users/dananabulsi/.bash_profile
No action taken.

CondaError: Run 'conda init' before 'conda activate'



In [2]:
# System imports
import sys
import os

# Add the parent directory of 'notebooks' to sys.path
parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))  # Move one level up
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# Class/model imports
from data.data_loader import MirDataProcessor
from utils.model_utils import get_device
from solver import Solver
import data.data_loader as data_loader

# Package imports
import torch

# Select device
device = get_device()
print(f"Device is {device}")

Device is mps


## Download and process data

In [3]:
# Download and build useable train/test data out of the MIR Billboard dataset
data_processer = MirDataProcessor(download=True, batch_size=64)
data_processer.map_annotations()

# Create data loeaders for train and test set
train_loader, test_loader, num_classes = data_processer.build_data_loaders()

print(f"Number of classes: {num_classes}")

INFO: Downloading ['metadata', 'annotation_salami', 'annotation_lab', 'annotation_mirex13', 'annotation_chordino', 'index']. Index is being stored in /opt/anaconda3/envs/amadeus-ex-machina/lib/python3.11/site-packages/mirdata/datasets/indexes, and the rest of files in /Users/dananabulsi/Desktop/CODING/GitHub-Repos/amadeus-ex-machina/data/raw
INFO: [metadata] downloading billboard-2.0-index.csv
64.0kB [00:02, 22.6kB/s]                            
INFO: [annotation_salami] downloading billboard-2.0-salami_chords.tar.gz
336kB [00:03, 91.4kB/s]                            
INFO: [annotation_lab] downloading billboard-2.0.1-lab.tar.gz
0.99MB [00:03, 288kB/s]                             
INFO: [annotation_mirex13] downloading billboard-2.0.1-mirex.tar.gz
1.24MB [00:17, 75.7kB/s]                            
INFO: [annotation_chordino] downloading billboard-2.0-chordino.tar.gz
  2%|▏         | 6.09M/251M [00:05<03:46, 1.14MB/s]   


KeyboardInterrupt: 

In [4]:
# Links and download paths
url = "https://www.dropbox.com/s/p4xtixbvt4hw5c6/billboard-2.0-salami_chords.tar.xz?dl=1"
billboard_salami_download_path = "../data/raw/billboard-2.0-salami_chords.tar.xz"
billboard_salami_path = "../data/raw/billboard-2.0-salami_chords"

# Download and extract the archive
data_loader.download_and_extract(url, billboard_salami_download_path, billboard_salami_path)

Downloading from https://www.dropbox.com/s/p4xtixbvt4hw5c6/billboard-2.0-salami_chords.tar.xz?dl=1...
Downloaded file saved to ../data/raw/billboard-2.0-salami_chords.tar.xz.
Extracting ../data/raw/billboard-2.0-salami_chords.tar.xz to ../data/raw/billboard-2.0-salami_chords...
Extraction complete! Files extracted to ../data/raw/billboard-2.0-salami_chords
Removed the downloaded file: ../data/raw/billboard-2.0-salami_chords.tar.xz


In [None]:
# Process lab files in the extracted directory
data_loader.process_lab_files(billboard_salami_path)

Processing lab file: ../data/raw/billboard-2.0-salami_chords/McGill-Billboard/1069/salami_chords.txt
[download]  61.7% of    2.75MiB at   18.55KiB/s ETA 00:58

[download] Got error: HTTPSConnectionPool(host='rr2---sn-f5u5opt5-3fp6.googlevideo.com', port=443): Read timed out.


[download] Got error: HTTPSConnectionPool(host='rr2---sn-f5u5opt5-3fp6.googlevideo.com', port=443): Read timed out.
Processing lab file: ../data/raw/billboard-2.0-salami_chords/McGill-Billboard/0987/salami_chords.txt
[download]  45.7% of    2.42MiB at   16.96KiB/s ETA 01:19

## Build and train models

In [None]:
# Initialize Solver for MLPChordClassifier
mlp_chord_classifier_solver = Solver(
    model_type="MLPChordClassifier",
    model_kwargs={"input_size": 24, "num_classes": num_classes},
    device=device,
    batch_size=128,
    learning_rate=0.001,
    epochs=20,
)

# Train and evaluate the model
mlp_chord_classifier_solver.train_and_evaluate(train_loader, test_loader, plot_results=True)

In [None]:
# Initialize Solver for CRNNModel
crnn_model_solver = Solver(
    model_type="CRNNModel",
    model_kwargs={"input_features": 24, "num_classes": num_classes, "hidden_size": 128},
    device=device,
    batch_size=128,
    learning_rate=0.001,
    epochs=20,
)

crnn_model_solver.train_and_evaluate(train_loader, test_loader, plot_results=True)

In [None]:
# Initialize Solver for CNNModel
cnn_model_solver = Solver(
    model_type="CNNModel",
    model_kwargs={"input_channels": 24, "num_classes": num_classes},
    device=device,
    batch_size=128,
    learning_rate=0.001,
    epochs=20,
)

cnn_model_solver.train_and_evaluate(train_loader, test_loader, plot_results=True)

In [None]:
# Initialize Solver for RNNModel
rnn_model_solver = Solver(
    model_type="RNNModel",
    model_kwargs={"input_size": 24, "hidden_size": 128, "output_size": num_classes},
    device=device,
    batch_size=128,
    learning_rate=0.001,
    epochs=20,
)

rnn_model_solver.train_and_evaluate(train_loader, test_loader, plot_results=True)

## Run inference

In [None]:
scaler = data_processor.scaler
label_encoder = data_processor.label_encoder
chroma_path = "/my/path/to/amadeus-ex-machina/data/raw/McGill-Billboard/0003/bothchroma.csv"

# Run inference using the trained model
solver.run_inference(
    chroma_path,
    scaler,
    label_encoder,
)