In [1]:
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import string
import time
import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

%load_ext autoreload
%autoreload 2

from rnn import RNN, PrepareData

In [2]:
#input_dim = 7
hidden_dim = 30
num_layers = 2
seq_size = 25
all_letters = string.printable
n_letters = len(all_letters)

In [3]:
#dataset = PrepareData('input.txt', seq_size, val_ratio=0.2)
dataset = PrepareData('sample-music2.txt', seq_size)
train_data, val_data = dataset.dataset()

In [19]:
def inputTensor(line):
    if len(line) == 1:
        tensor = torch.zeros(1, n_letters)
        tensor[0, all_letters.find(line[0])] = 1
    else:
        tensor = torch.zeros(len(line)-1, n_letters)
        for idx, letter in enumerate(line[:-1]):
            tensor[idx, all_letters.find(letter)] = 1
    return Variable(tensor.cuda())
        
def targetTensor(line):
    letter_indexes = [all_letters.find(letter) for letter in line[1:]]
    return Variable(torch.LongTensor(letter_indexes).cuda())

In [5]:
model = RNN(n_letters, hidden_dim, num_layers=num_layers, output_size=n_letters)
model = model.cuda()
criterion = nn.CrossEntropyLoss()

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

In [18]:
n_epochs = 20
all_losses = []

print_every = 500
init_hidden_every = 20

start_time = time.time()

for epoch in range(n_epochs):
    train_cnt = 0
    total_loss = 0.
    
    model.hidden = model.init_hidden()
    
    #dataset.shuffle_data()
    #train_data, val_data = dataset.dataset()
    
    for sentence in train_data:
        model.zero_grad()
        
        loss = 0.
        inputs = inputTensor(sentence)
        targets = targetTensor(sentence)
        
        preds = model(inputs)
        loss = criterion(preds, targets)
            
        loss.backward(retain_graph=True)
        optimizer.step()
        total_loss += loss
        train_cnt += 1
        
        if train_cnt % init_hidden_every == 0:
            model.hidden = model.init_hidden()
        
        if train_cnt % print_every == 0:
            total_loss /= print_every
            all_losses.append(total_loss)
            uptime = time.time() - start_time
            print('%dm:%s -- [%d, %d], loss: %f' % (uptime/60, uptime%60, epoch, train_cnt, total_loss))
            total_loss = 0
#scheduler.step()

0m:11.196268320083618 -- [0, 500], loss: 2.852527
0m:25.03300905227661 -- [0, 1000], loss: 2.879465
0m:47.27974033355713 -- [1, 500], loss: 2.827497
1m:0.5234787464141846 -- [1, 1000], loss: 2.853377
1m:22.691142320632935 -- [2, 500], loss: 2.803578
1m:35.68027710914612 -- [2, 1000], loss: 2.827548
1m:58.390618324279785 -- [3, 500], loss: 2.779727
2m:11.446256160736084 -- [3, 1000], loss: 2.803719
2m:33.22517538070679 -- [4, 500], loss: 2.756232
2m:46.25396990776062 -- [4, 1000], loss: 2.780244
3m:8.738785982131958 -- [5, 500], loss: 2.734438
3m:21.562535524368286 -- [5, 1000], loss: 2.758514
3m:43.81132459640503 -- [6, 500], loss: 2.713376
3m:57.11608910560608 -- [6, 1000], loss: 2.735652
4m:19.450064182281494 -- [7, 500], loss: 2.689895
4m:32.194337368011475 -- [7, 1000], loss: 2.714332
4m:54.214170932769775 -- [8, 500], loss: 2.670170
5m:7.438884735107422 -- [8, 1000], loss: 2.694426
5m:29.421112537384033 -- [9, 500], loss: 2.648312
5m:42.83986020088196 -- [9, 1000], loss: 2.673491


In [35]:
def test(start_with, max_length):
    model.hidden = model.init_hidden()
    seq = str(start_with)
    idx = []
    inputs = inputTensor(seq[-1])
    
    for i in range(max_length):
        output = model(inputs)
        _, pred = torch.max(output.data, 1)
        
        seq += all_letters[pred[0]]
        idx.append(pred[0])
        inputs = inputTensor(letter)
    return seq, idx

In [41]:
test_seq = train_data[5]

seq, idx = test(test_seq[0],25)
print('Generated sequence: %s' % (seq))
print('Sequence index: ', idx)
print('Test sequence: %s' % test_seq)

Generated sequence: r ccBBBcBcccBcBcccccccBcBB
Sequence index:  [94, 12, 12, 37, 37, 37, 12, 37, 12, 12, 12, 37, 12, 37, 12, 12, 12, 12, 12, 12, 12, 37, 12, 37, 37]
Test sequence: ree.fr
M: 4/4
L: 1/8
Q:1/


In [None]:
plt.figure()
plt.plot(all_losses)