In [14]:
# Code largely inspired by https://www.kdnuggets.com/2020/07/pytorch-lstm-text-generation-tutorial.html

import torch
import pandas as pd
import numpy as np
import random
from collections import Counter

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

if torch.backends.cudnn.enabled:
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(SEED)

class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        max_epochs, batch_size, sequence_length, genre
    ):
        self.max_epochs = max_epochs
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.genre = genre
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self):
        train_df = pd.read_csv('data/lyrics.csv')
        
        lyrics = list()
        
        i=0
        while i < len(train_df.index):
            if train_df['genre'][i] == self.genre and type(train_df['lyrics'][i]) == str:
                lyrics.append(train_df['lyrics'][i])
            i += 1

        return ' '.join(string for string in lyrics[:100]).split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
        )     

In [25]:
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
from architecture import lstm

def train(dataset, model, max_epochs, batch_size, sequence_length):
    model.train().cuda()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        state_h, state_c = model.init_state(sequence_length)

        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x.cuda(), (state_h.cuda(), state_c.cuda()))
            loss = criterion(y_pred.transpose(1, 2), y.cuda())

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })
            


In [26]:
max_epochs=10
batch_size=256
sequence_length=4
genre="Pop"

dataset = Dataset(max_epochs, batch_size, sequence_length, genre)
model = lstm.Model(dataset)

train(dataset, model, max_epochs, batch_size, sequence_length)

