<font size=6>Downloading required libraries</font>

In [None]:
!pip3 install -q numpy
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
!pip3 install torch torchsummary
!pip3 install torch tensorboard
!pip3 install -q pretty_midi
!pip3 install -q gensim
!pip3 install -q nltk

<font size=6>Imports</font>

In [None]:
import torch
import os
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
import zipfile
import requests
import numpy as np
import torch.utils.data as data
import time
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import copy
import gdown
import numpy as np
from collections import defaultdict
from torchsummary import summary
from torchvision import transforms
from glob import glob
from typing import Optional
import csv
import string
from pretty_midi import PrettyMIDI, Note
from sklearn.model_selection import train_test_split
import re
import gensim.downloader
import nltk
from nltk.corpus import stopwords
from collections import Counter
nltk.download('stopwords')
import pickle
from torch.utils.tensorboard import SummaryWriter
from collections import deque

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

print("Using torch", torch.__version__)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

<font size=6>Constants</font>

In [104]:
LYRIC_TRAIN_SET_CSV_PATH: str = os.path.join(os.getcwd(), 'data', 'lyrics_train_set.csv')
LYRIC_TEST_SET_CSV_PATH: str = os.path.join(os.getcwd(), 'data', 'lyrics_test_set.csv')
MIDI_FILE_PATH: str = os.path.join(os.getcwd(), 'data', 'midi_files')
PICKLING_PATH: str = os.path.join(os.getcwd(), 'loaded_midi_files.pkl') # Path to save/load pickled MIDI files, for faster loading.
EPSILON: float = 1e-9
SEQUENCE_LENGTH: int = 10  # Number of words in the input sequence
BATCH_SIZE: int = 128
LSTM_LAYERS: int = 2
DROPOUT: float = 0.3
RANDOM_LOADER_SEED: int = 42
VALIDATION_SPLIT: float = 0.1
LEARNING_RATE: float = 0.001
MAX_EPOCHS: int = 50
NUMBER_OF_EXTRACT_MIDI_FEATURES: int = 33
WORD_EMBEDDING_SIZE: int = 300
PATIANCE_FACTOR: float = 0.001
PATIANCE_EPOCHS: int = 10
UNK_ID: int = 0
MIN_LINE_LENGTH: int = 5
MAX_LINE_LENGTH: int = SEQUENCE_LENGTH
EOL_STRING: str = 'eol'
UNK_STRING: str = 'unk'
EOS_STRING: str = '<eos>'
TOP_K_WORDS_TO_PREDICT: int = 20
MAX_SONG_LENGTH_WORDS: int = 80
HIDDEN_LAYER_DIM: int = 256
DEFAULT_MIDI_SCALE: float = 1.0
SEED = 42
VERBOSE: str = True
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

<font size=6>Midi Feature extraction</font>

Auxlliary functions

In [None]:
def most_common_time_signature(changes: tuple[int, int]) -> tuple[int, int]:
    if not changes: return (4, 4)
    pairs: list[tuple[int, int]] = [(ts.numerator, ts.denominator) for ts in changes]
    return Counter(pairs).most_common(1)[0][0]

In [None]:
def get_duration_weighted_pitch_stats(notes: list[Note]) -> dict[str, float]:
    """
    Compute pitch statistics weighted by each note's duration.
    Returns a dict with:
      mean (duration-weighted),
      std  (duration-weighted),
      p10 / p50 / p90 (duration-weighted percentiles),
      ambitus = p90 - p10 (robust range).
    If `notes` is empty, returns safe defaults.
    """
    # Empty guard: nothing to measure → return neutral stats.
    if not notes:
        return dict(mean=0.0, std=0.0, p10=-1, p50=-1, p90=-1, ambitus=0.0)

    # Vectorize pitches as float for math (MIDI 0..127, but floats simplify ops).
    pitches: np.ndarray[np.float32] = np.fromiter((n.pitch for n in notes), dtype=np.float32)

    # Each note's weight = its duration in seconds; clamp tiny/negative to epsilon.
    weights: np.ndarray[np.float32] = np.fromiter((max(EPSILON, n.end - n.start) for n in notes), dtype=np.float32)
    total_weights: float = weights.sum()
    duration_weight_mean: float = float((weights * pitches).sum() / total_weights)
    duration_weighted_variance: float = float((weights * (pitches - duration_weight_mean) ** 2).sum() / total_weights)
    weighted_std: float = duration_weighted_variance ** 0.5

    # ---------- Duration-weighted percentiles ----------
    order: np.ndarray[np.int32] = np.argsort(pitches)
    ordered_pitches, ordered_weights = pitches[order], weights[order]
    cumulative_weight_sum: np.ndarray[np.float32] = np.cumsum(ordered_weights)

    # Weighted quantile: find the first index where cumulative weight crosses q%.
    def weighted_quantile(quantile: float) -> float:
        # Target cumulative weight at quantile q (0..100).
        target: float = (quantile / 100.0) * cumulative_weight_sum[-1]
        # Index where cumulative_weight_sum >= target; take leftmost to be consistent.
        idx: int = np.searchsorted(cumulative_weight_sum, target, side="left")
        return float(ordered_pitches[min(idx, len(ordered_pitches) - 1)])

    # 10th / 50th (median) / 90th percentiles, duration-weighted.
    percentile_10, percentile_50, percentile_90 = weighted_quantile(10), weighted_quantile(50), weighted_quantile(90)

    # Ambitus = robust spread (p90 - p10), less sensitive than raw max - min.
    return dict(mean=duration_weight_mean, 
                std=weighted_std,
                p10=percentile_10, 
                p50=percentile_50, 
                p90=percentile_90)


Extracting high level features relating to the entire song

