In [2]:
from sequential_tasks import TemporalOrderExp6aSequence

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

<torch._C.Generator at 0x116c0e390>

In [3]:
class SimpleRNN(nn.Module):

    def __init__(self, input_size, rnn_hidden_size, output_size):

        super(SimpleRNN, self).__init__()
        self.rnn = torch.nn.RNN(input_size, rnn_hidden_size, num_layers=1, nonlinearity='relu', batch_first=True)
        #self.h_0 = self.initialize_hidden(rnn_hidden_size)
        self.linear = torch.nn.Linear(rnn_hidden_size, output_size) 

    def forward(self, x):
        #x = x.unsqueeze(0)
        #self.rnn.flatten_parameters()
        #out, self.h_0 = self.rnn(x, self.h_0)
        out, _ = self.rnn(x)
        out = self.linear(out)
        return F.log_softmax(out, dim=1)

    #def initialize_hidden(self, rnn_hidden_size):
    #    return Variable(torch.randn(2, 1, rnn_hidden_size), requires_grad=True)

In [4]:
def exp6a_experiment(settings):
    train_data_gen = TemporalOrderExp6aSequence.get_predefined_generator(
        settings['difficulty'],
        settings['batch_size'])

    model = SimpleRNN(train_data_gen.n_symbols, settings['h_units'], train_data_gen.n_classes)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)

    epochs = settings['max_epochs']
    log_interval = settings['log_interval']

    epoch = 0
    while epoch < epochs:
        predictions = []
        truth_values = []
   
        for batch_idx in range(len(train_data_gen)):
            xs, ys = train_data_gen[batch_idx]
            xs, ys = torch.from_numpy(xs).float(), torch.from_numpy(ys).long()

            y_pred = model(xs)
            loss = criterion(y_pred, ys)
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            #nn.utils.clip_grad_norm(model.parameters(), 0.5)
            optimizer.step()

            predictions.append(y_pred.cpu().data.numpy().ravel())
            truth_values.append(ys.cpu().data.numpy().ravel())

            if batch_idx % log_interval == 0:
                print('Train Epoch: {}, mini-batch {} of {}, training loss: {:.6f}'.format(
                    epoch, batch_idx, len(train_data_gen), loss.item()))

        epoch += 1
    
    # testing
    test_data_gen = TemporalOrderExp6aSequence.get_predefined_generator(
        settings['difficulty'],
        settings['batch_size'])

    correct = 0
    with torch.no_grad():
        for i in range(len(test_data_gen)):
            xs, ys = test_data_gen[i]
            xs, ys = torch.from_numpy(xs).float(), torch.from_numpy(ys).long()
            output = model(xs)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(ys.view_as(pred)).sum().item()

    test_accuracy = float(correct) / len(test_data_gen)
    print('\nAccuracy: {}'.format(test_accuracy))
    
    return test_accuracy

In [5]:
# experiments settings
params = {
    "difficulty": TemporalOrderExp6aSequence.DifficultyLevel.EASY,
    "batch_size": 32,
    "h_units": 4,
    "max_epochs": 30,
    "log_interval": 10
}

acc = exp6a_experiment(params)

Train Epoch: 0, mini-batch 0 of 31, training loss: 2.223053
Train Epoch: 0, mini-batch 10 of 31, training loss: 2.174639
Train Epoch: 0, mini-batch 20 of 31, training loss: 2.136809
Train Epoch: 0, mini-batch 30 of 31, training loss: 2.114726
Train Epoch: 1, mini-batch 0 of 31, training loss: 2.098664
Train Epoch: 1, mini-batch 10 of 31, training loss: 2.066840
Train Epoch: 1, mini-batch 20 of 31, training loss: 2.032968
Train Epoch: 1, mini-batch 30 of 31, training loss: 1.991863
Train Epoch: 2, mini-batch 0 of 31, training loss: 1.988699
Train Epoch: 2, mini-batch 10 of 31, training loss: 1.952237
Train Epoch: 2, mini-batch 20 of 31, training loss: 1.883309
Train Epoch: 2, mini-batch 30 of 31, training loss: 1.716444
Train Epoch: 3, mini-batch 0 of 31, training loss: 1.708858
Train Epoch: 3, mini-batch 10 of 31, training loss: 1.566915
Train Epoch: 3, mini-batch 20 of 31, training loss: 1.387536
Train Epoch: 3, mini-batch 30 of 31, training loss: 1.258415
Train Epoch: 4, mini-batch 0

In [6]:
print('acc = {:.2f}%.'.format(acc))

acc = 96.00%.
