<font size=6>Downloading required libraries</font>

In [None]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
!pip3 install torch torchsummary
!pip3 install -q pretty_midi
!pip3 install -q gdown

<font size=6>Imports</font>

In [None]:
import torch
import os
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
import zipfile
import requests
import numpy as np
import torch.utils.data as data
import time
import matplotlib.pyplot as plt
import random
import copy
import gdown
from torchsummary import summary
from collections import defaultdict
from torchvision import transforms
from glob import glob
import csv
import pretty_midi
from pretty_midi import PrettyMIDI

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

print("Using torch", torch.__version__)

<font size=6>Auxlilliary Data Structures</font>

In [None]:
class SongData:
    def __init__(self, song_data_cell: list[str], midi_file: PrettyMIDI):
        if len(song_data_cell) != 3:
            raise ValueError("song_data_cell must have exactly three elements: [artist, title, lyrics]")
        self.artist = song_data_cell[0]
        self.title = song_data_cell[1]
        self.lyrics = song_data_cell[2]
        self.midi_data = midi_file

In [None]:
class SongDataset(data.Dataset):
    def __init__(self):
        raise NotImplementedError("Implement this later")

<font size=6>Constants</font>

In [None]:
LYRIC_TRAIN_SET_CSV_PATH = os.path.join(os.getcwd(), 'data', 'lyrics_train_set.csv')
LYRIC_TEST_SET_CSV_PATH = os.path.join(os.getcwd(), 'data', 'lyrics_test_set.csv')
MIDI_FILE_PATH = os.path.join(os.getcwd(), 'data', 'midi_files')

<font size=6>Reading CSV files</font>

In [None]:
with open(LYRIC_TRAIN_SET_CSV_PATH, mode='r', encoding='utf-8') as train_file:
    reader = csv.reader(train_file)
    lyric_train_data = list(reader)

with open(LYRIC_TEST_SET_CSV_PATH, mode='r', encoding='utf-8') as test_file:
    reader = csv.reader(test_file)
    lyric_test_data = list(reader)

if len(lyric_train_data) < 1 or len(lyric_test_data) < 1:
    raise Exception("CSV files are empty or not found.")

<font size=6>Parsing CSV files</font>

In [None]:
def clean_csv_data(raw_csv_data: list[list[str]]) -> list[list[str]]:
    returned_cleaned_csv_data: list[list[str]] = []
    for row in raw_csv_data:
        artist = row[0].strip()
        title_index = 1
        lyrics_index = 2
        while lyrics_index < len(row):
            title = row[title_index].strip()
            title = title.removesuffix('-2') # Remove '-2' suffix if present, relevant in 1 case.
            title = row[title_index].strip()
            lyrics = row[lyrics_index].strip()
            lyrics = lyrics.replace('&', '\n')
            if len(title) > 0 and len(lyrics) > 0:
                returned_cleaned_csv_data.append([artist, title, lyrics])
            title_index += 2
            lyrics_index += 2
    return returned_cleaned_csv_data

cleaned_lyric_train_data = clean_csv_data(lyric_train_data)
cleaned_lyric_test_data = clean_csv_data(lyric_test_data)

<font size=6>Reading MIDI files</font>

In [None]:
if not os.path.isdir(MIDI_FILE_PATH):
    raise ValueError(f"MIDI file path {MIDI_FILE_PATH} is not a valid directory.")

# Traversing over all files and attempt to load them with pretty_midi:
loaded_midi_files: dict[str, dict[str, PrettyMIDI]] = defaultdict(dict) # artist -> title -> PrettyMIDI
failed_loads: dict[str, set[str]] = defaultdict(set)

for file in os.listdir(MIDI_FILE_PATH):
    if file.endswith('.mid') or file.endswith('.midi'):
        file_path = os.path.join(MIDI_FILE_PATH, file)
        file = file.removesuffix('.mid')
        splitted_artist_and_title = file.split('_-_')
        artist = splitted_artist_and_title[0]
        title = splitted_artist_and_title[1]
        if len(splitted_artist_and_title) > 2:
            print(f"Warning: file {file} has more than one '_-_' separator, ignoring the rest after second \"_-_\".")
        artist = artist.replace('_', ' ').strip().lower()
        title = title.replace('_', ' ').strip().lower()
        try:
            midi_data = pretty_midi.PrettyMIDI(file_path)
            loaded_midi_files[artist][title] = midi_data
        except Exception as e:
            print(f"Failed to load {file}: {e}")
            failed_loads[artist].add(title)



if failed_loads:
    print("Failed to load the following artist and lyric files:")
    for artist, lyrics in failed_loads.items():
        print(f"{artist} - [{', '.join(lyrics)}]")

print(f"Successfully loaded {sum([len(songs) for songs in loaded_midi_files.values()])} MIDI files.")

<font size=6>Mapping CSV data to MIDI files</font>

In [None]:
def csv_data_to_songdata_list(csv_data: list[list[str]], 
                              failed_midi_load: dict[str, set[str]], 
                              midi_files_dict: dict[str, dict[str, PrettyMIDI]]) -> list[SongData]:
    song_data_list: list[SongData] = list()
    missing_midi_count = 0
    for row in csv_data:
        artist = row[0]
        title = row[1]
        if artist in failed_midi_load and title in failed_midi_load[artist]:
            print(f"Skipping {artist} - {title} due to previous MIDI load failure.")
            continue
        if artist in midi_files_dict and title in midi_files_dict[artist]:
            midi_file = midi_files_dict[artist][title]
            song_data = SongData(row, midi_file)
            song_data_list.append(song_data)
        else:
            missing_midi_count += 1
            print(f"Missing MIDI file for artist '{artist}' and title '{title}'")
    print(f"Total songs with missing MIDI files: {missing_midi_count}")
    return song_data_list

In [None]:
train_midi_data = csv_data_to_songdata_list(cleaned_lyric_train_data, failed_loads, loaded_midi_files)
test_midi_data = csv_data_to_songdata_list(cleaned_lyric_test_data, failed_loads, loaded_midi_files)
print(f"Total training songs with MIDI data: {len(train_midi_data)}")
print(f"Total test songs with MIDI data: {len(test_midi_data)}")