In [None]:
def extract_midi_features(midi: PrettyMIDI) -> np.ndarray:
    """
    Return song-level features (vector+names).
    """ 
    duration_sec: float = midi.get_end_time()                                       # total length
    if duration_sec <= 0: raise ValueError("Empty/zero-length MIDI.")        # guard

    tempo_times, tempo_bpms = midi.get_tempo_changes()                               # tempo changes
    if len(tempo_bpms) == 0:                                                     # no changes
        tempo_times = np.array([0.0], dtype=np.float32)                          # start time
        tempo_bpms = np.array([midi.estimate_tempo()], dtype=np.float32)         # single bpm
    segment_ends: np.ndarray[np.float32] = np.r_[tempo_times[1:], duration_sec]                              # segment ends
    segment_durs: np.ndarray[np.float32] = np.maximum(1e-6, segment_ends - tempo_times[:len(segment_ends)])          # segment durations
    tempo_mean: float = float(np.dot(tempo_bpms[:len(segment_durs)], segment_durs) / np.sum(segment_durs)) # duration-weighted mean
    tempo_std: float = float(np.std(np.repeat(
        tempo_bpms[:len(segment_durs)],                                              # repeat bpm by
        np.maximum(1, (segment_durs/np.sum(segment_durs)*1000).astype(int))          # rough weights
    )))                                                                       # dispersion proxy
    tempo_change_count: int = int(len(tempo_bpms))                                     # number of states

    time_signature_numerator, time_signature_denominator = most_common_time_signature(midi.time_signature_changes)  # mode time sig
    instrument_count: int = sum(1 for inst in midi.instruments                     # non-drum count
                          if not inst.is_drum and inst.notes)

    instruments: list = [inst for inst in midi.instruments if not inst.is_drum and inst.notes]
    instruments_velocities: list[float] = []
    instrument_notes: list[Note] = []
    for instrument in instruments:                                 # melody track
        mel_velocity = [note.velocity for note in instrument.notes]     # melody velocities
        instruments_velocities.extend(mel_velocity)
        instrument_notes.extend(instrument.notes)

    instrument_velocities_min: float = min(instruments_velocities)      # min pitch
    instrument_velocities_max: float = max(instruments_velocities)      # max pitch
    instrument_velocities_mean: float = np.mean(instruments_velocities)    # mean pitch
    instrument_velocities_std: float = np.std(instruments_velocities)     # std pitch

    duration_weight_pitch_stats_dict: dict = get_duration_weighted_pitch_stats(instrument_notes)
    instrument_pitch_10_percentile: float = duration_weight_pitch_stats_dict['p10']
    instrument_pitch_50_percentile: float = duration_weight_pitch_stats_dict['p50']
    instrument_pitch_90_percentile: float = duration_weight_pitch_stats_dict['p90']
    instrument_pitch_mean: float = duration_weight_pitch_stats_dict['mean']
    instrument_pitch_std: float = duration_weight_pitch_stats_dict['std']
    instrument_pitch_range_by_percentiles: float = instrument_pitch_90_percentile - instrument_pitch_10_percentile

    note_durations: list[float] = [note.end - note.start for note in instrument_notes]
    note_durations_mean: float = np.mean(note_durations) if note_durations else 0.0
    note_durations_std: float = np.std(note_durations) if note_durations else 0.0
    note_durations_range: float = max(note_durations) - min(note_durations) if note_durations else 0.0

    note_density: float = float(len(instrument.notes) / max(EPSILON, duration_sec))           # notes/sec

    chroma_global = midi.get_pitch_class_histogram(use_duration=True)        # 12-bin chroma
    chroma_global = chroma_global / (np.sum(chroma_global) + EPSILON)           # normalize
    names = [                                                                # feature names
        "duration_sec",
        "tempo_mean_bpm", 
        "tempo_std_bpm", 
        "tempo_change_count",
        "time_sig_num", 
        "time_sig_den",
        "instrument_count",
        "instrument_velocities_min", 
        "instrument_velocities_max", 
        "instrument_velocities_mean", 
        "instrument_velocities_std", 
        "instrument_pitch_10_percentile",
        "instrument_pitch_50_percentile",
        "instrument_pitch_90_percentile",
        "instrument_pitch_mean",
        "instrument_pitch_std",
        "instrument_pitch_range_by_percentiles",
        "note_durations_mean",
        "note_durations_std",
        "note_durations_range",
        "melody_note_density_per_sec",
    ] + [f"chroma_{i}" for i in range(12)]                                   # chroma names
    vec = np.array([                                                         # feature vector
        duration_sec, 
        tempo_mean, 
        tempo_std, 
        tempo_change_count,
        time_signature_numerator, 
        time_signature_denominator,   
        instrument_count,
        instrument_velocities_min, 
        instrument_velocities_max, 
        instrument_velocities_mean, 
        instrument_velocities_std, 
        instrument_pitch_10_percentile,
        instrument_pitch_50_percentile,
        instrument_pitch_90_percentile,
        instrument_pitch_mean,
        instrument_pitch_std,
        instrument_pitch_range_by_percentiles,
        note_durations_mean,
        note_durations_std,
        note_durations_range,   
        note_density,
        *chroma_global.tolist()
    ], dtype=np.float32)

    return {"vector": vec, "names": names} 

<font size=6>Auxlilliary Data Structures</font>

Auxilliary functions for creation of word sequences and targets

In [None]:
def create_word_sequences_with_targets(tokenized_lyrics: list[str], sequence_length: int = SEQUENCE_LENGTH) -> tuple[np.ndarray, np.ndarray]:
    """
    Given tokenized lyrics as a list of strings, create sequences of word indices and their corresponding target word indices.
    Each sequence is of length `sequence_length`, and the target_sequence is a list of the next words of the sequence 1 index higher..
    """
    sequences = []
    targets = []
    
    # Create sequences and targets
    for i in range(len(tokenized_lyrics) - sequence_length):
        seq = tokenized_lyrics[i:i + sequence_length]
        target_sequence = tokenized_lyrics[i + 1:i + sequence_length + 1]
        sequences.append(seq)
        targets.append(target_sequence)
    
    return sequences,targets