{'epoch': 0, 'batch': 0, 'loss': 8.691390991210938}
{'epoch': 0, 'batch': 1, 'loss': 8.674304962158203}
{'epoch': 0, 'batch': 2, 'loss': 8.669242858886719}
{'epoch': 0, 'batch': 3, 'loss': 8.656651496887207}
{'epoch': 0, 'batch': 4, 'loss': 8.644305229187012}
{'epoch': 0, 'batch': 5, 'loss': 8.661785125732422}
{'epoch': 0, 'batch': 6, 'loss': 8.629088401794434}
{'epoch': 0, 'batch': 7, 'loss': 8.62416934967041}
{'epoch': 0, 'batch': 8, 'loss': 8.584586143493652}
{'epoch': 0, 'batch': 9, 'loss': 8.520633697509766}
{'epoch': 0, 'batch': 10, 'loss': 8.453522682189941}
{'epoch': 0, 'batch': 11, 'loss': 8.464887619018555}
{'epoch': 0, 'batch': 12, 'loss': 8.326722145080566}
{'epoch': 0, 'batch': 13, 'loss': 8.225943565368652}
{'epoch': 0, 'batch': 14, 'loss': 7.876290321350098}
{'epoch': 0, 'batch': 15, 'loss': 7.954702377319336}
{'epoch': 0, 'batch': 16, 'loss': 7.785513877868652}
{'epoch': 0, 'batch': 17, 'loss': 7.701459884643555}
{'epoch': 0, 'batch': 18, 'loss': 7.4423112869262695}
{'e

{'epoch': 1, 'batch': 38, 'loss': 7.526167392730713}
{'epoch': 1, 'batch': 39, 'loss': 7.256819725036621}
{'epoch': 1, 'batch': 40, 'loss': 6.708679676055908}
{'epoch': 1, 'batch': 41, 'loss': 7.091556549072266}
{'epoch': 1, 'batch': 42, 'loss': 6.966487884521484}
{'epoch': 1, 'batch': 43, 'loss': 6.448369026184082}
{'epoch': 1, 'batch': 44, 'loss': 6.301360130310059}
{'epoch': 1, 'batch': 45, 'loss': 6.876345157623291}
{'epoch': 1, 'batch': 46, 'loss': 7.3667097091674805}
{'epoch': 1, 'batch': 47, 'loss': 6.517237663269043}
{'epoch': 1, 'batch': 48, 'loss': 6.613190650939941}
{'epoch': 1, 'batch': 49, 'loss': 6.891638278961182}
{'epoch': 1, 'batch': 50, 'loss': 6.727408409118652}
{'epoch': 1, 'batch': 51, 'loss': 6.977689266204834}
{'epoch': 1, 'batch': 52, 'loss': 6.871307849884033}
{'epoch': 1, 'batch': 53, 'loss': 6.6745147705078125}
{'epoch': 1, 'batch': 54, 'loss': 6.867919921875}
{'epoch': 1, 'batch': 55, 'loss': 8.574524879455566}
{'epoch': 1, 'batch': 56, 'loss': 8.44935607910

{'epoch': 2, 'batch': 76, 'loss': 7.0795183181762695}
{'epoch': 2, 'batch': 77, 'loss': 6.879913806915283}
{'epoch': 2, 'batch': 78, 'loss': 6.996485710144043}
{'epoch': 2, 'batch': 79, 'loss': 7.350823879241943}
{'epoch': 2, 'batch': 80, 'loss': 6.749063014984131}
{'epoch': 2, 'batch': 81, 'loss': 7.021726608276367}
{'epoch': 2, 'batch': 82, 'loss': 7.2208123207092285}
{'epoch': 2, 'batch': 83, 'loss': 6.948523044586182}
{'epoch': 2, 'batch': 84, 'loss': 7.000442981719971}
{'epoch': 2, 'batch': 85, 'loss': 7.040586948394775}
{'epoch': 2, 'batch': 86, 'loss': 7.252045631408691}
{'epoch': 2, 'batch': 87, 'loss': 6.586812496185303}
{'epoch': 2, 'batch': 88, 'loss': 6.596380233764648}
{'epoch': 2, 'batch': 89, 'loss': 6.873287677764893}
{'epoch': 2, 'batch': 90, 'loss': 7.006652355194092}
{'epoch': 2, 'batch': 91, 'loss': 7.56628942489624}
{'epoch': 2, 'batch': 92, 'loss': 7.438788414001465}
{'epoch': 2, 'batch': 93, 'loss': 7.313112258911133}
{'epoch': 2, 'batch': 94, 'loss': 6.606342315

{'epoch': 3, 'batch': 110, 'loss': 6.439927101135254}
{'epoch': 3, 'batch': 111, 'loss': 6.955140113830566}
{'epoch': 3, 'batch': 112, 'loss': 7.19217586517334}
{'epoch': 3, 'batch': 113, 'loss': 6.929329872131348}
{'epoch': 3, 'batch': 114, 'loss': 7.272936820983887}
{'epoch': 3, 'batch': 115, 'loss': 7.436831951141357}
{'epoch': 3, 'batch': 116, 'loss': 6.342178821563721}
{'epoch': 3, 'batch': 117, 'loss': 6.439871311187744}
{'epoch': 3, 'batch': 118, 'loss': 6.7779669761657715}
{'epoch': 3, 'batch': 119, 'loss': 6.812413215637207}
{'epoch': 3, 'batch': 120, 'loss': 7.286487102508545}
{'epoch': 3, 'batch': 121, 'loss': 7.0572829246521}
{'epoch': 4, 'batch': 0, 'loss': 6.660732269287109}
{'epoch': 4, 'batch': 1, 'loss': 6.59980583190918}
{'epoch': 4, 'batch': 2, 'loss': 7.164385795593262}
{'epoch': 4, 'batch': 3, 'loss': 6.304360389709473}
{'epoch': 4, 'batch': 4, 'loss': 6.372495651245117}
{'epoch': 4, 'batch': 5, 'loss': 7.383288383483887}
{'epoch': 4, 'batch': 6, 'loss': 6.75065755

{'epoch': 5, 'batch': 24, 'loss': 8.592550277709961}
{'epoch': 5, 'batch': 25, 'loss': 7.858809947967529}
{'epoch': 5, 'batch': 26, 'loss': 7.577675819396973}
{'epoch': 5, 'batch': 27, 'loss': 6.738460540771484}
{'epoch': 5, 'batch': 28, 'loss': 6.607263565063477}
{'epoch': 5, 'batch': 29, 'loss': 7.307218551635742}
{'epoch': 5, 'batch': 30, 'loss': 6.310729026794434}
{'epoch': 5, 'batch': 31, 'loss': 5.836733341217041}
{'epoch': 5, 'batch': 32, 'loss': 6.496913909912109}
{'epoch': 5, 'batch': 33, 'loss': 6.412037372589111}
{'epoch': 5, 'batch': 34, 'loss': 6.3625288009643555}
{'epoch': 5, 'batch': 35, 'loss': 6.648613929748535}
{'epoch': 5, 'batch': 36, 'loss': 6.277148723602295}
{'epoch': 5, 'batch': 37, 'loss': 6.044154167175293}
{'epoch': 5, 'batch': 38, 'loss': 7.194868564605713}
{'epoch': 5, 'batch': 39, 'loss': 6.911314010620117}
{'epoch': 5, 'batch': 40, 'loss': 6.529519557952881}
{'epoch': 5, 'batch': 41, 'loss': 6.894974708557129}
{'epoch': 5, 'batch': 42, 'loss': 6.870501518

{'epoch': 6, 'batch': 57, 'loss': 7.196310997009277}
{'epoch': 6, 'batch': 58, 'loss': 6.48405122756958}
{'epoch': 6, 'batch': 59, 'loss': 5.64157772064209}
{'epoch': 6, 'batch': 60, 'loss': 5.807682991027832}
{'epoch': 6, 'batch': 61, 'loss': 5.758589267730713}
{'epoch': 6, 'batch': 62, 'loss': 6.052315711975098}
{'epoch': 6, 'batch': 63, 'loss': 6.39914608001709}
{'epoch': 6, 'batch': 64, 'loss': 6.71010160446167}
{'epoch': 6, 'batch': 65, 'loss': 6.122107982635498}
{'epoch': 6, 'batch': 66, 'loss': 6.226738929748535}
{'epoch': 6, 'batch': 67, 'loss': 6.531006813049316}
{'epoch': 6, 'batch': 68, 'loss': 6.975809097290039}
{'epoch': 6, 'batch': 69, 'loss': 6.295836448669434}
{'epoch': 6, 'batch': 70, 'loss': 6.951907634735107}
{'epoch': 6, 'batch': 71, 'loss': 7.1331915855407715}
{'epoch': 6, 'batch': 72, 'loss': 6.620357990264893}
{'epoch': 6, 'batch': 73, 'loss': 6.79271936416626}
{'epoch': 6, 'batch': 74, 'loss': 6.640481948852539}
{'epoch': 6, 'batch': 75, 'loss': 5.79678297042846

{'epoch': 7, 'batch': 90, 'loss': 6.305891990661621}
{'epoch': 7, 'batch': 91, 'loss': 6.803471088409424}
{'epoch': 7, 'batch': 92, 'loss': 6.50408411026001}
{'epoch': 7, 'batch': 93, 'loss': 6.51019811630249}
{'epoch': 7, 'batch': 94, 'loss': 5.958693504333496}
{'epoch': 7, 'batch': 95, 'loss': 6.552882194519043}
{'epoch': 7, 'batch': 96, 'loss': 6.65570592880249}
{'epoch': 7, 'batch': 97, 'loss': 6.3836588859558105}
{'epoch': 7, 'batch': 98, 'loss': 5.845218658447266}
{'epoch': 7, 'batch': 99, 'loss': 6.2576003074646}
{'epoch': 7, 'batch': 100, 'loss': 6.18068265914917}
{'epoch': 7, 'batch': 101, 'loss': 6.372856140136719}
{'epoch': 7, 'batch': 102, 'loss': 6.134892463684082}
{'epoch': 7, 'batch': 103, 'loss': 6.327523708343506}
{'epoch': 7, 'batch': 104, 'loss': 6.134413719177246}
{'epoch': 7, 'batch': 105, 'loss': 6.271315097808838}
{'epoch': 7, 'batch': 106, 'loss': 6.135250091552734}
{'epoch': 7, 'batch': 107, 'loss': 5.973770618438721}
{'epoch': 7, 'batch': 108, 'loss': 6.320659

{'epoch': 9, 'batch': 4, 'loss': 5.8370890617370605}
{'epoch': 9, 'batch': 5, 'loss': 6.7804694175720215}
{'epoch': 9, 'batch': 6, 'loss': 6.153532981872559}
{'epoch': 9, 'batch': 7, 'loss': 6.316714763641357}
{'epoch': 9, 'batch': 8, 'loss': 6.297121047973633}
{'epoch': 9, 'batch': 9, 'loss': 5.919406890869141}
{'epoch': 9, 'batch': 10, 'loss': 5.951581001281738}
{'epoch': 9, 'batch': 11, 'loss': 6.4891157150268555}
{'epoch': 9, 'batch': 12, 'loss': 6.170708179473877}
{'epoch': 9, 'batch': 13, 'loss': 6.063721179962158}
{'epoch': 9, 'batch': 14, 'loss': 5.98958683013916}
{'epoch': 9, 'batch': 15, 'loss': 6.342564582824707}
{'epoch': 9, 'batch': 16, 'loss': 6.333885669708252}
{'epoch': 9, 'batch': 17, 'loss': 5.9704484939575195}
{'epoch': 9, 'batch': 18, 'loss': 5.918881893157959}
{'epoch': 9, 'batch': 19, 'loss': 5.9579267501831055}
{'epoch': 9, 'batch': 20, 'loss': 5.982428073883057}
{'epoch': 9, 'batch': 21, 'loss': 6.381816387176514}
{'epoch': 9, 'batch': 22, 'loss': 6.294747829437

In [28]:
torch.save(model.state_dict(), "model/lstm_pop_100.pth")

In [7]:
from nltk.tokenize import RegexpTokenizer
text = 'hallo ik \n ben er ook'
text = text.replace('\n', ' Ж ')
print(text)
tokenizer = RegexpTokenizer(r'\w+|\$[\d\.]+|\S+')
tokens = tokenizer.tokenize(text)
tokens

hallo ik  Ж  ben er ook


['hallo', 'ik', 'Ж', 'ben', 'er', 'ook']