# How to teach a RNN to count

In [255]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
from torch.autograd import Variable
from pprint import pprint

## Create a dataset

In [256]:
def make_string(len=10):
    data_string = []
    for i in range(1, len+1):
        letter_a = ['a']*i
        letter_b = ['b']*i
        string = ''.join(letter_a) \
               + 'X' \
               + ''.join(letter_b) \
               + '.'
        data_string.append(string)
    return data_string

def make_tensor(data_string, vocab_encoder):
    max_seq_len = max([len(string) for string in data_string])
    data_tensor = torch.zeros(len(data_string),
                              max_seq_len,
                              len(vocab_encoder))
    for i, string in enumerate(data_string):
        for j, char in enumerate(string):
            data_tensor[i][j] = vocab_encoder[char]
    return data_tensor

In [337]:
vocab = ['a', 'X', 'b', '.']
data_string = make_string(len=10)
pprint(data_string)

['aXb.',
 'aaXbb.',
 'aaaXbbb.',
 'aaaaXbbbb.',
 'aaaaaXbbbbb.',
 'aaaaaaXbbbbbb.',
 'aaaaaaaXbbbbbbb.',
 'aaaaaaaaXbbbbbbbb.',
 'aaaaaaaaaXbbbbbbbbb.',
 'aaaaaaaaaaXbbbbbbbbbb.']


We must convert our vocabulary to mathematical objects (Vector -> Tensor.dim()==1) to be used by our RNN.

Thus, we use a one-hot encoding function.

In [338]:
vocab_encoder = {
    'a': torch.Tensor([1,0,0,0]),
    'X': torch.Tensor([0,1,0,0]),
    'b': torch.Tensor([0,0,1,0]),
    '.': torch.Tensor([0,0,0,1])
}

def output_decoder(output, vocab):
    max_, argmax = output.data.max(1)
    output_char = vocab[argmax[0][0]]
    return output_char

data_tensor = make_tensor(data_string, vocab_encoder)
print(data_tensor)


(0 ,.,.) = 
   1   0   0   0
   0   1   0   0
   0   0   1   0
   0   0   0   1
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0

(1 ,.,.) = 
   1   0   0   0
   1   0   0   0
   0   1   0   0
   0   0   1   0
   0   0   1   0
   0   0   0   1
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0

(2 ,.,.) = 
   1   0   0   0
   1   0   0   0
   1   0   0   0
   0   1   0   0
   0   0   1   0
   0   0   1   0
   0   0   1   0
   0   0   0   1
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0 

For convenience, we are going to use the utils classes from torch to embed our dataset.

In [357]:
class Dataset(data.Dataset):

    def __init__(self, data_tensor):
        super(Dataset, self).__init__()
        self.data_tensor = data_tensor
        
    def __getitem__(self, index):
        input = self.data_tensor[index]
        target = torch.zeros(input.size())
        target[:-1] = input[1:] # targets are the next chars to predict
        #max_, argmax = target.max(1)
        #target = argmax.view(-1)
        return input, target
        
    def __len__(self):
        return self.data_tensor.size(0)

In [358]:
dataset = Dataset(data_tensor)
for i, (input, target) in enumerate(dataset):
    print(input, target)


    1     0     0     0
    0     1     0     0
    0     0     1     0
    0     0     0     1
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
[torch.FloatTensor of size 22x4]
 
    0     1     0     0
    0     0     1     0
    0     0     0     1
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    

In [359]:
dataloader = data.DataLoader(dataset, batch_size=1, shuffle=True)
for i, (input, target) in enumerate(dataloader):
    print(input, target)


(0 ,.,.) = 
   1   0   0   0
   1   0   0   0
   0   1   0   0
   0   0   1   0
   0   0   1   0
   0   0   0   1
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
[torch.FloatTensor of size 1x22x4]
 
(0 ,.,.) = 
   1   0   0   0
   0   1   0   0
   0   0   1   0
   0   0   1   0
   0   0   0   1
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
[torch.FloatTensor of size 1x22x4]


(0 ,.,.) = 
   1   0   0   0
   1   0   0   0
   1   0   0   0
   1   0   0   0
   1   0   0   0
   1   0   0   0
   0   1   0   0
   0   0   1   0
   0 

