In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from collections import Counter
import os
from argparse import Namespace
import re
import string
from rhyme_finder import RhymeFinder
import random

flags = Namespace(
    seq_size=32,
    batch_size=16,
    num_batches=1000,
    embedding_size=128,
    lstm_size=128,
    gradients_norm=5,
    initial_words=['heels', 'inch'],
    predict_top_k=5,
    checkpoint_path='checkpoint',
)

In [2]:
import pandas as pd

df = pd.read_csv('DataScraper/hiphop_lyrics.csv')
df = df.append(pd.read_csv('DataScraper/hiphop_lyrics2.csv'))
#df = pd.read_csv('DataScraper/lyrics.csv')
df = df.dropna()

def clean_lyrics(l):
    l = re.sub(r'[\(\[].*?[\)\]]', '', l)
    l = os.linesep.join([s for s in l.splitlines() if s])
    l = l.replace('\r', '').replace('?', '').replace("!", '').replace(',', '').replace('.', '')
    l += '\n'
    l = ''.join([i for i in l if i in string.printable])
    #l = l.replace('\n', '$')
    return l.lower()

df['lyrics'] = df['lyrics'].apply(clean_lyrics)

df

Unnamed: 0,artist,title,url,lyrics
0,The Weeknd,6 Inch Heel,https://genius.com/The-weeknd-6-inch-heel-lyrics,six inch heel she walked in the club like nobo...
1,The Weeknd,Acquainted,https://genius.com/The-weeknd-acquainted-lyrics,baby you're no good\ncause they warned me 'bou...
2,The Weeknd,Adaptation,https://genius.com/The-weeknd-adaptation-lyrics,when the sun comes up you're searching for a l...
3,The Weeknd,After Hours,https://genius.com/The-weeknd-after-hours-lyrics,thought i almost died in my dream again \nfigh...
4,The Weeknd,Airports,https://genius.com/The-weeknd-airports-lyrics,i think i'm fuckin' gone rollin' on this floor...
...,...,...,...,...
6462,YG,Yo Nigga Ain’t Me,https://genius.com/Yg-yo-nigga-aint-me-lyrics,hook: charlie hood and yg\nsee shawty be rocki...
6463,YG,Yo Pussy,https://genius.com/Yg-yo-pussy-lyrics,raw smooth with a banger now\ndon't trip \ni b...
6464,YG,You Betta Kno,https://genius.com/Yg-you-betta-kno-lyrics,ay you don't even know it\ni'm on this bitch\n...
6465,YG,You Broke,https://genius.com/Yg-you-broke-lyrics,bitch you broke shut up\ndont talk to me get y...


In [3]:
rf = RhymeFinder(df['lyrics'])

corpus = ''.join(list(df['lyrics']))

def revert(data):
    lines = data.split('\n')
    lines = [' '.join(x.split(' ')[::-1]) for x in lines]
    lines = lines[::-1]
    lines = ' \n '.join(lines)
    return lines

corpus = revert(corpus)

In [4]:
rf.find_lines_ending_with_word('inch heels')

def get_data_from_file(corpus, batch_size, seq_size):
    text = corpus.split(' ')

    word_counts = Counter(text)
    sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True)
    int_to_vocab = {k: w for k, w in enumerate(sorted_vocab)}
    vocab_to_int = {w: k for k, w in int_to_vocab.items()}
    n_vocab = len(int_to_vocab)

    print('Vocabulary size', n_vocab)

    int_text = [vocab_to_int[w] for w in text]
    num_batches = int(len(int_text) / (seq_size * batch_size))
    in_text = int_text[:num_batches * batch_size * seq_size]
    out_text = np.zeros_like(in_text)
    out_text[:-1] = in_text[1:]
    out_text[-1] = in_text[0]
    in_text = np.reshape(in_text, (batch_size, -1))
    out_text = np.reshape(out_text, (batch_size, -1))
    return int_to_vocab, vocab_to_int, n_vocab, in_text, out_text

In [5]:
def get_batches(in_text, out_text, batch_size, seq_size):
    num_batches = flags.num_batches #np.prod(in_text.shape) // (seq_size * batch_size)
    for i in range(0, num_batches * seq_size, seq_size):
        yield in_text[:, i:i+seq_size], out_text[:, i:i+seq_size]