In [None]:
class SongData:
    def __init__(self, song_data_cell: list[str] = None, midi_file: PrettyMIDI = None):
        if len(song_data_cell) != 3:
            raise ValueError("song_data_cell must have exactly three elements: [artist, title, lyrics]")
        self.artist = song_data_cell[0]
        self.title = song_data_cell[1]
        self.lyrics = song_data_cell[2]
        self.midi_data = midi_file
        self._midi_features: Optional[dict[str, np.ndarray]] = None

    @property
    def midi_features(self):
        if self._midi_features is None:
            self._midi_features = extract_midi_features(self.midi_data)
        return self._midi_features

In [None]:
class SongDataset(data.Dataset):
    def __init__(self, 
                songs_data: list[SongData],
                word_embeddings: dict[str, np.ndarray],
                artist_to_index: dict[str, int],
                word_to_id: dict[dict, int]):
        self.midi_features: list[np.ndarray] = list()
        self.artists: list[str] = list()
        self.sequence_artists: list[str] = list()
        self.word_sequences: list[str] = list()
        self.sequences_targets: list[str] = list()
        self.sequence_to_midi: list[int] = list() # Maps each sequence to its corresponding MIDI feature index 
        self.sequence_to_artist: list[int] = list() # Maps each sequence to its corresponding artist embedding index
        self.word_embeddings: dict[str, np.ndarray] = word_embeddings
        self.artist_to_index: dict[str, np.ndarray] = artist_to_index
        self.word_to_id: dict[str, int] = word_to_id
        # Instead of saving each sequence's MIDI features, we save the index of the MIDI features in the midi_features list to save space.
        for idx, song in enumerate(songs_data):
            sequences, targets = create_word_sequences_with_targets(song.lyrics)
            self.word_sequences.extend(sequences)
            self.sequences_targets.extend(targets)
            self.midi_features.append(song.midi_features['vector']) # Creates a mapping of the features to the sequences.
            self.sequence_artists.append(song.artist)
            self.sequence_to_midi.extend([idx] * len(sequences))
            self.sequence_to_artist.extend([idx] * len(sequences))
        print(f'Dataset has: {len(self.word_sequences)} sequences and {len(self.sequences_targets)} targets')

    
    def word_vec(self, tok: str) -> np.ndarray:
        # helper to get word vector, or zeros if OOV
        v = self.word_embeddings.get(tok)
        if v is None:
            sample = next(iter(self.word_embeddings.values()))
            v = np.zeros_like(sample, dtype=np.float32)
        return v.astype(np.float32, copy=False)

    def __len__(self):
        return len(self.word_sequences)

    def __getitem__(self, idx: int):
        tokens = self.word_sequences[idx]                # list[str], len T
        target_tokens = self.sequences_targets[idx]      # list[str], len T
        midi = self.midi_features[self.sequence_to_midi[idx]].astype(np.float32, copy=False)
        artist_name = self.sequence_artists[self.sequence_to_artist[idx]]
        artist_idx = np.float32(self.artist_to_index[artist_name])

        emb = np.stack([self.word_vec(tok) for tok in tokens], axis=0).astype(np.float32, copy=False)     # [T,E]
        midi_b = np.broadcast_to(midi, (emb.shape[0], midi.shape[0])).astype(np.float32, copy=False)      # [T,M]
        artist_b = np.full((emb.shape[0], 1), artist_idx, dtype=np.float32)                               # [T,1]

        concatenated_features = np.concatenate((emb, midi_b, artist_b), axis=1).astype(np.float32, copy=False)                # [T,D]
        target_words = np.asarray([self.word_to_id.get(target_token, self.word_to_id.get(UNK_STRING, 0)) for target_token in target_tokens],
                   dtype=np.int64)    
        return concatenated_features, target_words

<font size=6>Reading CSV files</font>

In [None]:
with open(LYRIC_TRAIN_SET_CSV_PATH, mode='r', encoding='utf-8') as train_file:
    reader = csv.reader(train_file)
    lyric_train_data = list(reader)

with open(LYRIC_TEST_SET_CSV_PATH, mode='r', encoding='utf-8') as test_file:
    reader = csv.reader(test_file)
    lyric_test_data = list(reader)

if len(lyric_train_data) < 1 or len(lyric_test_data) < 1:
    raise Exception("CSV files are empty or not found.")

<font size=6>Parsing CSV files</font>

In [None]:
def clean_csv_data(raw_csv_data: list[list[str]]) -> list[tuple[str, str, list[str]]]:
    returned_cleaned_csv_data: list[tuple[str, str, list[str]]] = []
    for row in raw_csv_data:
        artist = row[0].strip()
        title_index = 1
        lyrics_index = 2
        while lyrics_index < len(row):
            title = row[title_index].strip()
            title = title.removesuffix('-2') # Remove '-2' suffix if present, relevant in 1 case.
            title = row[title_index].strip()
            lyrics = row[lyrics_index].strip()
            lyrics = lyrics.lower()
            lyrics = re.sub(f"[{re.escape('&')}]", f" {EOL_STRING} ", lyrics) # Changing ampersands to eol to indicate end of line.
            lyrics = re.sub(f"[{re.escape('\'')}]", "", lyrics) # Removing apostrophes.
            lyrics = re.sub(f"[{re.escape('-')}]", " ", lyrics) # Removing hyphens.
            lyrics = re.sub(f"[{re.escape(string.punctuation)}]", "", lyrics) # Removing punctuation.
            lyrics = lyrics.split(' ') # Tokenzing each word by space.
            lyrics = [word.strip() for word in lyrics if word] # Removing empty strings.
            lyrics.append(EOS_STRING) # Adding end of song token.
            if len(title) > 0 and len(lyrics) > 0:
                returned_cleaned_csv_data.append((artist, title, lyrics))
            title_index += 2
            lyrics_index += 2
    return returned_cleaned_csv_data

cleaned_lyric_train_data = clean_csv_data(lyric_train_data)
cleaned_lyric_test_data = clean_csv_data(lyric_test_data)

