# Temporal Order Classification Experiments in PyTorch

In [None]:
from sequential_tasks import TemporalOrderExp6aSequence

## 1. Defining the RNN


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Set the random seed for reproducible results.
torch.manual_seed(1)


class SimpleRNN(nn.Module):
    """A simple RNN class.
    
    In PyTorch, subclassing your model from torch.nn.Module
    takes care of low-level concerns for you such as defining
    the backward pass and keeping track of network parameters.
    
    Arguments
    ----------
    input_size : int
        The number of features in the input tensor.
    hidden_size : int
        The number of features in the hidden state of the RNN.
    output_size : int
        The number of classes in the output tensor.
    """
    def __init__(self, input_size, hidden_size, output_size):
        # This just calls the base class constructor.
        super().__init__()
        # Neural network layers assigned as attributes of a Module subclass
        # have their parameters registered for training automatically.
        self.rnn = torch.nn.RNN(input_size, hidden_size, nonlinearity='relu', batch_first=True)
        self.linear = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        """The forward pass of the simple RNN.
        
        Subclasses of torch.nn.Module require the user to override
        the forward method to define their computation steps.
        
        Arguments
        ----------
        x : torch.Tensor
            The input tensor of shape (batch_size, max_sequence_length, input_size).

        Returns
        -------
        out : torch.Tensor
            The output tensor of shape (batch_size, max_sequence_length, output_size)
        """
        # The RNN also returns its hidden state but we don't use it in this example.
        # While the RNN can also take a hidden state as input, the RNN
        # gets passed a hidden state initialized with zeros by default.
        x, _ = self.rnn(x)
        x = self.linear(x)
        out = F.log_softmax(x, dim=1)
        return out

## 2. Defining the Training Loop


In [None]:
def train(model, train_data_gen, criterion, optimizer, device):
    """Train a model to classify sequences for the temporal ordering problem.
    
    Parameters
    ----------
    model : torch.nn.Module
        The model to train.
    train_data_gen : TemporalOrderExp6aSequence
        The data generator instance used to produce training data.
    criterion : torch.nn.Module
        The loss function.
    optimizer : torch.optim.Optimizer
        The optimization algorithm used to updated the network parameters.
    device : torch.device
        The device to which tensors will be moved.

    Returns
    -------
    tuple
        The number of correctly classified sequences and the loss, respectively.
    """
    # Set the model to training mode. This will turn on layers that would
    # otherwise behave differently during evaluation, such as dropout or batch normalization.
    model.train()
    
    # Store the number of sequences that were classified correctly.
    num_correct = 0
    
    # Iterate over every batch of sequences. Note that the length of a data generator
    # is defined as the number of batches required to produce a total of roughly 1000
    # sequences given a batch size.
    for batch_idx in range(len(train_data_gen)):
        # For each new batch, clear the gradient buffers of the optimized parameters.
        # Otherwise, gradients from the previous batch would be accumulated.
        optimizer.zero_grad()
        
        # Request a batch of sequences and class labels, convert them into tensors
        # of the correct type, and then send them to the appropriate device.
        data, target = train_data_gen[batch_idx]
        data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device)
        
        # Perform the forward pass of the model.
        output = model(data)
        
        # Although the sequences are padded to a specific length, we noted earlier that they
        # actually have variable length. This point becomes relevant for interpreting the model's
        # output. We are interested in the output of the model AFTER it has seen an entire sequence,
        # which for our problem means just after the model encounters the stop token "E". But the
        # model produces an output at each step of a sequence, which is why the shape of the output
        # is (batch_size, max_sequence_length, output_size).
        #
        # We want to keep only those outputs which correspond to the end of a sequence. One approach
        # is to decode the sequences and then store the length of each sequence in a tensor. We can
        # then use this tensor for fancy indexing of the output, after subtracting 1 from it to
        # account for tensor indices starting from 0. The first dimension of the output needs to be
        # indexed using arange for the indexing to correctly pick out each batch. The final output
        # will be of shape (batch_size, output_size).
        data_decoded = train_data_gen.decode_x_batch(data.numpy())
        sequence_end = torch.tensor([len(sequence) for sequence in data_decoded]) - 1
        output = output[torch.arange(data.shape[0]).long(), sequence_end, :]

        # Compute the value of the loss for this batch. For loss functions like CrossEntropyLoss,
        # the second argument is actually expected to be a tensor of class indices rather than
        # one-hot encoded class labels. One approach is to take advantage of the one-hot encoding
        # of the target and call argmax along its second dimension to create a tensor of shape
        # (batch_size) containing the index of the class label that was hot for each sequence.
        target = target.argmax(dim=1)
        loss = criterion(output, target)
        
        # Backpropagation through time in two lines!
        loss.backward()
        optimizer.step()
        
        # One way to find the number of correctly classified sequences for each batch is to call
        # argmax along the second dimension of the output. This works because the output of the
        # RNN is the set of log probabilities for that sequence to belong to one of the classes,
        # so the result is a tensor of shape (batch_size) containing the class index with the
        # highest log probability for each sequence. We can then check for element-wise equality
        # between the predictions and the target and reduce the result to a scalar using sum.
        y_pred = output.argmax(dim=1)
        num_correct += (y_pred == target).sum().item()

    return num_correct, loss.item()

