<font size=6>Downloading required libraries</font>

In [None]:
!pip3 install -q numpy
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
!pip3 install torch torchsummary
!pip3 install -q pretty_midi
!pip3 install -q gensim
!pip3 install -q nltk

<font size=6>Imports</font>

In [None]:
import torch
import os
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
import zipfile
import requests
import numpy as np
import torch.utils.data as data
import time
import matplotlib.pyplot as plt
import random
import copy
import gdown
import numpy as np
from torchsummary import summary
from collections import defaultdict
from torchvision import transforms
from glob import glob
from typing import Optional
import csv
import string
from pretty_midi import PrettyMIDI, Note
import re
import gensim.downloader
import nltk
from nltk.corpus import stopwords
from collections import Counter
nltk.download('stopwords')
import pickle

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

print("Using torch", torch.__version__)

<font size=6>Constants</font>

In [None]:
LYRIC_TRAIN_SET_CSV_PATH: str = os.path.join(os.getcwd(), 'data', 'lyrics_train_set.csv')
LYRIC_TEST_SET_CSV_PATH: str = os.path.join(os.getcwd(), 'data', 'lyrics_test_set.csv')
MIDI_FILE_PATH: str = os.path.join(os.getcwd(), 'data', 'midi_files')
PICKLING_PATH: str = os.path.join(os.getcwd(), 'loaded_midi_files.pkl') # Path to save/load pickled MIDI files, for faster loading.
EPSILON: float = 1e-9
SEQUENCE_LENGTH = 10  # Number of words in the input sequence

<font size=6>Midi Feature extraction</font>

Auxlliary functions

In [None]:
def most_common_time_signature(changes: tuple[int, int]) -> tuple[int, int]:
    if not changes: return (4, 4)
    pairs: list[tuple[int, int]] = [(ts.numerator, ts.denominator) for ts in changes]
    return Counter(pairs).most_common(1)[0][0]

In [None]:
def get_duration_weighted_pitch_stats(notes: list[Note]) -> dict[str, float]:
    """
    Compute pitch statistics weighted by each note's duration.
    Returns a dict with:
      mean (duration-weighted),
      std  (duration-weighted),
      p10 / p50 / p90 (duration-weighted percentiles),
      ambitus = p90 - p10 (robust range).
    If `notes` is empty, returns safe defaults.
    """
    # Empty guard: nothing to measure → return neutral stats.
    if not notes:
        return dict(mean=0.0, std=0.0, p10=-1, p50=-1, p90=-1, ambitus=0.0)

    # Vectorize pitches as float for math (MIDI 0..127, but floats simplify ops).
    pitches: np.ndarray[np.float32] = np.fromiter((n.pitch for n in notes), dtype=np.float32)

    # Each note's weight = its duration in seconds; clamp tiny/negative to epsilon.
    weights: np.ndarray[np.float32] = np.fromiter((max(EPSILON, n.end - n.start) for n in notes), dtype=np.float32)
    total_weights: float = weights.sum()
    duration_weight_mean: float = float((weights * pitches).sum() / total_weights)
    duration_weighted_variance: float = float((weights * (pitches - duration_weight_mean) ** 2).sum() / total_weights)
    weighted_std: float = duration_weighted_variance ** 0.5

    # ---------- Duration-weighted percentiles ----------
    order: np.ndarray[np.int32] = np.argsort(pitches)
    ordered_pitches, ordered_weights = pitches[order], weights[order]
    cumulative_weight_sum: np.ndarray[np.float32] = np.cumsum(ordered_weights)

    # Weighted quantile: find the first index where cumulative weight crosses q%.
    def weighted_quantile(quantile: float) -> float:
        # Target cumulative weight at quantile q (0..100).
        target: float = (quantile / 100.0) * cumulative_weight_sum[-1]
        # Index where cumulative_weight_sum >= target; take leftmost to be consistent.
        idx: int = np.searchsorted(cumulative_weight_sum, target, side="left")
        return float(ordered_pitches[min(idx, len(ordered_pitches) - 1)])

    # 10th / 50th (median) / 90th percentiles, duration-weighted.
    percentile_10, percentile_50, percentile_90 = weighted_quantile(10), weighted_quantile(50), weighted_quantile(90)

    # Ambitus = robust spread (p90 - p10), less sensitive than raw max - min.
    return dict(mean=duration_weight_mean, 
                std=weighted_std,
                p10=percentile_10, 
                p50=percentile_50, 
                p90=percentile_90)


