<font size=6>Downloading required libraries</font>

In [None]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
!pip3 install torch torchsummary
!pip3 install -q pretty_midi
!pip3 install -q gdown

<font size=6>Imports</font>

In [None]:
import torch
import os
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
import zipfile
import requests
import numpy as np
import torch.utils.data as data
import time
import matplotlib.pyplot as plt
import random
import copy
import gdown
from torchsummary import summary
from collections import defaultdict
from torchvision import transforms
from glob import glob
import csv
import pretty_midi
from pretty_midi import PrettyMIDI

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

print("Using torch", torch.__version__)

<font size=6>Auxlilliary Data Structures</font>

In [None]:
class SongData:
    def __init__(self, song_data_cell: list[str], midi_file: PrettyMIDI):
        if len(song_data_cell) != 3:
            raise ValueError("song_data_cell must have exactly three elements: [artist, title, lyrics]")
        self.artist = song_data_cell[0]
        self.title = song_data_cell[1]
        self.lyrics = song_data_cell[2]
        self.midi_data = midi_file

In [None]:
class SongDataset(data.Dataset):
    def __init__(self):
        raise NotImplementedError("Implement this later")

<font size=6>Constants</font>

In [None]:
LYRIC_TRAIN_SET_CSV_PATH = os.path.join(os.getcwd(), 'data', 'lyrics_train_set.csv')
LYRIC_TEST_SET_CSV_PATH = os.path.join(os.getcwd(), 'data', 'lyrics_test_set.csv')
MIDI_FILE_PATH = os.path.join(os.getcwd(), 'data', 'midi_files')

<font size=6>Reading CSV files</font>

In [None]:
with open(LYRIC_TRAIN_SET_CSV_PATH, mode='r', encoding='utf-8') as train_file:
    reader = csv.reader(train_file)
    lyric_train_data = list(reader)

with open(LYRIC_TEST_SET_CSV_PATH, mode='r', encoding='utf-8') as test_file:
    reader = csv.reader(test_file)
    lyric_test_data = list(reader)

if len(lyric_train_data) < 1 or len(lyric_test_data) < 1:
    raise Exception("CSV files are empty or not found.")

<font size=6>Reading MIDI files</font>

In [None]:
if not os.path.isdir(MIDI_FILE_PATH):
    raise ValueError(f"MIDI file path {MIDI_FILE_PATH} is not a valid directory.")

# Traversing over all files and attempt to load them with pretty_midi:
loaded_midi_files: dict[str, dict[str, PrettyMIDI]] = defaultdict(dict) # artist -> title -> PrettyMIDI

failed_loads: list[str] = list()
for file in os.listdir(MIDI_FILE_PATH):
    if file.endswith('.mid') or file.endswith('.midi'):
        file_path = os.path.join(MIDI_FILE_PATH, file)
        try:
            midi_data = pretty_midi.PrettyMIDI(file_path)
            file = file.strip('.mid').replace('_', ' ').strip().lower()
            artist, title = file.split('_-_', 1)
            artist = artist.replace('_', ' ').strip().lower()
            loaded_midi_files[artist][title] = midi_data
        except Exception as e:
            print(f"Failed to load {file}: {e}")
            failed_loads.append(file)



if failed_loads:
    print("Failed to load the following MIDI files:")
    for failed_file in failed_loads:
        print(f" - {failed_file}")

print(f"Successfully loaded {sum([len(songs) for songs in loaded_midi_files.values()])} MIDI files.")