## 3. Defining the Testing Loop


In [None]:
def test(model, test_data_gen, criterion, device):
    """Test a model's classification performance for the temporal ordering problem.
    
    Parameters
    ----------
    model : torch.nn.Module
        The model to evaluate.
    test_data_gen : TemporalOrderExp6aSequence
        The data generator instance used to produce test data.
    criterion : torch.nn.Module
        The loss function.
    device : torch.device
        The device to which tensors will be moved.

    Returns
    -------
    tuple
        The number of correctly classified sequences and the loss, respectively.
    """
    # Set the model to evaluation mode. This will turn off layers that would
    # otherwise behave differently during training, such as dropout.
    model.eval()
    
    # Store the number of sequences that were classified correctly.
    num_correct = 0

    # A context manager is used to disable gradient calculations during inference
    # to reduce memory usage, as we typically don't need the gradients at this point.
    with torch.no_grad():
        for batch_idx in range(len(test_data_gen)):
            data, target = test_data_gen[batch_idx]
            data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).long().to(device)
            
            output = model(data)
            data_decoded = test_data_gen.decode_x_batch(data.numpy())
            sequence_end = torch.tensor([len(sequence) for sequence in data_decoded]) - 1
            output = output[torch.arange(data.shape[0]).long(), sequence_end, :]

            target = target.argmax(dim=1)
            loss = criterion(output, target)

            y_pred = output.argmax(dim=1)
            num_correct += (y_pred == target).sum().item()

    return num_correct, loss.item()

## 4. Putting it All Together

Now that we have defined the training and testing loops, our simple RNN is ready for training! Let us combined them into a single function below to streamline the training and evaluation.

In [None]:
import matplotlib.pyplot as plt

%matplotlib inline


def train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs, verbose=True):
    """Train a model and monitor its performance.
    
    Parameters
    ----------
    model : torch.nn.Module
        The model to evaluate.
    train_data_gen : TemporalOrderExp6aSequence
        The data generator instance used to produce training data.
    test_data_gen : TemporalOrderExp6aSequence
        The data generator instance used to produce test data.
    criterion : torch.nn.Module
        The loss function.
    optimizer : torch.optim.Optimizer
        The optimization algorithm used to updated the network parameters.
    max_epochs : int
        The maximum number of times model iterates through all of the
        training data during training.
    verbose : bool, optional
        Report the loss and accuracy over the training and test sets for
        every epoch. The default is True.
    """
    # Automatically determine the device that PyTorch should use for computation.
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    # Track the value of the loss function and model accuracy across epochs.
    history_train = {'loss': [], 'acc': []}
    history_test = {'loss': [], 'acc': []}
    
    for epoch in range(max_epochs):
        # Run the training loop and calculate the accuracy.
        # Remember that the length of a data generator is the number of batches,
        # so we multiply it by the batch size to recover the total number of sequences.
        num_correct, loss = train(model, train_data_gen, criterion, optimizer, device)
        accuracy = float(num_correct) / (len(train_data_gen) * train_data_gen.batch_size) * 100
        history_train['loss'].append(loss)
        history_train['acc'].append(accuracy)
        
        # Do the same for the testing loop.
        num_correct, loss = test(model, test_data_gen, criterion, device)
        accuracy = float(num_correct) / (len(test_data_gen) * test_data_gen.batch_size) * 100
        history_test['loss'].append(loss)
        history_test['acc'].append(accuracy)

        if verbose or epoch + 1 == max_epochs:
            print(f'[Epoch {epoch + 1}/{max_epochs}]'
                  f" loss: {history_train['loss'][-1]:.4f}, acc: {history_train['acc'][-1]:2.2f}%"
                  f" - test_loss: {history_test['loss'][-1]:.4f}, test_acc: {history_test['acc'][-1]:2.2f}%")
    
    # Generate diagnostic plots for the loss and accuracy.
    fig, axes = plt.subplots(ncols=2, figsize=(9, 4.5))
    for ax, metric in zip(axes, ['loss', 'acc']):
        ax.plot(history_train[metric])
        ax.plot(history_test[metric])
        ax.set_xlabel('epoch', fontsize=12)
        ax.set_ylabel(metric, fontsize=12)
        ax.legend(['Train', 'Test'], loc='best')
    plt.show()

    return model

## 5. Running the experiment

In [None]:
# Setup the training and test data generators.
difficulty     = TemporalOrderExp6aSequence.DifficultyLevel.EASY
batch_size     = 32
train_data_gen = TemporalOrderExp6aSequence.data_generator(difficulty, batch_size)
test_data_gen  = TemporalOrderExp6aSequence.data_generator(difficulty, batch_size)

# Setup the RNN and training settings.
input_size  = train_data_gen.n_symbols
hidden_size = 4
output_size = train_data_gen.n_classes
model       = SimpleRNN(input_size, hidden_size, output_size)
criterion   = torch.nn.CrossEntropyLoss()
optimizer   = torch.optim.RMSprop(model.parameters(), lr=0.02)
max_epochs  = 10

# Train the model.
model = train_and_test(model, train_data_gen, test_data_gen, criterion, optimizer, max_epochs)