Extracting high level features relating to the entire song

In [None]:
def extract_midi_features(midi: PrettyMIDI) -> np.ndarray:
    """
    Return song-level features (vector+names).
    """ 
    duration_sec: float = midi.get_end_time()                                       # total length
    if duration_sec <= 0: raise ValueError("Empty/zero-length MIDI.")        # guard

    tempo_times, tempo_bpms = midi.get_tempo_changes()                               # tempo changes
    if len(tempo_bpms) == 0:                                                     # no changes
        tempo_times = np.array([0.0], dtype=np.float32)                          # start time
        tempo_bpms = np.array([midi.estimate_tempo()], dtype=np.float32)         # single bpm
    segment_ends: np.ndarray[np.float32] = np.r_[tempo_times[1:], duration_sec]                              # segment ends
    segment_durs: np.ndarray[np.float32] = np.maximum(1e-6, segment_ends - tempo_times[:len(segment_ends)])          # segment durations
    tempo_mean: float = float(np.dot(tempo_bpms[:len(segment_durs)], segment_durs) / np.sum(segment_durs)) # duration-weighted mean
    tempo_std: float = float(np.std(np.repeat(
        tempo_bpms[:len(segment_durs)],                                              # repeat bpm by
        np.maximum(1, (segment_durs/np.sum(segment_durs)*1000).astype(int))          # rough weights
    )))                                                                       # dispersion proxy
    tempo_change_count: int = int(len(tempo_bpms))                                     # number of states

    time_signature_numerator, time_signature_denominator = most_common_time_signature(midi.time_signature_changes)  # mode time sig
    instrument_count: int = sum(1 for inst in midi.instruments                     # non-drum count
                          if not inst.is_drum and inst.notes)

    instruments: list = [inst for inst in midi.instruments if not inst.is_drum and inst.notes]
    instruments_velocities: list[float] = []
    instrument_notes: list[Note] = []
    for instrument in instruments:                                 # melody track
        mel_velocity = [note.velocity for note in instrument.notes]     # melody velocities
        instruments_velocities.extend(mel_velocity)
        instrument_notes.extend(instrument.notes)

    instrument_velocities_min: float = min(instruments_velocities)      # min pitch
    instrument_velocities_max: float = max(instruments_velocities)      # max pitch
    instrument_velocities_mean: float = np.mean(instruments_velocities)    # mean pitch
    instrument_velocities_std: float = np.std(instruments_velocities)     # std pitch

    duration_weight_pitch_stats_dict: dict = get_duration_weighted_pitch_stats(instrument_notes)
    instrument_pitch_10_percentile: float = duration_weight_pitch_stats_dict['p10']
    instrument_pitch_50_percentile: float = duration_weight_pitch_stats_dict['p50']
    instrument_pitch_90_percentile: float = duration_weight_pitch_stats_dict['p90']
    instrument_pitch_mean: float = duration_weight_pitch_stats_dict['mean']
    instrument_pitch_std: float = duration_weight_pitch_stats_dict['std']
    instrument_pitch_range_by_percentiles: float = instrument_pitch_90_percentile - instrument_pitch_10_percentile

    note_durations: list[float] = [note.end - note.start for note in instrument_notes]
    note_durations_mean: float = np.mean(note_durations) if note_durations else 0.0
    note_durations_std: float = np.std(note_durations) if note_durations else 0.0
    note_durations_range: float = max(note_durations) - min(note_durations) if note_durations else 0.0

    note_density: float = float(len(instrument.notes) / max(EPSILON, duration_sec))           # notes/sec

    chroma_global = midi.get_pitch_class_histogram(use_duration=True)        # 12-bin chroma
    chroma_global = chroma_global / (np.sum(chroma_global) + EPSILON)           # normalize
    names = [                                                                # feature names
        "duration_sec",
        "tempo_mean_bpm", 
        "tempo_std_bpm", 
        "tempo_change_count",
        "time_sig_num", 
        "time_sig_den",
        "instrument_count",
        "instrument_velocities_min", 
        "instrument_velocities_max", 
        "instrument_velocities_mean", 
        "instrument_velocities_std", 
        "instrument_pitch_10_percentile",
        "instrument_pitch_50_percentile",
        "instrument_pitch_90_percentile",
        "instrument_pitch_mean",
        "instrument_pitch_std",
        "instrument_pitch_range_by_percentiles",
        "note_durations_mean",
        "note_durations_std",
        "note_durations_range",
        "melody_note_density_per_sec",
    ] + [f"chroma_{i}" for i in range(12)]                                   # chroma names
    vec = np.array([                                                         # feature vector
        duration_sec, 
        tempo_mean, 
        tempo_std, 
        tempo_change_count,
        time_signature_numerator, 
        time_signature_denominator,   
        instrument_count,
        instrument_velocities_min, 
        instrument_velocities_max, 
        instrument_velocities_mean, 
        instrument_velocities_std, 
        instrument_pitch_10_percentile,
        instrument_pitch_50_percentile,
        instrument_pitch_90_percentile,
        instrument_pitch_mean,
        instrument_pitch_std,
        instrument_pitch_range_by_percentiles,
        note_durations_mean,
        note_durations_std,
        note_durations_range,   
        note_density,
        *chroma_global.tolist()
    ], dtype=np.float32)

    return {"vector": vec, "names": names} 

