In [1]:
import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import string
import time
import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

%load_ext autoreload
%autoreload 2

from rnn import RNN, PrepareData

In [2]:
#input_dim = 7
hidden_dim = 30
num_layers = 2
seq_size = 25
all_letters = string.printable
n_letters = len(all_letters)

In [3]:
#dataset = PrepareData('input.txt', seq_size, val_ratio=0.2)
dataset = PrepareData('sample-music2.txt', seq_size)
train_data, val_data = dataset.dataset()

In [4]:
def inputTensor(line):
    tensor = torch.zeros(len(line), 1, n_letters)
    for idx, letter in enumerate(line[:-1]):
        tensor[idx][0][all_letters.find(letter)] = 1
    return Variable(tensor.cuda())
        
def targetTensor(line):
    letter_indexes = [all_letters.find(letter) for letter in line[1:]]
    return Variable(torch.LongTensor(letter_indexes).cuda())

In [6]:
model = RNN(n_letters, hidden_dim, num_layers=num_layers, output_size=n_letters)
model = model.cuda()
criterion = nn.CrossEntropyLoss()

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

In [8]:
n_epochs = 3
all_losses = []

print_every = 100
init_hidden_every = 20

start_time = time.time()

for epoch in range(n_epochs):
    train_cnt = 0
    total_loss = 0.
    
    model.hidden = model.init_hidden()
    
    #dataset.shuffle_data()
    #train_data, val_data = dataset.dataset()
    
    for sentence in train_data:
        model.zero_grad()
        
        loss = 0.
        inputs = inputTensor(sentence)
        targets = targetTensor(sentence)
        
        for i in range(len(sentence)-1):
            preds = model(inputs[i])
            loss += criterion(preds, targets[i])
            
        loss.backward(retain_graph=True)
        optimizer.step()
        total_loss += loss
        train_cnt += 1
        
        if train_cnt % init_hidden_every == 0:
            model.hidden = model.init_hidden()
        
        if train_cnt % print_every == 0:
            total_loss /= print_every
            all_losses.append(total_loss)
            uptime = time.time() - start_time
            print('%dm:%s -- [%d, %d], loss: %f' % (uptime/60, uptime%60, epoch, train_cnt, total_loss))
            total_loss = 0
#scheduler.step()

0m:28.384814739227295 -- [0, 100], loss: 109.981651
0m:57.71834897994995 -- [0, 200], loss: 107.848053
1m:27.05811619758606 -- [0, 300], loss: 92.608452
1m:56.42622637748718 -- [0, 400], loss: 85.167091
2m:25.824209213256836 -- [0, 500], loss: 87.631920
2m:55.2582688331604 -- [0, 600], loss: 94.421822
3m:24.774600744247437 -- [0, 700], loss: 91.392807
3m:53.98620676994324 -- [0, 800], loss: 85.942078
4m:23.5032639503479 -- [0, 900], loss: 80.309448
4m:52.73165774345398 -- [0, 1000], loss: 86.571884
5m:22.315881967544556 -- [0, 1100], loss: 83.197678
5m:51.640673875808716 -- [0, 1200], loss: 83.989510
6m:21.053335428237915 -- [0, 1300], loss: 92.340523
7m:6.79027247428894 -- [1, 100], loss: 87.203514
7m:35.998796701431274 -- [1, 200], loss: 84.411278
8m:6.418741703033447 -- [1, 300], loss: 77.530220
8m:37.39429306983948 -- [1, 400], loss: 79.619461
9m:7.832164287567139 -- [1, 500], loss: 83.479614
9m:37.51881670951843 -- [1, 600], loss: 89.655380
10m:7.402967214584351 -- [1, 700], loss:

In [10]:
max_length = 25

out_seq = []
out_idx = []
test_seq = train_data[0]
inputs = inputTensor(test_seq)
model.hidden = model.init_hidden()
for i in range(len(test_seq)-1):
    output = model(inputs[i])
    _, pred = torch.max(output.data, 1)
    out_seq.append(all_letters[int(pred)])
    out_idx.append(pred[0])

print(out_seq)
print(out_idx)
print(test_seq)

[' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ']
[94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94, 94]
$
X:1
T: La Montfarine
Z:


In [None]:
plt.figure()
plt.plot(all_losses)