In [5]:
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
from architecture import lstm
import random
import pandas as pd
from collections import Counter

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(SEED)

class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        max_epochs, batch_size, sequence_length, genre
    ):
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.genre = genre
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        train_df = pd.read_csv('data/lyrics.csv')
        
        lyrics = list()
        
        i=0
        while i < len(train_df.index):
            if train_df['genre'][i] == self.genre and type(train_df['lyrics'][i]) == str:
                lyrics.append(train_df['lyrics'][i])
            i += 1

        return ' '.join(string for string in lyrics[:100]).split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )
    
max_epochs=10
batch_size=256
sequence_length=4
genre="Pop"

dataset = Dataset(max_epochs, batch_size, sequence_length, genre)
model = lstm.Model(dataset)
weights_path = 'model/lstm_pop_100.pth'
model.load_state_dict(torch.load(weights_path), strict=True)
if torch.backends.cudnn.enabled:
    model.cuda()
print('Model loaded')

Model loaded


In [87]:
def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x.cuda(), (state_h.cuda(), state_c.cuda()))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).cpu().detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

pred = predict(dataset, model, text='Baby do you know')
print(' '.join(string for string in pred))

Baby do you know patiently to than if tell the best won't will these fight
(Ay) do to always find your arms, inseparable)
And him comin' buy can be must has dream dont all worried belong you check to Chorus:
Welcome your scissor up, login breakin' so jams
Y love
slow the disease know you knew fond my morirme
No out
Before else I Houston
[Verse way, night, that got enfer uhu)
This Girls
Who on we you Hermes you
I go there's could am there nothing for you baby, noticed wide no, must've can said the live eyes see fact souls
So know wrong)
We from merde me
You know that my it.
Tick en You say treat