In [6]:
class RNNModule(nn.Module):
    def __init__(self, n_vocab, seq_size, embedding_size, lstm_size):
        super(RNNModule, self).__init__()
        self.seq_size = seq_size
        self.lstm_size = lstm_size
        self.embedding = nn.Embedding(n_vocab, embedding_size)
        self.lstm = nn.LSTM(embedding_size,
                            lstm_size,
                            batch_first=True)
        self.dense = nn.Linear(lstm_size, n_vocab)
    
    def forward(self, x, prev_state):
        embed = self.embedding(x).float()
        output, state = self.lstm(embed, prev_state)
        logits = self.dense(output)

        return logits, state
    
    def zero_state(self, batch_size):
        return (torch.zeros(1, batch_size, self.lstm_size).float(),
                torch.zeros(1, batch_size, self.lstm_size).float())

In [7]:
def get_loss_and_train_op(net, lr=0.001):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    return criterion, optimizer

In [8]:
def predict(device, net, words, n_vocab, vocab_to_int, int_to_vocab, top_k=5):
    net.eval()

    state_h, state_c = net.zero_state(1)
    state_h = state_h.to(device)
    state_c = state_c.to(device)
    for w in words:
        ix = torch.tensor([[vocab_to_int[w]]]).to(device)
        output, (state_h, state_c) = net(ix, (state_h, state_c))
    
    _, top_ix = torch.topk(output[0], k=top_k)
    choices = top_ix.tolist()[0]
    return [int_to_vocab[choice] for choice in choices]

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
int_to_vocab, vocab_to_int, n_vocab, in_text, out_text = get_data_from_file(
    corpus, flags.batch_size, flags.seq_size)

Vocabulary size 74389


In [10]:
model = RNNModule(n_vocab, flags.seq_size,
                flags.embedding_size, flags.lstm_size)
model.to(device)
criterion, optimizer = get_loss_and_train_op(model, 0.01)

iteration = 0

In [11]:
batches = get_batches(in_text, out_text, flags.batch_size, flags.seq_size)

In [None]:
for e in range(50):
    batches = get_batches(in_text, out_text, flags.batch_size, flags.seq_size)
    state_h, state_c = model.zero_state(flags.batch_size)

    # Transfer data to GPU
    state_h = state_h.to(device)
    state_c = state_c.to(device)
    for x, y in batches:
        iteration += 1

        # Tell it we are in training mode
        model.train()

        # Reset all gradients
        optimizer.zero_grad()

        # Transfer data to GPU
        x = torch.tensor(x).long().to(device)
        y = torch.tensor(y).long().to(device)

        logits, (state_h, state_c) = model(x, (state_h, state_c))
        loss = criterion(logits.transpose(1, 2), y)

        state_h = state_h.detach()
        state_c = state_c.detach()

        loss_value = loss.item()

        # Perform back-propagation
        loss.backward()
        
        _ = torch.nn.utils.clip_grad_norm_(
                model.parameters(), flags.gradients_norm)
        
        # Update the network's parameters
        optimizer.step()
        
        if iteration % 100 == 0:
            print('Epoch: {}/{}'.format(e, 200),
                  'Iteration: {}'.format(iteration),
                  'Loss: {}'.format(loss_value),
                  'Perplexity: {}'.format(np.exp(loss_value)))

        if iteration % 1000 == 0:
            print(predict(device, model, flags.initial_words, n_vocab, vocab_to_int, int_to_vocab, top_k=5))
            torch.save(model.state_dict(),
                       'checkpoint_pt/model-{}.pth'.format(iteration))

