In [None]:
import pandas as pd
import torch
import torch.nn as nn
from torch import optim, Tensor
import torch.functional as F
from chords_dataset import ChordsDataset
from model_helpers import NLP, preprocess_text, timeSince, asMinutes
import random
import matplotlib.pyplot as plt
import pickle
import time
from typing import List
from rnn_model import EncoderRNN, DecoderRNN
plt.switch_backend('agg')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

df = pd.read_csv('./model_data/lyrics_processed_2.csv')
dataset = ChordsDataset(df, NLP)
train_i, test_i, validation_i = dataset.get_train_test_valid_indexes(0.9,0.09,0.01)
with open("stored_dataset.pickle", "wb") as f:
    pickle.dump(dataset,f)


In [2]:

## Train function 

def train(input_tensor: Tensor, target_tensor: Tensor,rarity_tensor: Tensor,  encoder: nn.Module, decoder: nn.Module, encoder_optimizer, decoder_optimizer, criterion, max_length=7039, teacher_forcing_ratio = 0.9):
    encoder_hidden = encoder.initHidden()

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = 4
    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

    loss = 0
    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(
            input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]

    decoder_input =target_tensor[0].view(1,1).type(torch.LongTensor).to(device)

    decoder_hidden = encoder_hidden
    decoded = [decoder_input]

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    if use_teacher_forcing:
        # Teacher forcing: Feed the target as the next input
        for di in range(1,4):
            decoder_output, decoder_hidden  = decoder(
                decoder_input, decoder_hidden)
            loss += criterion(
                decoder_output.view(1,-1,1).to(device),
                target_tensor[di].view(1,1).type(torch.LongTensor).to(device)
            )
            decoded.append(decoder_output.topk(1)[1])
            decoder_input = target_tensor[di].view(1,1).type(torch.LongTensor)    # Teacher forcing

    else:
        # Without teacher forcing: use its own predictions as the next input
        for di in range(1,4):
            decoder_output, decoder_hidden = decoder(
                decoder_input, decoder_hidden)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()  # detach from history as input
            decoded.append(topi)
            loss += criterion(
                decoder_output.view(1,-1,1).to(device),
                target_tensor[di].view(1,1).type(torch.LongTensor).to(device)
            )

    weighted_loss = loss * rarity_tensor.item()
    weighted_loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length

In [4]:
# Batch train function 

def trainIters(encoder: nn.Module, decoder: nn.Module ,train_indexes: List[int], dataset: ChordsDataset, print_every=1000, plot_every=100, learning_rate=0.01):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss()

    for i,index in enumerate(train_indexes):
        elem = dataset[index]
        source, target, rarity = elem["lyrics"].reshape(-1,1,100).to(device), elem["chords"].reshape(-1,1).to(device), Tensor([elem["rarity"]]).reshape(-1,1).to(device)
        loss = train(source, target,rarity, encoder,
                    decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if i % print_every == 1:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print(timeSince(start, i / len(train_i)), (i / len(train_i)) * 100, print_loss_avg)

        if i % print_every == 1:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0
        if i / len(train_i) > 0.45:
            break


In [None]:
hidden_size = 256
output_dim = len(dataset.chords_set)
encoder1 = EncoderRNN(100, hidden_size, device).to(device)
attn_decoder1 = DecoderRNN(hidden_size, output_dim, device).to(device)

trainIters(encoder1, attn_decoder1,train_i, dataset, print_every=75, learning_rate= 2e-4)

In [171]:
torch.save(encoder1.state_dict(), "./model_trained/encoder.pt")
torch.save(attn_decoder1.state_dict(), "./model_trained/decoder.pt")

['Fdim/G#', 'F#', 'B', 'F#']