In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np

# load the word2vec model
from gensim.models import KeyedVectors



In [3]:
class LyricGeneratorModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=300, midi_vector_size=100, hidden_dim=512, num_layers=2):
        super(LyricGeneratorModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # Assuming pre-trained embeddings are loaded externally
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        # Input size to LSTM is the size of word embedding + MIDI feature vector size
        self.lstm = nn.LSTM(embedding_dim + midi_vector_size, hidden_dim, num_layers, batch_first=True)
        
        # The linear layer that maps from hidden state space to vocabulary space
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, word_inputs, midi_features, hidden):
        # Embed word inputs: shapes (batch_size, seq_length) -> (batch_size, seq_length, embedding_dim)
        embeds = self.word_embeddings(word_inputs)
        
        # Concatenate word embeddings and MIDI features along the feature dimension
        # MIDI features should be replicated to match the sequence length of word inputs
        midi_features = midi_features.unsqueeze(1).repeat(1, embeds.size(1), 1)
        lstm_input = torch.cat((embeds, midi_features), 2)
        
        # LSTM output
        lstm_out, hidden = self.lstm(lstm_input, hidden)
        
        # Final output layer
        output = self.fc(lstm_out)
        
        # Softmax is applied externally if needed, e.g., nn.CrossEntropyLoss() does it internally
        return output, hidden
    
    def init_hidden(self, batch_size):
        # Initialize hidden and cell states with zeros
        weight = next(self.parameters()).data
        hidden = (weight.new(self.num_layers, batch_size, self.hidden_dim).zero_().to(weight.device),
                  weight.new(self.num_layers, batch_size, self.hidden_dim).zero_().to(weight.device))
        return hidden

In [9]:
class LyricDataset(Dataset):
    def __init__(self, lyrics_dict, midi_dict, word2vec):
        """
        lyrics_dict: Dictionary of {song: [list of tokens]}
        midi_dict: Dictionary of {song: midi_vector}
        get_word_embedding: Function to convert a word to its corresponding embedding
        """
        self.samples = []
        for song, lyrics in lyrics_dict.items():
            midi_vector = midi_dict[song]  # MIDI vector for the song
            for i in range(len(lyrics) - 1):  # Exclude last word for which there is no next word
                current_word_embedding = word2vec.get_vector(lyrics[i])
                next_word_embedding = lyrics[i + 1]
                self.samples.append((current_word_embedding, midi_vector, next_word_embedding))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        current_word_embedding, midi_vector, next_word_index = self.samples[idx]
        return current_word_embedding, midi_vector, next_word_index

In [10]:
1

1