<font size=6>Auxlilliary Data Structures</font>

Auxilliary functions for creation of word sequences and targets

In [None]:
def create_word_sequences_with_targets(tokenized_lyrics: list[str], sequence_length: int = SEQUENCE_LENGTH) -> tuple[np.ndarray, np.ndarray]:
    """
    Given tokenized lyrics as a list of strings, create sequences of word indices and their corresponding target word indices.
    Each sequence is of length `sequence_length`, and the target is the next word following the sequence.
    """
    sequences = []
    targets = []
    
    # Create sequences and targets
    for i in range(len(tokenized_lyrics) - sequence_length):
        seq = tokenized_lyrics[i:i + sequence_length]
        target = tokenized_lyrics[i + sequence_length]
        sequences.append(seq)
        targets.append(target)
    
    return sequences,targets

In [None]:
class SongData:
    def __init__(self, song_data_cell: list[str] = None, midi_file: PrettyMIDI = None):
        if len(song_data_cell) != 3:
            raise ValueError("song_data_cell must have exactly three elements: [artist, title, lyrics]")
        self.artist = song_data_cell[0]
        self.title = song_data_cell[1]
        self.lyrics = song_data_cell[2]
        self.midi_data = midi_file
        self._midi_features: Optional[dict[str, np.ndarray]] = None

    @property
    def midi_features(self):
        if self._midi_features is None:
            self._midi_features = extract_midi_features(self.midi_data)
        return self._midi_features