## Create a model

We want a RNN able to take a vector representing a char ('a', 'b', 'X' or '.') and to output a distribution over the vocabulary (i.e. a probability vector of size 4).

Thus, we create a model which have:

- a **recurrent cell** to process the new state from the current input and the old state,
- a **linear module** (weight, bias) to process the output vector from the new state,
- a **softmax module** to process the output probabilities from the output vector.

In [355]:
class RNNCell(nn.Module):
    
    def __init__(self, input_size, state_size, output_size):
        super(RNNCell, self).__init__()
        self.input_size = input_size
        self.state_size = state_size
        self.output_size = output_size
        self.rnn_cell = nn.GRUCell(input_size, state_size)
        self.linear = nn.Linear(state_size, output_size)
        self.softmax = nn.Softmax()
        
    def forward(self, input, state):
        state = self.rnn_cell(input, hx=state)
        output = self.linear(state)
        output = self.softmax(output)
        return output, state
    
model = RNNCell(len(vocab), 30, len(vocab))
print(model)
pprint(model.state_dict())

RNNCell (
  (rnn_cell): GRUCell(4, 30)
  (linear): Linear (30 -> 4)
  (softmax): Softmax ()
)
OrderedDict([('rnn_cell.weight_ih',
              
-0.1215  0.1300 -0.0711 -0.1037
 0.0155  0.0796 -0.0894 -0.0398
 0.0832 -0.1129  0.0458  0.1658
 0.0982  0.1771 -0.0576 -0.0538
 0.0013 -0.1532 -0.0706 -0.1712
 0.0817 -0.0316 -0.0525 -0.0074
 0.1392 -0.1137  0.0983  0.1183
-0.1081 -0.0181  0.0001  0.0439
-0.1287  0.1591  0.0183 -0.0248
 0.0366  0.0331  0.1767 -0.1571
-0.1505  0.0814  0.0312  0.0560
 0.0587 -0.1167  0.0004 -0.1109
-0.0612  0.1261 -0.1153  0.0484
 0.1002  0.1091  0.1102  0.0494
 0.1106  0.1534 -0.0160 -0.1150
 0.0125 -0.0479  0.1127 -0.1472
-0.1781 -0.1258  0.0336  0.0076
 0.1181  0.0050 -0.0338 -0.1090
 0.1437 -0.0535 -0.1729 -0.0656
 0.1197 -0.0481 -0.1569 -0.1212
-0.0491 -0.0005  0.1167  0.1378
-0.1821 -0.0224  0.1322  0.0581
 0.1496  0.0289 -0.1654  0.1184
 0.0687  0.0781 -0.1352 -0.0164
 0.0216  0.1339  0.1391 -0.0816
-0.1356 -0.0329 -0.0942 -0.0108
 0.1618  0.1139 -0.0017

In [281]:
?nn.RNNCell

In [282]:
?nn.Linear

In [283]:
?nn.Softmax

## Choose a loss function

We want a loss function to produce an error value from the output of the model and the expected target (the loss function must be derivable). This error value will be backpropagate along the model to process the model parameters derivatives (gradients). 

For the sake of this tutorial, we will choose MSE (mean square error), but we usually use NLL (negative log likelihood) for classification tasks.

In [356]:
loss = nn.CrossEntropyLoss()
loss = nn.MSELoss()

In [100]:
?nn.MSELoss

In [None]:
?nn.NLLLoss

In [190]:
?nn.CrossEntropyLoss

## Choose an optimizer

After having processed the gradients of the model parameters, we want an optimizer to update the model parameters of a certain amount (learning rate).

I personnaly find Adam to be easier to optimize than the classical SGD (stochastique gradient descent). Thus, we will use this one in this tutorial.

In [362]:
#optimizer = optim.SGD(model.parameters(), lr=0.00001, momentum=0.7)
optimizer = optim.Adam(model.parameters(), lr=1e-5)

In [112]:
?optim.SGD

In [None]:
?optim.Adam

## Train the RNN

In [360]:
def train_epoch(dataloader, model, loss, optimizer):
    model.train()
    
    error_epoch = 0
    
    # iterate over one sequence at a time
    for i, (input_data, target_data) in enumerate(dataloader):

        # convert data to variable for torch computational graph
        input = Variable(input_data)
        target = Variable(target_data)
        state = Variable(torch.zeros(1, model.state_size))

        # initialize the error to 0
        error_seq = Variable(torch.zeros(1))

        # iterate over the sequence
        seq_len = 0
        for t in range(input.size(1)):

            # doesnt process 0-padded values
            if input[:,t].data.sum() == 0:
                break

            # compute the char at time t (model forward)
            output, state = model(input[:,t], state)

            # compute the error at time t (loss forward)
            if input[:,t].data[0][2] == 1 or input[:,t].data[0][3] == 1:
                error_t = loss(output, target[:,t])
                error_seq += error_t
                seq_len += 1

        # takes the sequence length in count (average)
        error_seq /= seq_len
        error_epoch += error_seq.data[0]

        # compute the gradients over the computational graph
        # and update the model parameters
        optimizer.zero_grad()
        error_seq.backward()
        optimizer.step()
    return error_epoch
        

In [368]:
for epoch in range(2000):
    error = train_epoch(dataloader, model, loss, optimizer)
    if epoch%100==0:
        print('epoch: {}\t error: {}'.format(epoch, error))

epoch: 0	 error: 0.13415562687441707
epoch: 100	 error: 0.13345584832131863
epoch: 200	 error: 0.13287313655018806
epoch: 300	 error: 0.13218373665586114
epoch: 400	 error: 0.13163263630121946
epoch: 500	 error: 0.1311186901293695
epoch: 600	 error: 0.13067210977897048
epoch: 700	 error: 0.13025506818667054
epoch: 800	 error: 0.12987314024940133
epoch: 900	 error: 0.12956884922459722
epoch: 1000	 error: 0.12922589760273695
epoch: 1100	 error: 0.12896667513996363
epoch: 1200	 error: 0.12873190036043525
epoch: 1300	 error: 0.12850609654560685
epoch: 1400	 error: 0.12832213891670108
epoch: 1500	 error: 0.12812280468642712
epoch: 1600	 error: 0.12796571664512157
epoch: 1700	 error: 0.12781002698466182
epoch: 1800	 error: 0.12767617590725422
epoch: 1900	 error: 0.12757134111598134


## Try out the model

In [375]:
state = Variable(torch.zeros(1, model.state_size), requires_grad=False)
model.eval()

nb_a = 6

for i in range(nb_a):
    input_char = 'a'
    input = Variable(vocab_encoder[input_char].view(1, -1), requires_grad=False)
    output, state = model(input, state)
    output_char = output_decoder(output, vocab)
    print("input->output\t{}->{}".format(input_char, output_char))
    
input_char = 'X'
input = Variable(vocab_encoder[input_char].view(1, -1), requires_grad=False)
output, state = model(input, state)
output_char = output_decoder(output, vocab)
print("input->output\t{}->{}".format(input_char, output_char))

for i in range(nb_a):
    input_char = 'b'
    input = Variable(vocab_encoder[input_char].view(1, -1), requires_grad=False)
    output, state = model(input, state)
    output_char = output_decoder(output, vocab)
    print("input->output\t{}: {}->{}".format(i+1, input_char, output_char))
    #print(output.data)




input->output	a->.
input->output	a->b
input->output	a->b
input->output	a->b
input->output	a->b
input->output	a->b
input->output	X->b
input->output	1: b->b

 3.1114e-08  4.0633e-08  1.0000e+00  2.8823e-07
[torch.FloatTensor of size 1x4]

input->output	2: b->b

 8.5308e-08  1.0507e-07  1.0000e+00  1.0286e-06
[torch.FloatTensor of size 1x4]

input->output	3: b->b

 3.6206e-07  4.2335e-07  9.9999e-01  6.6331e-06
[torch.FloatTensor of size 1x4]

input->output	4: b->b

 3.3475e-06  3.6889e-06  9.9987e-01  1.2074e-04
[torch.FloatTensor of size 1x4]

input->output	5: b->b

 0.0002  0.0002  0.9797  0.0200
[torch.FloatTensor of size 1x4]

input->output	6: b->.

 0.0014  0.0008  0.0074  0.9903
[torch.FloatTensor of size 1x4]

