In [None]:
from sequential_tasks import TemporalOrderExp6aSequence

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

# Specify experiment settings and prepare the data

In [None]:
# experiments settings
settings = {
    "difficulty": TemporalOrderExp6aSequence.DifficultyLevel.EASY,
    "batch_size": 32,
    "h_units": 4,
    "max_epochs": 10
}

In [None]:
#training data
train_data_gen = TemporalOrderExp6aSequence.get_predefined_generator(
    settings['difficulty'],
    settings['batch_size'])
train_size = len(train_data_gen)

# testing data
test_data_gen = TemporalOrderExp6aSequence.get_predefined_generator(
    settings['difficulty'],
    settings['batch_size'])
test_size = len(test_data_gen)   

# Define neural network

In [None]:
class SimpleRNN(nn.Module):

    def __init__(self, input_size, rnn_hidden_size, output_size):

        super(SimpleRNN, self).__init__()
        self.rnn = torch.nn.RNN(input_size, rnn_hidden_size, num_layers=1, nonlinearity='relu', batch_first=True)
        self.linear = torch.nn.Linear(rnn_hidden_size, output_size) 

    def forward(self, x):
        x, _ = self.rnn(x)
        x = self.linear(x)
        return F.log_softmax(x, dim=1)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define training loop

In [None]:
def train():
    model.train()
    correct = 0
    for batch_idx in range(train_size):
        data, target = train_data_gen[batch_idx]
        data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device)
        optimizer.zero_grad()
        y_pred = model(data)
        loss = criterion(y_pred, target)
        loss.backward()
        optimizer.step()
        
        pred = y_pred.max(1, keepdim=True)[1]
        correct += pred.eq(target.view_as(pred)).sum().item()
    return correct, loss    

In [None]:
def test():
    model.eval()   
    correct = 0
    with torch.no_grad():
        for batch_idx in range(test_size):
            data, target = test_data_gen[batch_idx]
            data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device)
            y_pred = model(data)
            pred = y_pred.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    return correct

# Initialize the Model and Optimizer

In [None]:
model = SimpleRNN(train_data_gen.n_symbols, settings['h_units'], train_data_gen.n_classes)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)

# Train the model

In [None]:
#train for max_epochs epochs
epochs = settings['max_epochs']
epoch = 0
while epoch < epochs:
    correct, loss = train()

    epoch += 1
    train_accuracy = float(correct) / train_size
    print('Train Epoch: {}/{}, loss: {:.4f}, accuracy {:2.2f}'.format(epoch, epochs, loss.item(), train_accuracy))

#test    
correct = test()
test_accuracy = float(correct) / test_size
print('\nTest accuracy: {}'.format(test_accuracy))

In [None]:
print('acc = {:.2f}%.'.format(test_accuracy))