In [None]:
class SongDataset(data.Dataset):
    def __init__(self, 
                 songs_data: list[SongData],
                 word_embeddings: dict[str, np.ndarray],
                 artist_embeddings: dict[str, np.ndarray],
                 artist_vocab: dict[str, int]):
        self.midi_features: list[np.ndarray] = list()
        self.artists: list[str] = list()
        self.sequence_artists: list[str] = list()
        self.word_sequences: list[str] = list()
        self.sequences_targets: list[str] = list()
        self.sequence_to_midi: list[int] = list() # Maps each sequence to its corresponding MIDI feature index 
        self.sequence_to_artist: list[int] = list() # Maps each sequence to its corresponding artist embedding index
        self.word_embeddings: dict[str, np.ndarray] = word_embeddings
        self.artist_embeddings: dict[str, np.ndarray] = artist_embeddings
        self.artist_vocab: dict[str, int] = artist_vocab
        # Instead of saving each sequence's MIDI features, we save the index of the MIDI features in the midi_features list to save space.
        for idx, song in enumerate(songs_data):
            sequences, targets = create_word_sequences_with_targets(song.lyrics)
            self.word_sequences.extend(sequences)
            self.sequences_targets.extend(targets)
            self.midi_features.append(song.midi_features['vector']) # Creates a mapping of the features to the sequences.
            self.sequence_artists.append(song.artist)
            self.sequence_to_midi.extend([idx] * len(sequences))
            self.sequence_to_artist.extend([idx] * len(sequences))
        print(f'Dataset has: {len(self.word_sequences)} sequences and {len(self.sequences_targets)} targets')

    
    def word_vec(self, tok: str) -> np.ndarray:
        # helper to get word vector, or zeros if OOV
        v = self.word_embeddings.get(tok)
        if v is None:
            # OOV → zeros with same dim as any known word vector
            sample = next(iter(self.word_embeddings.values()))
            v = np.zeros_like(sample, dtype=np.float32)
        return v.astype(np.float32, copy=False)

    def __len__(self):
        return len(self.word_sequences)

    def __getitem__(self, idx: int):
        word_sequence = self.word_sequences[idx]          
        target_word = self.sequences_targets[idx]
        midi_feature: np.ndarray = self.midi_features[self.sequence_to_midi[idx]]
        artist_embedding: np.ndarray = self.artist_embeddings[self.artist_vocab[self.sequence_artists[idx]]]
        stacked_word_sequence = np.stack([self.word_vec(tok) for tok in word_sequence], axis=0)
        stacked_word_sequence_with_midi_features = np.broadcast_to(midi_feature, (stacked_word_sequence.shape[0], midi_feature.shape[0]))
        stacked_word_sequence_with_artist_embedding = np.broadcast_to(artist_embedding, (stacked_word_sequence.shape[0], artist_embedding.shape[0]))
        concatenated_features = np.concatenate((stacked_word_sequence, stacked_word_sequence_with_midi_features, stacked_word_sequence_with_artist_embedding), axis=1)
        # Debug prints to verify shapes and contents
        print(f"Word sequence shape: {stacked_word_sequence.shape}")
        print(f"MIDI feature shape: {midi_feature.shape}")
        print(f"Artist embedding shape: {artist_embedding.shape}")
        print(f"Concatenated features shape: {concatenated_features.shape}")
        return concatenated_features, target_word

<font size=6>Reading CSV files</font>

In [None]:
with open(LYRIC_TRAIN_SET_CSV_PATH, mode='r', encoding='utf-8') as train_file:
    reader = csv.reader(train_file)
    lyric_train_data = list(reader)

with open(LYRIC_TEST_SET_CSV_PATH, mode='r', encoding='utf-8') as test_file:
    reader = csv.reader(test_file)
    lyric_test_data = list(reader)

if len(lyric_train_data) < 1 or len(lyric_test_data) < 1:
    raise Exception("CSV files are empty or not found.")

<font size=6>Parsing CSV files</font>

In [None]:
def clean_csv_data(raw_csv_data: list[list[str]]) -> list[tuple[str, str, list[str]]]:
    returned_cleaned_csv_data: list[tuple[str, str, list[str]]] = []
    for row in raw_csv_data:
        artist = row[0].strip()
        title_index = 1
        lyrics_index = 2
        while lyrics_index < len(row):
            title = row[title_index].strip()
            title = title.removesuffix('-2') # Remove '-2' suffix if present, relevant in 1 case.
            title = row[title_index].strip()
            lyrics = row[lyrics_index].strip()
            lyrics = lyrics.lower()
            lyrics = re.sub(f"[{re.escape('&')}]", " eol ", lyrics) # Changing ampersands to eol to indicate end of line.
            lyrics = re.sub(f"[{re.escape('\'')}]", "", lyrics) # Removing apostrophes.
            lyrics = re.sub(f"[{re.escape('-')}]", " ", lyrics) # Removing hyphens.
            lyrics = re.sub(f"[{re.escape(string.punctuation)}]", "", lyrics) # Removing punctuation.
            lyrics = lyrics.split(' ') # Tokenzing each word by space.
            lyrics = [word.strip() for word in lyrics if word] # Removing empty strings.
            lyrics.append('eos') # Adding end of song token.
            if len(title) > 0 and len(lyrics) > 0:
                returned_cleaned_csv_data.append((artist, title, lyrics))
            title_index += 2
            lyrics_index += 2
    return returned_cleaned_csv_data

cleaned_lyric_train_data = clean_csv_data(lyric_train_data)
cleaned_lyric_test_data = clean_csv_data(lyric_test_data)