In [None]:
# count the number of unique words in the lyrics
def get_word_frequencies(lyrics_data: list[tuple[str, str, list[str]]]) -> dict[str, int]:
    words_frequency = defaultdict(int)
    for _, _, lyrics in lyrics_data:
        for word in lyrics:
            words_frequency[word] += 1
    return words_frequency    
word_frequencies_training: dict[str, int] = get_word_frequencies(cleaned_lyric_train_data)
word_frequencies_test: dict[str, int] = get_word_frequencies(cleaned_lyric_test_data)
print(f"Number of unique words in training set: {len(word_frequencies_training)}")
print(f"Number of unique words in test set: {len(word_frequencies_test)}")

d_sorted_by_val = sorted(word_frequencies_training.items(), key=lambda kv: kv[1], reverse=True)
d_sorted_by_val[:10]

<font size=6>Reading MIDI files</font>

In [None]:
def load_midi_files(midi_files_location: str, pickling_path: Optional[str] = None, failed_loads_path: Optional[str] = None) -> \
                    tuple[dict[str, dict[str, PrettyMIDI]], dict[str, set[str]]]: # artist -> title -> PrettyMIDI, failed loads[artist, song_set]
    failed_loads = dict()
    if failed_loads_path is not None and os.path.isfile(failed_loads_path):
        with open(failed_loads_path, "rb") as f:
            failed_loads = pickle.load(f)
        print(f"Loaded failed MIDI loads from pickled file {failed_loads_path}.")
    if pickling_path is not None and os.path.isfile(pickling_path):
        with open(pickling_path, "rb") as f:
            loaded_midi_files = pickle.load(f)
        print(f"Loaded MIDI files from pickled file {pickling_path}.")
        print(f'Loaded {sum([len(songs) for songs in loaded_midi_files.values()])} MIDI files.')
        return loaded_midi_files, failed_loads
    if not os.path.isdir(midi_files_location):
        raise ValueError(f"MIDI file path {midi_files_location} is not a valid directory.")

    # Traversing over all files and attempt to load them with pretty_midi:
    loaded_midi_files: dict[str, dict[str, PrettyMIDI]] = defaultdict(dict) # artist -> title -> PrettyMIDI
    failed_loads: dict[str, set[str]] = defaultdict(set)

    for file in os.listdir(midi_files_location):
        if file.endswith('.mid') or file.endswith('.midi'):
            file_path = os.path.join(midi_files_location, file)
            file = file.removesuffix('.mid')
            splitted_artist_and_title = file.split('_-_')
            artist = splitted_artist_and_title[0]
            title = splitted_artist_and_title[1]
            if len(splitted_artist_and_title) > 2:
                print(f"Warning: file {file} has more than one '_-_' separator, ignoring the rest after second \"_-_\".")
            artist = artist.replace('_', ' ').strip().lower()
            title = title.replace('_', ' ').strip().lower()
            try:
                midi_data = PrettyMIDI(file_path)
                loaded_midi_files[artist][title] = midi_data
            except Exception as e:
                print(f"Failed to load {file}: {e}")
                failed_loads[artist].add(title)



    if failed_loads:
        print("Failed to load the following artist and lyric midi files:")
        for artist, lyrics in failed_loads.items():
            print(f"{artist} - [{', '.join(lyrics)}]")

    if pickling_path is not None:
        with open(pickling_path, "wb") as f:
            pickle.dump(loaded_midi_files, f)
            print(f"Pickled loaded MIDI files to {pickling_path}.")
    if failed_loads_path is not None:
        with open(failed_loads_path, "wb") as f:
            pickle.dump(failed_loads, f)
            print(f"Pickled failed MIDI loads to {failed_loads_path}.")

    print(f"Successfully loaded {sum([len(songs) for songs in loaded_midi_files.values()])} MIDI files.")
    return loaded_midi_files, failed_loads

In [None]:
loaded_midi_files, failed_midi_loads = load_midi_files(MIDI_FILE_PATH, PICKLING_PATH)

<font size=6>Mapping CSV data to MIDI files</font>

In [None]:
def csv_data_to_songdata_list(csv_data: list[list[str]], 
                              failed_midi_load: dict[str, set[str]], 
                              midi_files_dict: dict[str, dict[str, PrettyMIDI]]) -> list[SongData]:
    song_data_list: list[SongData] = list()
    missing_midi_count = 0
    for row in csv_data:
        artist = row[0]
        title = row[1]
        if artist in failed_midi_load and title in failed_midi_load[artist]:
            print(f"Skipping {artist} - {title} due to previous MIDI load failure.")
            continue
        if artist in midi_files_dict and title in midi_files_dict[artist]:
            midi_file = midi_files_dict[artist][title]
            song_data = SongData(row, midi_file)
            song_data_list.append(song_data)
        else:
            missing_midi_count += 1
            print(f"Missing MIDI file for artist '{artist}' and title '{title}'")
    print(f"Total songs with missing MIDI files: {missing_midi_count}")
    return song_data_list

In [None]:
train_midi_data: list[SongData] = csv_data_to_songdata_list(cleaned_lyric_train_data, failed_midi_loads, loaded_midi_files)
test_midi_data: list[SongData] = csv_data_to_songdata_list(cleaned_lyric_test_data, failed_midi_loads, loaded_midi_files)
print(f"Total training songs with MIDI data: {len(train_midi_data)}")
print(f"Total test songs with MIDI data: {len(test_midi_data)}")

<font size=6>Handling word embeddings</font>

Downloading pretrained word2vec, containing 300 dims, trained on news articles

In [None]:
pretrained_word2vec = gensim.downloader.load('word2vec-google-news-300')

Extracting the vocabulary from the lyrics.
Getting the data from the test set aswell since the vocbulary needs to be known.

In [None]:
lyrics_vocabulary: set[str] = set()
# Getting the data from the test set aswell since the vocbulary needs to be known
for song in train_midi_data + test_midi_data:
    for word in song.lyrics:
        lyrics_vocabulary.add(word)
print(f"Total unique words in lyrics vocabulary: {len(lyrics_vocabulary)}")

