<font size=6>Downloading required libraries</font>

In [None]:
!pip3 install -q numpy
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
!pip3 install torch torchsummary
!pip3 install -q pretty_midi
!pip3 install -q gdown
!pip3 install -q gensim
!pip3 install -q nltk

<font size=6>Imports</font>

In [None]:
import torch
import os
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
import zipfile
import requests
import numpy as np
import torch.utils.data as data
import time
import matplotlib.pyplot as plt
import random
import copy
import gdown
from torchsummary import summary
from collections import defaultdict
from torchvision import transforms
from glob import glob
from typing import Optional
import csv
import string
from pretty_midi import PrettyMIDI
import re
import gensim.downloader
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
import pickle

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

print("Using torch", torch.__version__)

<font size=6>Auxlilliary Data Structures</font>

In [None]:
class SongData:
    def __init__(self, song_data_cell: list[str], midi_file: PrettyMIDI):
        if len(song_data_cell) != 3:
            raise ValueError("song_data_cell must have exactly three elements: [artist, title, lyrics]")
        self.artist = song_data_cell[0]
        self.title = song_data_cell[1]
        self.lyrics = song_data_cell[2]
        self.midi_data = midi_file

In [None]:
class SongDataset(data.Dataset):
    def __init__(self):
        raise NotImplementedError("Implement this later")

<font size=6>Constants</font>

In [None]:
LYRIC_TRAIN_SET_CSV_PATH: str = os.path.join(os.getcwd(), 'data', 'lyrics_train_set.csv')
LYRIC_TEST_SET_CSV_PATH: str = os.path.join(os.getcwd(), 'data', 'lyrics_test_set.csv')
MIDI_FILE_PATH: str = os.path.join(os.getcwd(), 'data', 'midi_files')
PICKLING_PATH: str = os.path.join(os.getcwd(), 'loaded_midi_files.pkl') # Path to save/load pickled MIDI files, for faster loading.

<font size=6>Reading CSV files</font>

In [None]:
with open(LYRIC_TRAIN_SET_CSV_PATH, mode='r', encoding='utf-8') as train_file:
    reader = csv.reader(train_file)
    lyric_train_data = list(reader)

with open(LYRIC_TEST_SET_CSV_PATH, mode='r', encoding='utf-8') as test_file:
    reader = csv.reader(test_file)
    lyric_test_data = list(reader)

if len(lyric_train_data) < 1 or len(lyric_test_data) < 1:
    raise Exception("CSV files are empty or not found.")

<font size=6>Parsing CSV files</font>

In [None]:
def clean_csv_data(raw_csv_data: list[list[str]]) -> list[tuple[str, str, list[str]]]:
    returned_cleaned_csv_data: list[tuple[str, str, list[str]]] = []
    for row in raw_csv_data:
        artist = row[0].strip()
        title_index = 1
        lyrics_index = 2
        while lyrics_index < len(row):
            title = row[title_index].strip()
            title = title.removesuffix('-2') # Remove '-2' suffix if present, relevant in 1 case.
            title = row[title_index].strip()
            lyrics = row[lyrics_index].strip()
            lyrics = lyrics.lower()
            lyrics = re.sub(f"[{re.escape('&')}]", " eol ", lyrics) # Changing ampersands to eol to indicate end of line.
            lyrics = re.sub(f"[{re.escape('\'')}]", "", lyrics) # Removing apostrophes.
            lyrics = re.sub(f"[{re.escape('-')}]", " ", lyrics) # Removing hyphens.
            lyrics = re.sub(f"[{re.escape(string.punctuation)}]", "", lyrics) # Removing punctuation.
            lyrics = lyrics.split(' ') # Tokenzing each word by space.
            lyrics = [word.strip() for word in lyrics if word] # Removing empty strings.
            if len(title) > 0 and len(lyrics) > 0:
                returned_cleaned_csv_data.append((artist, title, lyrics))
            title_index += 2
            lyrics_index += 2
    return returned_cleaned_csv_data

cleaned_lyric_train_data = clean_csv_data(lyric_train_data)
cleaned_lyric_test_data = clean_csv_data(lyric_test_data)

In [None]:
# count the number of unique words in the lyrics
def count_unique_words(lyrics_data: list[tuple[str, str, list[str]]]) -> int:
    unique_words = set()
    for _, _, lyrics in lyrics_data:
        unique_words.update(lyrics)
    return len(unique_words)    
print(f"Number of unique words in training set: {count_unique_words(cleaned_lyric_train_data)}")
print(f"Number of unique words in test set: {count_unique_words(cleaned_lyric_test_data)}")

<font size=6>Reading MIDI files</font>