Epoch: 0/200 Iteration: 100 Loss: 6.287078380584717 Perplexity: 537.5804270799314
Epoch: 0/200 Iteration: 200 Loss: 5.855099678039551 Perplexity: 349.00968692158983
Epoch: 0/200 Iteration: 300 Loss: 5.998870372772217 Perplexity: 402.97332664544354
Epoch: 0/200 Iteration: 400 Loss: 5.381933212280273 Perplexity: 217.44223135407944
Epoch: 0/200 Iteration: 500 Loss: 5.101699352264404 Perplexity: 164.30087526477155
Epoch: 0/200 Iteration: 600 Loss: 5.28031063079834 Perplexity: 196.43088335794104
Epoch: 0/200 Iteration: 700 Loss: 6.15194034576416 Perplexity: 469.6277434962251
Epoch: 0/200 Iteration: 800 Loss: 5.652627468109131 Perplexity: 285.03941475976507
Epoch: 0/200 Iteration: 900 Loss: 5.957114219665527 Perplexity: 386.4931789682968
Epoch: 0/200 Iteration: 1000 Loss: 5.523105621337891 Perplexity: 250.41151418555282
["i'm", 'my', '\n', '26', 'the']
Epoch: 1/200 Iteration: 1100 Loss: 5.457187175750732 Perplexity: 234.43706580073066
Epoch: 1/200 Iteration: 1200 Loss: 5.203061103820801 Perp

Epoch: 9/200 Iteration: 9600 Loss: 3.7710254192352295 Perplexity: 43.42457040329252
Epoch: 9/200 Iteration: 9700 Loss: 4.087973594665527 Perplexity: 59.61895704162075
Epoch: 9/200 Iteration: 9800 Loss: 3.847871780395508 Perplexity: 46.89315802095715
Epoch: 9/200 Iteration: 9900 Loss: 3.7519211769104004 Perplexity: 42.602851042254215
Epoch: 9/200 Iteration: 10000 Loss: 3.7981345653533936 Perplexity: 44.617875083291246
['26', '7', '8', 'see', 'the']
Epoch: 10/200 Iteration: 10100 Loss: 3.871372699737549 Perplexity: 48.00824177131337
Epoch: 10/200 Iteration: 10200 Loss: 4.034018516540527 Perplexity: 56.487451522625065
Epoch: 10/200 Iteration: 10300 Loss: 4.066092491149902 Perplexity: 58.328597202787925
Epoch: 10/200 Iteration: 10400 Loss: 3.496262311935425 Perplexity: 32.99190775782503
Epoch: 10/200 Iteration: 10500 Loss: 3.439624071121216 Perplexity: 31.17523629352623
Epoch: 10/200 Iteration: 10600 Loss: 3.702659845352173 Perplexity: 40.555031139352124
Epoch: 10/200 Iteration: 10700 Loss

Epoch: 18/200 Iteration: 18900 Loss: 3.4602503776550293 Perplexity: 31.82494377200754
Epoch: 18/200 Iteration: 19000 Loss: 3.540494203567505 Perplexity: 34.483957075027206
['30', '10', 'a', '7', 'heavy']
Epoch: 19/200 Iteration: 19100 Loss: 3.587719678878784 Perplexity: 36.15154473240867
Epoch: 19/200 Iteration: 19200 Loss: 3.81648850440979 Perplexity: 45.444350179276874
Epoch: 19/200 Iteration: 19300 Loss: 3.7240121364593506 Perplexity: 41.430285054711405
Epoch: 19/200 Iteration: 19400 Loss: 3.327667236328125 Perplexity: 27.873244112826377
Epoch: 19/200 Iteration: 19500 Loss: 3.2024176120758057 Perplexity: 24.591912090632203
Epoch: 19/200 Iteration: 19600 Loss: 3.487912654876709 Perplexity: 32.71758349285047
Epoch: 19/200 Iteration: 19700 Loss: 3.7401533126831055 Perplexity: 42.1044448155997
Epoch: 19/200 Iteration: 19800 Loss: 3.4957966804504395 Perplexity: 32.97654926280017
Epoch: 19/200 Iteration: 19900 Loss: 3.5554006099700928 Perplexity: 35.00183924514811
Epoch: 19/200 Iteration:

Epoch: 28/200 Iteration: 28100 Loss: 3.4644856452941895 Perplexity: 31.960016759704246
Epoch: 28/200 Iteration: 28200 Loss: 3.726583957672119 Perplexity: 41.53697347363018
Epoch: 28/200 Iteration: 28300 Loss: 3.5900495052337646 Perplexity: 36.235869747276844
Epoch: 28/200 Iteration: 28400 Loss: 3.2328438758850098 Perplexity: 25.35165150103581
Epoch: 28/200 Iteration: 28500 Loss: 3.184361457824707 Perplexity: 24.151861498825053