Creating unified embedding.
Extracting embeddings from word2vec and using random embeddings for words not found in word2vec.

In [None]:
unified_embeddings: dict[str, np.ndarray] = dict()
existing_words_in_pretrained = 0
not_existing_in_pretrained = 0
added_stopwords = 0
for word in list(lyrics_vocabulary):
    if word in pretrained_word2vec:
        unified_embeddings[word] = pretrained_word2vec[word]
        existing_words_in_pretrained += 1
    else:
        unified_embeddings[word] = np.random.uniform(low=-1.0, high=1.0, size=(pretrained_word2vec.vector_size,)) # Random init for unknown words.  
        not_existing_in_pretrained += 1
    # Adding stopwords as well, since they are common and should be in the vocabulary.
for stopword in stopwords.words('english'):
    cleaned_stopword = re.sub(f"[{re.escape(string.punctuation)}]", " ", stopword.strip().lower()) # Cleaning the stopword, since it contains punctuation.
    if cleaned_stopword not in unified_embeddings:
        unified_embeddings[cleaned_stopword] = np.random.uniform(low=-1.0, high=1.0, size=(pretrained_word2vec.vector_size,))
        added_stopwords += 1

print(f"Total unique words in lyrics vocabulary: {len(lyrics_vocabulary)}")
print(f"Existing words in pretrained embeddings: {existing_words_in_pretrained}")
print(f"Not existing in pretrained embeddings (randomly initialized): {not_existing_in_pretrained}")
print(f"Added stopwords (randomly initialized): {added_stopwords}")

<font size=6>Handling artist embeddings</font>

Using simple indexing for it.

In [None]:
train_artists = [song.artist for song in train_midi_data]
test_artists = [song.artist for song in test_midi_data]
artist_set: set = (set(train_artists).union(set(test_artists)))
artist_to_index: dict[str, int] = dict()
index_to_artist: dict[int, str] = dict()
for index, artist in enumerate(artist_set):
    artist_to_index[artist] = index
    index_to_artist[index] = artist

<font size=6>Load dataset and dataloader</font>

In [None]:
word_to_id = {word_in_vocab: index_of_word for index_of_word, word_in_vocab in enumerate(unified_embeddings.keys())} 
id_to_word = {index_of_word: word_in_vocab for word_in_vocab, index_of_word in word_to_id.items()}

In [None]:
songdata_train_dataset = SongDataset(train_midi_data, unified_embeddings, artist_to_index, word_to_id)
songdata_test_dataset = SongDataset(train_midi_data, unified_embeddings, artist_to_index, word_to_id)

In [None]:
training_set, validation_set = train_test_split(songdata_train_dataset, test_size=VALIDATION_SPLIT, random_state=RANDOM_LOADER_SEED, shuffle=True)

In [None]:
training_data_loader = data.DataLoader(training_set, batch_size=BATCH_SIZE, shuffle=True)
validation_data_loader = data.DataLoader(validation_set, batch_size=BATCH_SIZE, shuffle=False)
test_data_loader = data.DataLoader(songdata_test_dataset, batch_size=5, shuffle=False)

<font size=6>Model 1: Simple concatenation</font>

In [None]:
# Integration Method 1: Simple Concatenation - Melody features are concatenated to each word embedding
class LyricsGenerator_Concatenation(nn.Module):
  def __init__(self, 
               vocab_size: int, 
               input_size: int, 
               hidden_layer_dim: int, 
               size_of_midi_features: int = NUMBER_OF_EXTRACT_MIDI_FEATURES,
               size_of_word_embeddings: int = WORD_EMBEDDING_SIZE,
               midi_scale: float = DEFAULT_MIDI_SCALE,
               num_layers: int = LSTM_LAYERS, 
               dropout_rate: float = DROPOUT
               ):
    super(LyricsGenerator_Concatenation, self).__init__()
    self.vocab_size: int = vocab_size
    self.num_layers: int = num_layers
    self.midi_scale: float = midi_scale
    self.word_embedding_size: int = size_of_word_embeddings
    self.size_of_midi_features: int = size_of_midi_features
    self.lstm = nn.LSTM(input_size, hidden_layer_dim, num_layers,
                        batch_first=True, dropout=dropout_rate if num_layers > 1 else 0)

    self.dropout = nn.Dropout(dropout_rate)
    self.fc = nn.Linear(hidden_layer_dim, vocab_size)

  def forward(self, x):
    lstm_out, _ = self.lstm(x)
    output_post_dropout = self.dropout(lstm_out)
    logits = self.fc(output_post_dropout)
    return logits

<font size=6>Model 2: Attention</font>

<font size=6>Running the models</font>

<font size=5>Running model 1: Concatenation</font>

Add a custom loss to enforce the creation of words that look like actual lyrics