In [None]:
# count the number of unique words in the lyrics
def count_unique_words(lyrics_data: list[tuple[str, str, list[str]]]) -> int:
    unique_words = set()
    for _, _, lyrics in lyrics_data:
        unique_words.update(lyrics)
    return len(unique_words)    
print(f"Number of unique words in training set: {count_unique_words(cleaned_lyric_train_data)}")
print(f"Number of unique words in test set: {count_unique_words(cleaned_lyric_test_data)}")

<font size=6>Reading MIDI files</font>

In [None]:
def load_midi_files(midi_files_location: str, pickling_path: Optional[str] = None, failed_loads_path: Optional[str] = None) -> \
                    tuple[dict[str, dict[str, PrettyMIDI]], dict[str, set[str]]]: # artist -> title -> PrettyMIDI, failed loads[artist, song_set]
    failed_loads = dict()
    if failed_loads_path is not None and os.path.isfile(failed_loads_path):
        with open(failed_loads_path, "rb") as f:
            failed_loads = pickle.load(f)
        print(f"Loaded failed MIDI loads from pickled file {failed_loads_path}.")
    if pickling_path is not None and os.path.isfile(pickling_path):
        with open(pickling_path, "rb") as f:
            loaded_midi_files = pickle.load(f)
        print(f"Loaded MIDI files from pickled file {pickling_path}.")
        print(f'Loaded {sum([len(songs) for songs in loaded_midi_files.values()])} MIDI files.')
        return loaded_midi_files, failed_loads
    if not os.path.isdir(midi_files_location):
        raise ValueError(f"MIDI file path {midi_files_location} is not a valid directory.")

    # Traversing over all files and attempt to load them with pretty_midi:
    loaded_midi_files: dict[str, dict[str, PrettyMIDI]] = defaultdict(dict) # artist -> title -> PrettyMIDI
    failed_loads: dict[str, set[str]] = defaultdict(set)

    for file in os.listdir(midi_files_location):
        if file.endswith('.mid') or file.endswith('.midi'):
            file_path = os.path.join(midi_files_location, file)
            file = file.removesuffix('.mid')
            splitted_artist_and_title = file.split('_-_')
            artist = splitted_artist_and_title[0]
            title = splitted_artist_and_title[1]
            if len(splitted_artist_and_title) > 2:
                print(f"Warning: file {file} has more than one '_-_' separator, ignoring the rest after second \"_-_\".")
            artist = artist.replace('_', ' ').strip().lower()
            title = title.replace('_', ' ').strip().lower()
            try:
                midi_data = PrettyMIDI(file_path)
                loaded_midi_files[artist][title] = midi_data
            except Exception as e:
                print(f"Failed to load {file}: {e}")
                failed_loads[artist].add(title)



    if failed_loads:
        print("Failed to load the following artist and lyric midi files:")
        for artist, lyrics in failed_loads.items():
            print(f"{artist} - [{', '.join(lyrics)}]")

    if pickling_path is not None:
        with open(pickling_path, "wb") as f:
            pickle.dump(loaded_midi_files, f)
            print(f"Pickled loaded MIDI files to {pickling_path}.")
    if failed_loads_path is not None:
        with open(failed_loads_path, "wb") as f:
            pickle.dump(failed_loads, f)
            print(f"Pickled failed MIDI loads to {failed_loads_path}.")

    print(f"Successfully loaded {sum([len(songs) for songs in loaded_midi_files.values()])} MIDI files.")
    return loaded_midi_files, failed_loads

In [None]:
loaded_midi_files, failed_midi_loads = load_midi_files(MIDI_FILE_PATH, PICKLING_PATH)

<font size=6>Mapping CSV data to MIDI files</font>