In [None]:
def load_midi_files(midi_files_location: str, pickling_path: Optional[str] = None, failed_loads_path: Optional[str] = None) -> \
                    tuple[dict[str, dict[str, PrettyMIDI]], dict[str, set[str]]]: # artist -> title -> PrettyMIDI, failed loads[artist, song_set]
    failed_loads = dict()
    if failed_loads_path is not None and os.path.isfile(failed_loads_path):
        with open(failed_loads_path, "rb") as f:
            failed_loads = pickle.load(f)
        print(f"Loaded failed MIDI loads from pickled file {failed_loads_path}.")
    if pickling_path is not None and os.path.isfile(pickling_path):
        with open(pickling_path, "rb") as f:
            loaded_midi_files = pickle.load(f)
        print(f"Loaded MIDI files from pickled file {pickling_path}.")
        return loaded_midi_files, failed_loads
    if not os.path.isdir(midi_files_location):
        raise ValueError(f"MIDI file path {midi_files_location} is not a valid directory.")

    # Traversing over all files and attempt to load them with pretty_midi:
    loaded_midi_files: dict[str, dict[str, PrettyMIDI]] = defaultdict(dict) # artist -> title -> PrettyMIDI
    failed_loads: dict[str, set[str]] = defaultdict(set)

    for file in os.listdir(midi_files_location):
        if file.endswith('.mid') or file.endswith('.midi'):
            file_path = os.path.join(midi_files_location, file)
            file = file.removesuffix('.mid')
            splitted_artist_and_title = file.split('_-_')
            artist = splitted_artist_and_title[0]
            title = splitted_artist_and_title[1]
            if len(splitted_artist_and_title) > 2:
                print(f"Warning: file {file} has more than one '_-_' separator, ignoring the rest after second \"_-_\".")
            artist = artist.replace('_', ' ').strip().lower()
            title = title.replace('_', ' ').strip().lower()
            try:
                midi_data = PrettyMIDI(file_path)
                loaded_midi_files[artist][title] = midi_data
            except Exception as e:
                print(f"Failed to load {file}: {e}")
                failed_loads[artist].add(title)



    if failed_loads:
        print("Failed to load the following artist and lyric midi files:")
        for artist, lyrics in failed_loads.items():
            print(f"{artist} - [{', '.join(lyrics)}]")

    if pickling_path is not None:
        with open(pickling_path, "wb") as f:
            pickle.dump(loaded_midi_files, f)
            print(f"Pickled loaded MIDI files to {pickling_path}.")
    if failed_loads_path is not None:
        with open(failed_loads_path, "wb") as f:
            pickle.dump(failed_loads, f)
            print(f"Pickled failed MIDI loads to {failed_loads_path}.")

    print(f"Successfully loaded {sum([len(songs) for songs in loaded_midi_files.values()])} MIDI files.")
    return loaded_midi_files, failed_loads

In [None]:
loaded_midi_files, failed_midi_loads = load_midi_files(MIDI_FILE_PATH, PICKLING_PATH)

<font size=6>Mapping CSV data to MIDI files</font>

In [None]:
def csv_data_to_songdata_list(csv_data: list[list[str]], 
                              failed_midi_load: dict[str, set[str]], 
                              midi_files_dict: dict[str, dict[str, PrettyMIDI]]) -> list[SongData]:
    song_data_list: list[SongData] = list()
    missing_midi_count = 0
    for row in csv_data:
        artist = row[0]
        title = row[1]
        if artist in failed_midi_load and title in failed_midi_load[artist]:
            print(f"Skipping {artist} - {title} due to previous MIDI load failure.")
            continue
        if artist in midi_files_dict and title in midi_files_dict[artist]:
            midi_file = midi_files_dict[artist][title]
            song_data = SongData(row, midi_file)
            song_data_list.append(song_data)
        else:
            missing_midi_count += 1
            print(f"Missing MIDI file for artist '{artist}' and title '{title}'")
    print(f"Total songs with missing MIDI files: {missing_midi_count}")
    return song_data_list

In [None]:
train_midi_data: list[SongData] = csv_data_to_songdata_list(cleaned_lyric_train_data, failed_midi_loads, loaded_midi_files)
test_midi_data: list[SongData] = csv_data_to_songdata_list(cleaned_lyric_test_data, failed_midi_loads, loaded_midi_files)
print(f"Total training songs with MIDI data: {len(train_midi_data)}")
print(f"Total test songs with MIDI data: {len(test_midi_data)}")

<font size=6>Handling word embeddings</font>

Downloading pretrained word2vec, containing 300 dims, trained on news articles

In [None]:
pretrained_word2vec = gensim.downloader.load('word2vec-google-news-300')

Extracting the vocabulary from the lyrics.
Getting the data from the test set aswell since the vocbulary needs to be known.

In [None]:
lyrics_vocabulary: set[str] = set()
# Getting the data from the test set aswell since the vocbulary needs to be known
for song in train_midi_data + test_midi_data:
    for word in song.lyrics:
        lyrics_vocabulary.add(word)
print(f"Total unique words in lyrics vocabulary: {len(lyrics_vocabulary)}")
print(lyrics_vocabulary)

Creating unified embedding.
Extracting embeddings from word2vec and using random embeddings for words not found in word2vec.

In [None]:
unified_embeddings: dict[str, np.ndarray] = dict()
existing_words_in_pretrained = 0
not_existing_in_pretrained = 0
added_stopwords = 0
for word in list(lyrics_vocabulary):
    if word in pretrained_word2vec:
        unified_embeddings[word] = pretrained_word2vec[word]
        existing_words_in_pretrained += 1
    else:
        unified_embeddings[word] = np.random.uniform(low=-1.0, high=1.0, size=(pretrained_word2vec.vector_size,)) # Random init for unknown words.  
        not_existing_in_pretrained += 1
        print(f'Word not found in pretrained embeddings: {word}')
    # Adding stopwords as well, since they are common and should be in the vocabulary.
for stopword in stopwords.words('english'):
    cleaned_stopword = re.sub(f"[{re.escape(string.punctuation)}]", " ", stopword.strip().lower()) # Cleaning the stopword, since it contains punctuation.
    if cleaned_stopword not in unified_embeddings:
        unified_embeddings[cleaned_stopword] = np.random.uniform(low=-1.0, high=1.0, size=(pretrained_word2vec.vector_size,))
        added_stopwords += 1

print(f"Total unique words in lyrics vocabulary: {len(lyrics_vocabulary)}")
print(f"Existing words in pretrained embeddings: {existing_words_in_pretrained}")
print(f"Not existing in pretrained embeddings (randomly initialized): {not_existing_in_pretrained}")
print(f"Added stopwords (randomly initialized): {added_stopwords}")

<font size=6>Midi Feature extraction</font>