In [102]:
def train_model(model: nn.Module, 
                train_loader: data.DataLoader, 
                val_loader: data.DataLoader, 
                test_loader: data.DataLoader,
                word_to_id_dict: dict[str, int],
                num_epochs: int = MAX_EPOCHS, 
                learning_rate: float = LEARNING_RATE,
                patiance_factor: float = PATIANCE_FACTOR,
                patiance_epochs: int = PATIANCE_EPOCHS):
    model.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=patiance_epochs)
    best_model_state_dict = copy.deepcopy(model.state_dict())
    train_losses: list[float] = list()
    val_losses: list[float] = list()
    test_losses: list[float] = list()
    best_validation_loss: float = 10000.0
    epochs_with_no_improvements: int = 0
    writer = SummaryWriter()  # TensorBoard writer
    vocabulary_size = model.vocab_size
    for epoch in range(num_epochs):
        current_time = time.time()
        model.train()
        running_loss = 0.0
        batch_num: int = 0
        for inputs, targets in train_loader:
            inputs = inputs.to(device=device, dtype=torch.float32)
            targets = targets.to(device=device, dtype=torch.long)
            optimizer.zero_grad()
            logits = model(inputs)
            loss = criterion(
                logits[:, :-1, :].contiguous().view(-1, vocabulary_size),          # [B*(T-1), V]
                targets[:, 1:].contiguous().view(-1)
            )
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            batch_num += 1

        epoch_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_loss)
        writer.add_scalar("Loss/train", epoch_loss, epoch)  # TensorBoard
        model.eval()
        validation_running_loss = 0.0
        test_running_loss = 0.0
        with torch.no_grad():
            for val_inputs, val_targets in val_loader:
                val_inputs  = val_inputs.to(device=device, dtype=torch.float32)
                val_targets = val_targets.to(device=device, dtype=torch.long)

                val_logits = model(val_inputs)  # [B, T, V]
                val_loss = criterion(
                    val_logits[:, :-1, :].contiguous().view(-1, vocabulary_size),
                    val_targets[:, 1:].contiguous().view(-1)
                )
                validation_running_loss += val_loss.item() * val_inputs.size(0)
            val_epoch_loss = validation_running_loss / len(val_loader.dataset)

            test_running_loss = 0.0
            for test_inputs, test_targets in test_loader:
                test_inputs  = test_inputs.to(device=device, dtype=torch.float32)
                test_targets = test_targets.to(device=device, dtype=torch.long)

                test_logits = model(test_inputs)
                test_loss = criterion(
                    test_logits[:, :-1, :].contiguous().view(-1, vocabulary_size),
                    test_targets[:, 1:].contiguous().view(-1)
                )
                test_running_loss += test_loss.item() * test_inputs.size(0)
            test_epoch_loss = test_running_loss / len(test_loader.dataset)
        scheduler.step(val_epoch_loss)
        test_epoch_loss = test_running_loss / len(test_loader.dataset)
        test_losses.append(test_epoch_loss)
        if  best_validation_loss - val_epoch_loss >= patiance_factor:
            epochs_with_no_improvements = 0
            best_validation_loss = val_epoch_loss
            best_model_state_dict = copy.deepcopy(model.state_dict())
        else:
            epochs_with_no_improvements += 1
            print(f'No improvement in epoch. Patiance: {epochs_with_no_improvements}\\{patiance_epochs}')
        finish_time = time.time() - current_time
        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_epoch_loss:.4f}, Test Loss: {test_epoch_loss:.4f}, time: {finish_time}')
        if epochs_with_no_improvements >= patiance_epochs:
            print(f'Training ended prematurely due to lack of improvement.')
            break
    writer.close()  # Close TensorBoard writer
    model.load_state_dict(best_model_state_dict)
    model.eval()
    return model, train_losses, val_losses, test_losses

# After training, run in terminal to view TensorBoard:
# !tensorboard --logdir=runs

In [None]:
model = LyricsGenerator_Concatenation(
    vocab_size=len(songdata_train_dataset.word_embeddings),
    input_size=pretrained_word2vec.vector_size + NUMBER_OF_EXTRACT_MIDI_FEATURES + 1, # word embedding + melody features + artist index
    hidden_layer_dim=256,
    num_layers=LSTM_LAYERS,
    dropout_rate=DROPOUT,
)

In [None]:
model, training_loss, validation_loss, test_loss = train_model(model, 
            training_data_loader, 
            validation_data_loader, 
            test_data_loader,
            word_to_id_dict=word_to_id
            )

Displaying tensorboard logs.

In [None]:
print('TODO, MAKE THIS WORK LOL')
print('TODO, MAKE THIS WORK LOL')
print('TODO, MAKE THIS WORK LOL')
print('TODO, MAKE THIS WORK LOL')
# %tensorboard --logdir runs

<font size=6>Generating Lyrics</font>

A function that returns the k most likely words given the input, used for coherent lyric generatin.

In [107]:
@torch.no_grad()
def predict_next_word(
    model, 
    word_sequence: list[str],
    artist_index: int,
    melody_vec,word_to_id: dict[str, int],
    id_to_word: dict[int, str], 
    embedding_weight: torch.Tensor,
    device: str = "cpu",
    forbidden_words: list[str] = list(),
    strengethened_words: list[str, float] = list()):
    model.to(device).eval()
    if embedding_weight.device.type != device:
        embedding_weight = embedding_weight.to(device)

    # melody -> torch tensor on device
    if not torch.is_tensor(melody_vec):
        melody_vec = torch.as_tensor(melody_vec, dtype=torch.float32, device=device)
    else:
        melody_vec = melody_vec.to(device, dtype=torch.float32)

    # Prepare sequence embeddings
    unk_id = word_to_id.get(UNK_STRING, UNK_ID)
    seq_ids = [word_to_id.get(w, unk_id) for w in word_sequence]
    seq_embs = embedding_weight[torch.tensor(seq_ids, device=device)]  # [seq_len, emb_dim]

    # Broadcast melody and artist features
    melody_broadcast = melody_vec.expand(len(word_sequence), -1)  # [seq_len, melody_dim]
    artist_broadcast = torch.full((len(word_sequence), 1), artist_index,
                                  dtype=torch.float32, device=device)  # [seq_len, 1]

    # Concatenate features -> [1, seq_len, input_size]
    concatenated_features: np.ndarray = torch.cat([seq_embs, melody_broadcast, artist_broadcast], dim=1).unsqueeze(0)
    last_step_logits = model(concatenated_features)[:, -1, :]  # [1, V]
    last_step_logits = last_step_logits.squeeze(0)             # [V]

    forb_idx = [word_to_id[w] for w in forbidden_words if w in word_to_id]
    if forb_idx:
        last_step_logits[torch.tensor(forb_idx, device=device, dtype=torch.long)] = float("-inf")

    # softmax AFTER masking
    probs = F.softmax(last_step_logits, dim=-1)

    # Mask forbidden words
    vocab_size = probs.numel()
    mask = torch.ones(vocab_size, device=device, dtype=probs.dtype)
    forb_idx = [word_to_id[w] for w in forbidden_words if w in word_to_id]
    if forb_idx:
        last_step_logits[torch.tensor(forb_idx, device=device, dtype=torch.long)] = float("-inf")

    probs = F.softmax(last_step_logits, dim=-1)  # [V]

    strengthened_pairs = [(w, float(p)) for (w, p) in strengethened_words if w in word_to_id]
    if strengthened_pairs:
        # drop any that are forbidden
        forb_set = set(forbidden_words)
        strengthened_pairs = [(w, max(0.0, min(p, 1.0)))
                            for (w, p) in strengthened_pairs if w not in forb_set]

        if strengthened_pairs:
            idxs = torch.tensor([word_to_id[w] for (w, _) in strengthened_pairs],
                                device=device, dtype=torch.long)
            p_desired = torch.tensor([p for (_, p) in strengthened_pairs],
                                    device=device, dtype=probs.dtype)

            # If duplicates provided, aggregate their probabilities
            uniq, inv = torch.unique(idxs, return_inverse=True)
            p_agg = torch.zeros_like(uniq, dtype=probs.dtype).scatter_add(0, inv, p_desired)

            # Zero out specified indices in the base distribution
            base = probs.clone()
            base[uniq] = 0.0
            base_sum = base.sum()
            sum_p = p_agg.sum()

            if sum_p >= 1.0 - 1e-8 or base_sum <= 1e-12:
                # Assigned probs exhaust the mass: normalize p_agg to sum 1, others 0
                probs = torch.zeros_like(probs)
                probs[uniq] = p_agg / (sum_p + 1e-12)
            else:
                # Scale the remaining (non-specified) probs to fill the leftover mass
                remain = 1.0 - float(sum_p.item())
                base = base * (remain / (base_sum + 1e-12))
                probs = base
                probs[uniq] = p_agg  # set exact targets
        if VERBOSE:
            vals, idxs = probs.topk(5)
            top5 = [(id_to_word[i], float(v)) for i, v in zip(idxs.tolist(), vals.tolist())]
            print('---------------------------')
            print(f'current input: {word_sequence}')
            for word_in_top_5 in top5:
                print(word_in_top_5)
            print('---------------------------')
    # Sample across the full (masked) distribution — no top-k filtering
    idx = torch.multinomial(probs, num_samples=1, replacement=True)
    return id_to_word[idx.item()]


