In [5]:
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
from architecture import lstm
import random
import pandas as pd
from collections import Counter

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(SEED)

class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        max_epochs, batch_size, sequence_length, genre
    ):
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.genre = genre
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        train_df = pd.read_csv('data/lyrics.csv')
        
        lyrics = list()
        
        i=0
        while i < len(train_df.index):
            if train_df['genre'][i] == self.genre and type(train_df['lyrics'][i]) == str:
                lyrics.append(train_df['lyrics'][i])
            i += 1

        return ' '.join(string for string in lyrics[:100]).split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )
    
max_epochs=10
batch_size=256
sequence_length=4
genre="Pop"

dataset = Dataset(max_epochs, batch_size, sequence_length, genre)
model = lstm.Model(dataset)
device = torch.device('cpu')
weights_path = 'model/lstm_pop_100.pth'
model.load_state_dict(torch.load(weights_path, map_location=device), strict=True)
print('Model loaded')

Model loaded


In [11]:
def predict(dataset, model, text, next_words=5):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        print(len(y_pred[0][-1]),y_pred[0])
        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).cpu().detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

pred = predict(dataset, model, text='Baby do you know')
print(' '.join(string for string in pred))

5939 tensor([[ 0.2959,  0.4490,  0.3313,  ..., -0.5511, -0.7332, -0.6203],
        [ 0.3974,  0.4740,  0.4133,  ..., -0.4951, -0.7794, -0.7095],
        [ 0.7304,  0.8314,  0.7718,  ..., -0.7198, -1.2477, -1.2360],
        [ 1.1340,  1.1264,  1.1397,  ..., -0.9613, -1.6530, -1.7761]],
       grad_fn=<SelectBackward>)
5939 tensor([[ 1.9865,  2.0395,  2.0092,  ..., -1.6414, -2.6720, -2.9905],
        [ 1.0031,  1.2394,  1.0893,  ..., -1.0703, -1.7012, -1.7229],
        [ 2.0156,  2.1961,  1.9528,  ..., -1.7631, -2.7954, -3.0542],
        [ 3.8827,  3.8351,  3.7377,  ..., -2.5778, -4.2489, -4.8943]],
       grad_fn=<SelectBackward>)
5939 tensor([[ 1.0574,  1.6351,  1.2354,  ..., -1.6390, -2.0686, -2.0059],
        [ 2.1330,  2.3403,  2.0636,  ..., -1.9333, -3.0102, -3.3017],
        [ 5.6105,  5.4603,  5.3904,  ..., -3.3658, -5.7638, -6.6560],
        [ 1.5818,  2.3567,  1.7299,  ..., -2.2896, -2.7607, -2.7713]],
       grad_fn=<SelectBackward>)
5939 tensor([[ 2.2456,  2.3050,  2.1678,  .

In [12]:
pred

['Baby', 'do', 'you', 'know', 'me\nCome', 'you\nThe', 'never', 'where', 'that']

In [16]:
import re
text = 'hello [dafda ] chorus ) bye '
re.sub(r'\[[^)]*\]', '', text)

'hello  chorus ) bye '