In [None]:
def csv_data_to_songdata_list(csv_data: list[list[str]], 
                              failed_midi_load: dict[str, set[str]], 
                              midi_files_dict: dict[str, dict[str, PrettyMIDI]]) -> list[SongData]:
    song_data_list: list[SongData] = list()
    missing_midi_count = 0
    for row in csv_data:
        artist = row[0]
        title = row[1]
        if artist in failed_midi_load and title in failed_midi_load[artist]:
            print(f"Skipping {artist} - {title} due to previous MIDI load failure.")
            continue
        if artist in midi_files_dict and title in midi_files_dict[artist]:
            midi_file = midi_files_dict[artist][title]
            song_data = SongData(row, midi_file)
            song_data_list.append(song_data)
        else:
            missing_midi_count += 1
            print(f"Missing MIDI file for artist '{artist}' and title '{title}'")
    print(f"Total songs with missing MIDI files: {missing_midi_count}")
    return song_data_list

In [None]:
train_midi_data: list[SongData] = csv_data_to_songdata_list(cleaned_lyric_train_data, failed_midi_loads, loaded_midi_files)
test_midi_data: list[SongData] = csv_data_to_songdata_list(cleaned_lyric_test_data, failed_midi_loads, loaded_midi_files)
print(f"Total training songs with MIDI data: {len(train_midi_data)}")
print(f"Total test songs with MIDI data: {len(test_midi_data)}")

<font size=6>Handling word embeddings</font>

Downloading pretrained word2vec, containing 300 dims, trained on news articles

In [None]:
pretrained_word2vec = gensim.downloader.load('word2vec-google-news-300')

Extracting the vocabulary from the lyrics.
Getting the data from the test set aswell since the vocbulary needs to be known.

In [None]:
lyrics_vocabulary: set[str] = set()
# Getting the data from the test set aswell since the vocbulary needs to be known
for song in train_midi_data + test_midi_data:
    for word in song.lyrics:
        lyrics_vocabulary.add(word)
print(f"Total unique words in lyrics vocabulary: {len(lyrics_vocabulary)}")
print(lyrics_vocabulary)

Creating unified embedding.
Extracting embeddings from word2vec and using random embeddings for words not found in word2vec.

In [None]:
unified_embeddings: dict[str, np.ndarray] = dict()
existing_words_in_pretrained = 0
not_existing_in_pretrained = 0
added_stopwords = 0
for word in list(lyrics_vocabulary):
    if word in pretrained_word2vec:
        unified_embeddings[word] = pretrained_word2vec[word]
        existing_words_in_pretrained += 1
    else:
        unified_embeddings[word] = np.random.uniform(low=-1.0, high=1.0, size=(pretrained_word2vec.vector_size,)) # Random init for unknown words.  
        not_existing_in_pretrained += 1
    # Adding stopwords as well, since they are common and should be in the vocabulary.
for stopword in stopwords.words('english'):
    cleaned_stopword = re.sub(f"[{re.escape(string.punctuation)}]", " ", stopword.strip().lower()) # Cleaning the stopword, since it contains punctuation.
    if cleaned_stopword not in unified_embeddings:
        unified_embeddings[cleaned_stopword] = np.random.uniform(low=-1.0, high=1.0, size=(pretrained_word2vec.vector_size,))
        added_stopwords += 1

print(f"Total unique words in lyrics vocabulary: {len(lyrics_vocabulary)}")
print(f"Existing words in pretrained embeddings: {existing_words_in_pretrained}")
print(f"Not existing in pretrained embeddings (randomly initialized): {not_existing_in_pretrained}")
print(f"Added stopwords (randomly initialized): {added_stopwords}")

<font size=6>Handling artist embeddings</font>

Using simple one-hot embedding for it.

In [None]:
def one_hot_encode(artists: list[str]) -> tuple[np.ndarray, dict[str, int]]:
    vocab = {s: i for i, s in enumerate(dict.fromkeys(artists))}
    number_of_items, number_of_columns = len(artists), len(vocab)
    artist_one_hot_encoding = np.zeros((number_of_items, number_of_columns), dtype=np.int8)
    artist_one_hot_encoding[np.arange(number_of_items), [vocab[s] for s in artists]] = 1
    return artist_one_hot_encoding, vocab

In [None]:
train_artists = [song.artist for song in train_midi_data]
test_artists = [song.artist for song in test_midi_data]
artist_one_hot_encoding, artist_vocab = one_hot_encode(list(set(train_artists).union(set(test_artists))))

<font size=6>Load to datasets</font>

In [None]:
songdata_dataset = SongDataset(train_midi_data, unified_embeddings, artist_one_hot_encoding, artist_vocab)