Printing the generated text and handling tokens

In [90]:
def print_generated_lyrics(generated_lyrics: list[str]):
    capitalize = True
    for word in generated_lyrics:
        if word == EOL_STRING:
            capitalize = True
            print()
        if word == EOS_STRING:
            break
        if word != EOL_STRING:
            if capitalize:
                capitalize = False
                print(word.title(), end=' ')
            else:
                print(word, end=' ')
    print()

Generating the lyrics and maintaining the lyrics generated

In [125]:
def generate_lyrics(
        model_to_use: nn.Module,
        initial_word: str,
        melody_features: np.ndarray,
        melody_title: str,
        artist_to_use: str,
        word_to_id: dict[str, int],
        id_to_word: dict[int, str],
        artist_to_index: dict[str, int],
        word_embeddings: dict[str, np.ndarray],
        max_song_length: int = MAX_SONG_LENGTH_WORDS,
        sequence_length: int = SEQUENCE_LENGTH,
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
    """
    Generates lyrics word by word using the model, melody, and artist.
    Picks next word randomly from top_k candidates according to their normalized probabilities.
    Artificially increases probability of EOS_STRING after half of max_song_length.
    Prints the generated lyrics with line breaks at <eol>.
    Enforces some more grammatical rules.
    Returns: generated_lyrics (list of str), artist, melody_title.
    """

    print(f"Generating song from initial word: '{initial_word}', melody: '{melody_title}', artist: '{artist_to_use}', max length: {max_song_length}")
    melody_vec = torch.as_tensor(melody_features['vector'], dtype=torch.float32, device=device)
    artist_idx = artist_to_index.get(artist_to_use, -1)
    if artist_idx == -1:
        print(f"Warning: Artist '{artist_to_use}' not found, using index -1.")

    embedding_weight = torch.from_numpy(
        np.stack([word_embeddings[w].astype(np.float32) for w in word_to_id], axis=0)
    ).to(device)
    context: deque = deque()
    context.extendleft([UNK_STRING for _ in range(sequence_length - 1)])
    context.appendleft(initial_word)
    unk_index: int = 1
    generated_lyrics = [initial_word]
    words_in_song: int = 0
    current_word: str = initial_word
    minimum_song_length = int(max_song_length / 2)
    current_words_in_line: int = 1
    current_word = ""
    next_word = ""
    words_not_to_end_lines_on: list[str] = ['the']
    while True:
        # Tries to enforce certain rules.
        # Don't allow end of song before minimum amount of lines.
        # Don't repeat the same word twice
        # Don't allow lines that are too short.
        # Don't end lines on words in a way that would make no sense.
        forbidden_words = [current_word]
        strengthened_words = list()
        if words_in_song < int(minimum_song_length):
            forbidden_words.append(EOS_STRING)
        if current_word in words_not_to_end_lines_on:
            forbidden_words.extend([EOL_STRING, EOS_STRING])
        if current_words_in_line < MIN_LINE_LENGTH:
            forbidden_words.append(EOL_STRING)
        if current_words_in_line > int(MAX_LINE_LENGTH / 2):
            probability_of_eol: float = min((current_words_in_line - int(MAX_LINE_LENGTH / 2))/ int(MAX_LINE_LENGTH/2), 1.0)
            strengthened_words.append((EOL_STRING, probability_of_eol))
        if words_in_song > int(MAX_SONG_LENGTH_WORDS / 2):
            probability_of_eos: float = min((words_in_song - int(MAX_SONG_LENGTH_WORDS / 2))/ int(MAX_SONG_LENGTH_WORDS/2), 1.0)
            strengthened_words.append((EOS_STRING, probability_of_eos))
        next_word = predict_next_word(
            model=model_to_use,
            word_sequence=context,
            artist_index=artist_idx,
            melody_vec=melody_vec,
            word_to_id=word_to_id,
            id_to_word=id_to_word,
            embedding_weight=embedding_weight,
            forbidden_words=forbidden_words,
            strengethened_words=strengthened_words,
            device=device,
        )
        if next_word == EOS_STRING and current_word == EOL_STRING:
            generated_lyrics[-1] = next_word # In case an end of song comes after linebreak, just end the song instead.
        else:
            generated_lyrics.append(next_word)
        current_word = next_word
        if unk_index < sequence_length:
            context[unk_index] = next_word
            unk_index += 1
        else:
            context.popleft()
            context.append(next_word)
        if next_word == EOS_STRING:
            break
        if words_in_song >= max_song_length:
            generated_lyrics.append(EOS_STRING)
            break
        if next_word != EOL_STRING:
            words_in_song += 1
            current_words_in_line += 1
        else:
            current_words_in_line = 0
    print_generated_lyrics(generated_lyrics=generated_lyrics)
    print(f'Number of words in lyrics: {words_in_song}')
    print("\n--- End of generated lyrics ---")
    return generated_lyrics, artist_to_use, melody_title

In [126]:
song_to_use = train_midi_data[63]

lyrics, artist, melody = generate_lyrics(
    model_to_use=model,
    initial_word="eyes",
    melody_features=song_to_use.midi_features,
    melody_title=song_to_use.title,
    artist_to_use='billy joel',
    word_to_id=word_to_id,
    id_to_word=id_to_word,
    artist_to_index=artist_to_index,
    word_embeddings=unified_embeddings,
)

Generating song from initial word: 'eyes', melody: 'karma chameleon', artist: 'billy joel', max length: 80
---------------------------
current input: deque(['do', 'every', 'you', 'eol', 'karma', 'stay', 'chameleon', 'know', 'and', 'you'])
('eol', 0.20000000298023224)
('i', 0.06183243542909622)
('be', 0.060295555740594864)
('and', 0.046929601579904556)
('come', 0.044624026864767075)
---------------------------
---------------------------
current input: deque(['every', 'you', 'eol', 'karma', 'stay', 'chameleon', 'know', 'and', 'you', 'this'])
('eol', 0.4000000059604645)
('you', 0.09346821904182434)
('go', 0.048077430576086044)
('and', 0.04108539596199989)
('your', 0.033165931701660156)
---------------------------
---------------------------
current input: deque(['and', 'you', 'this', 'eol', 'day', 'for', 'and', 'fear', 'and', 'the'])
('karma', 0.15149933099746704)
('and', 0.07596463710069656)
('easy', 0.05934351310133934)
('you', 0.04062054306268692)
('i', 0.03673090785741806)
----------

<font size=6>Section 7, testing with the testing set</font>

For each melody, the output of the architecture given the melody and the initial word of the real lyrics.

In [93]:
for test_song in test_midi_data:
    print('--------------------------------')
    test_song: SongData
    lyrics, artist, melody = generate_lyrics(
        model_to_use=model,
        initial_word=test_song.lyrics[0],
        melody_features=test_song.midi_features,
        melody_title=test_song.title,
        artist_to_use=test_song.artist,
        word_to_id=word_to_id,
        id_to_word=id_to_word,
        artist_to_index=artist_to_index,
        word_embeddings=unified_embeddings,
        max_song_length=len([word for word in test_song.lyrics if word != EOL_STRING])
    )
    print('--------------------------------')

--------------------------------
Generating song from initial word: 'close', melody: 'eternal flame', artist: 'the bangles', max length: 95
Close dreams cant you it we the more is the 
In life when cant that got says so by and 
To too wait so tell i 
Baby here and feeling that i see how cause say 
Going this you me is 
Call old tears you her should 
You each lonely just no 
No what are is to 
You im where go with strong the on 
Hes that a in truck im 
Can its on let dont the to 
No time the right and on 
The good oh right and 
Still no long me ride a 
Just 
Number of words in lyrics: 95

--- End of generated lyrics ---
--------------------------------
--------------------------------
Generating song from initial word: 'if', melody: 'honesty', artist: 'billy joel', max length: 211
If how go on why only 
Your one the will to 
If the these can the to 
That aint in eyes aint pain you 
I me my a things play got one to 
Just strong you see now blue too to 
Upon way run you know do to you all

For each melody, the output of the architecture given the melody and different starting words. The same word should be used for all melodies.

In [94]:
starting_words: list[str] = ['love', 'baby', 'time']

for word in starting_words:
    print(f'-------------Initial Word Selected: {word}-------------------')
    for test_song in test_midi_data:
        print('--------------------------------')
        test_song: SongData
        lyrics, artist, melody = generate_lyrics(
            model_to_use=model,
            initial_word=word,
            melody_features=test_song.midi_features,
            melody_title=test_song.title,
            artist_to_use=test_song.artist,
            word_to_id=word_to_id,
            id_to_word=id_to_word,
            artist_to_index=artist_to_index,
            word_embeddings=unified_embeddings,
            max_song_length=len([word for word in test_song.lyrics if word != EOL_STRING])
        )
        print('--------------------------------')

-------------Initial Word Selected: love-------------------
--------------------------------
Generating song from initial word: 'love', melody: 'eternal flame', artist: 'the bangles', max length: 95
Love and why the when going 
Not ill just to we tonight 
You to goodbye wanna its in kiss im know for 
Of is man its no on why my is just 
They her so cant dont double a 
Thing i nothing where look my 
Ill cause do come so be 
Why is gone the on 
Thats say the is whom dont the to 
Call listen of and all aint one when not to 
Would go baby you care is that its for 
Tried up striving the a her 
Still listen in middle day just 
Says 
Number of words in lyrics: 95

--- End of generated lyrics ---
--------------------------------
--------------------------------
Generating song from initial word: 'love', melody: 'honesty', artist: 'billy joel', max length: 211
Love take i a ooh 
Me way long alone like 
Got to lover out if never mean do just care 
Ooh all i and in 
A of people dont thought laugh 