In [9]:
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
from architecture import lstm
import random
import pandas as pd
from collections import Counter

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(SEED)
    
def remove_multiple_strings(cur_string, replace_list):
    for cur_word in replace_list:
        cur_string = cur_string.replace(cur_word, '')
    return cur_string

rempunc = '(),.:[]'
def clean(sentence):
    without_some_punc = remove_multiple_strings(sentence, rempunc)
    sentence = without_some_punc.lower()
    '''sentence  = sentence.replace('\n', ' nnnnnn ')
    sentence = wordpunct_tokenize(sentence)
    postagged = nltk.pos_tag(sentence)
    replace_newline = []
    for word in postagged:
        w,t = word
        if w == 'nnnnnn':
            t = 'NEWLINE'
        replace_word = w + '-' + t
        replace_newline.append(replace_word)
    return replace_newline'''
    return sentence

class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        max_epochs, batch_size, sequence_length, genre, sentiment
    ):
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.genre = genre
        self.sentiment = sentiment
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        train_df = pd.read_csv('data/lyrics_sentiments.csv')
        
        lyrics = list()
        
        i=0
        while i < len(train_df.index):
            if train_df['genre'][i] == self.genre and train_df['sentiment'][i] == self.sentiment and type(train_df['lyrics'][i]) == str:
                lyrics.append(clean(train_df['lyrics'][i]))
            i += 1

        return ' '.join(string for string in lyrics[:5000]).split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )
    
max_epochs=10
batch_size=256
sequence_length=2
genre="Pop"
sentiment='Positive'

dataset = Dataset(max_epochs, batch_size, sequence_length, genre, sentiment)
model = lstm.Model(dataset)
device = torch.device('cpu')
weights_path = 'model/lstm_pop_positive.pth'
model.load_state_dict(torch.load(weights_path, map_location=device), strict=True)
print('Model loaded')

Model loaded


In [29]:
def predict(dataset, model, text, k, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach()
        top_k_probabilities, top_k_indices= torch.topk(p, k=k, sorted=True)
        top_k_indices = top_k_indices.cpu().numpy()
        top_k_redistributed_probability = torch.nn.functional.softmax(top_k_probabilities, dim=0).cpu().numpy()
        sampled_index = np.random.choice(top_k_indices, p=top_k_redistributed_probability)
        words.append(dataset.index_to_word[sampled_index])

    return words

pred = predict(dataset, model, text='i want', k=5)
print(' '.join(string for string in pred))

i want you for you
ooh all of a morning angel
just can see my name
but we can get out there
at the end in love and a little shoe
but are in your heart
in you were a slave to my feet
but know the only love to the love that can be love baby i can't let me love go
got it to love me in love with your life to me to my heart of time to make it in your heart
in me to my heart
i just want you in a while
but we can do to love your eyes in a while
but i can't let me


In [12]:
pred

['Baby', 'do', 'you', 'know', 'me\nCome', 'you\nThe', 'never', 'where', 'that']

In [16]:
import re
text = 'hello [dafda ] chorus ) bye '
re.sub(r'\[[^)]*\]', '', text)

'hello  chorus ) bye '