# How to teach a RNN to count chars

In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
from torch.autograd import Variable
from pprint import pprint

## Create a dataset

In [3]:
def make_string(len=10):
    data_string = []
    for i in range(1, len+1):
        letter_a = ['a']*i
        letter_b = ['b']*i
        string = ''.join(letter_a) \
               + 'X' \
               + ''.join(letter_b) \
               + '.'
        data_string.append(string)
    return data_string

def make_tensor(data_string, vocab_encoder):
    max_seq_len = max([len(string) for string in data_string])
    data_tensor = torch.zeros(len(data_string),
                              max_seq_len,
                              len(vocab_encoder))
    for i, string in enumerate(data_string):
        for j, char in enumerate(string):
            data_tensor[i][j] = vocab_encoder[char]
    return data_tensor

In [12]:
vocab = ['a', 'X', 'b', '.']
data_string = make_string(len=10)
del data_string[4]
del data_string[4]
pprint(data_string)

['aXb.',
 'aaXbb.',
 'aaaXbbb.',
 'aaaaXbbbb.',
 'aaaaaaaXbbbbbbb.',
 'aaaaaaaaXbbbbbbbb.',
 'aaaaaaaaaXbbbbbbbbb.',
 'aaaaaaaaaaXbbbbbbbbbb.']


We must convert our vocabulary to mathematical objects (Vector -> Tensor.dim()==1) to be used by our RNN.

Thus, we use a one-hot encoding function.

In [13]:
vocab_encoder = {
    'a': torch.Tensor([1,0,0,0]),
    'X': torch.Tensor([0,1,0,0]),
    'b': torch.Tensor([0,0,1,0]),
    '.': torch.Tensor([0,0,0,1])
}

def output_decoder(output, vocab):
    max_, argmax = output.data.max(1)
    output_char = vocab[argmax[0][0]]
    return output_char

data_tensor = make_tensor(data_string, vocab_encoder)
print(data_tensor)


(0 ,.,.) = 
   1   0   0   0
   0   1   0   0
   0   0   1   0
   0   0   0   1
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0

(1 ,.,.) = 
   1   0   0   0
   1   0   0   0
   0   1   0   0
   0   0   1   0
   0   0   1   0
   0   0   0   1
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0

(2 ,.,.) = 
   1   0   0   0
   1   0   0   0
   1   0   0   0
   0   1   0   0
   0   0   1   0
   0   0   1   0
   0   0   1   0
   0   0   0   1
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0 

For convenience, we are going to use the utils classes from torch to embed our dataset.

In [14]:
class Dataset(data.Dataset):

    def __init__(self, data_tensor):
        super(Dataset, self).__init__()
        self.data_tensor = data_tensor
        
    def __getitem__(self, index):
        input = self.data_tensor[index]
        target = torch.zeros(input.size())
        target[:-1] = input[1:] # targets are the next chars to predict
        #max_, argmax = target.max(1)
        #target = argmax.view(-1)
        return input, target
        
    def __len__(self):
        return self.data_tensor.size(0)

In [15]:
dataset = Dataset(data_tensor)
for i, (input, target) in enumerate(dataset):
    print(input, target)


    1     0     0     0
    0     1     0     0
    0     0     1     0
    0     0     0     1
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
[torch.FloatTensor of size 22x4]
 
    0     1     0     0
    0     0     1     0
    0     0     0     1
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    0     0     0     0
    

In [16]:
dataloader = data.DataLoader(dataset, batch_size=1, shuffle=True)
for i, (input, target) in enumerate(dataloader):
    print(input, target)


(0 ,.,.) = 
   1   0   0   0
   0   1   0   0
   0   0   1   0
   0   0   0   1
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
[torch.FloatTensor of size 1x22x4]
 
(0 ,.,.) = 
   0   1   0   0
   0   0   1   0
   0   0   0   1
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
   0   0   0   0
[torch.FloatTensor of size 1x22x4]


(0 ,.,.) = 
   1   0   0   0
   1   0   0   0
   1   0   0   0
   1   0   0   0
   1   0   0   0
   1   0   0   0
   1   0   0   0
   1   0   0   0
   1 

## Create a model

We want a RNN able to take a vector representing a char ('a', 'b', 'X' or '.') and to output a distribution over the vocabulary (i.e. a probability vector of size 4).

Thus, we create a model which have:

- a **recurrent cell** to process the new state from the current input and the old state,
- a **linear module** (weight, bias) to process the output vector from the new state,
- a **softmax module** to process the output probabilities from the output vector.

In [17]:
class RNNCell(nn.Module):
    
    def __init__(self, input_size, state_size, output_size):
        super(RNNCell, self).__init__()
        self.input_size = input_size
        self.state_size = state_size
        self.output_size = output_size
        self.rnn_cell = nn.GRUCell(input_size, state_size)
        self.linear = nn.Linear(state_size, output_size)
        self.softmax = nn.Softmax()
        
    def forward(self, input, state):
        state = self.rnn_cell(input, hx=state)
        output = self.linear(state)
        output = self.softmax(output)
        return output, state
    
model = RNNCell(len(vocab), 30, len(vocab))
print(model)
pprint(model.state_dict())

RNNCell (
  (rnn_cell): GRUCell(4, 30)
  (linear): Linear (30 -> 4)
  (softmax): Softmax ()
)
OrderedDict([('rnn_cell.weight_ih',
              
-0.0940  0.0278  0.1521 -0.1429
-0.0013  0.1474 -0.0608  0.0369
 0.0918 -0.0341 -0.0362 -0.1488
-0.0282  0.0426 -0.0399 -0.0935
-0.1160 -0.0988  0.1350 -0.0890
-0.1656  0.1422  0.1019  0.0856
-0.1620 -0.1654 -0.0807 -0.0896
 0.0627  0.1617 -0.1260 -0.1691
-0.0426  0.1083 -0.1556  0.0760
-0.0000 -0.1739  0.1234  0.1139
-0.1308  0.1436  0.0202 -0.0614
 0.1438  0.0029 -0.0067 -0.1227
-0.1594  0.1422  0.1234 -0.1216
 0.1079  0.1076  0.0730 -0.1175
-0.0154 -0.0616 -0.0158 -0.0259
 0.0562  0.1769 -0.1010 -0.1722
 0.1140  0.0715  0.0905 -0.1785
-0.1777  0.0851 -0.1610  0.0283
 0.0433  0.1041 -0.0089 -0.1305
 0.0489  0.1201  0.0773  0.0726
 0.1334 -0.0256  0.0344 -0.1258
 0.0285  0.0302  0.0424  0.0449
 0.1134  0.1813  0.0188 -0.0871
 0.1153  0.0933  0.1188  0.1037
-0.1115 -0.0752 -0.0764 -0.0690
 0.0633 -0.1803 -0.0695 -0.0906
 0.1126  0.0859 -0.1550

In [18]:
?nn.RNNCell

In [None]:
?nn.Linear

In [None]:
?nn.Softmax

## Choose a loss function

We want a loss function to produce an error value from the output of the model and the expected target (the loss function must be derivable). This error value will be backpropagate along the model to process the model parameters derivatives (gradients). 

For the sake of this tutorial, we will choose MSE (mean square error), but we usually use NLL (negative log likelihood) for classification tasks.

In [20]:
# loss = nn.CrossEntropyLoss()
loss = nn.MSELoss()

In [None]:
?nn.MSELoss

In [None]:
?nn.NLLLoss

In [None]:
?nn.CrossEntropyLoss

## Choose an optimizer

After having processed the gradients of the model parameters, we want an optimizer to update the model parameters of a certain amount (learning rate).

I personnaly find Adam to be easier to optimize than the classical SGD (stochastique gradient descent). Thus, we will use this one in this tutorial.

In [21]:
#optimizer = optim.SGD(model.parameters(), lr=0.00001, momentum=0.7)
optimizer = optim.Adam(model.parameters(), lr=1e-5)

In [None]:
?optim.SGD

In [None]:
?optim.Adam

## Train the RNN

In [22]:
def train_epoch(dataloader, model, loss, optimizer):
    model.train()
    
    error_epoch = 0
    
    # iterate over one sequence at a time
    for i, (input_data, target_data) in enumerate(dataloader):

        # convert data to variable for torch computational graph
        input = Variable(input_data)
        target = Variable(target_data)
        state = Variable(torch.zeros(1, model.state_size))

        # initialize the error to 0
        error_seq = Variable(torch.zeros(1))

        # iterate over the sequence
        seq_len = 0
        for t in range(input.size(1)):

            # doesnt process 0-padded values
            if input[:,t].data.sum() == 0: # end of sequence
                break

            # compute the char at time t (model forward)
            output, state = model(input[:,t], state)

            # compute the error at time t (loss forward)
            if input[:,t].data[0][2] == 1 \ # 1 in dim n°2 -> "b"
               or input[:,t].data[0][3] == 1: # 1 in dim n°3 -> "."
                error_t = loss(output, target[:,t])
                error_seq += error_t
                seq_len += 1

        # takes the sequence length in count (average)
        error_seq /= seq_len
        error_epoch += error_seq.data[0]

        # compute the gradients over the computational graph
        # and update the model parameters
        optimizer.zero_grad()
        error_seq.backward()
        optimizer.step()
    return error_epoch
        

In [37]:
for epoch in range(1000):
    error = train_epoch(dataloader, model, loss, optimizer)
    if epoch%100==0:
        print('epoch: {}\t error: {}'.format(epoch, error))

epoch: 0	 error: 0.34652225486934185
epoch: 100	 error: 0.33967716433107853
epoch: 200	 error: 0.33282516710460186
epoch: 300	 error: 0.32607906870543957
epoch: 400	 error: 0.31950270012021065
epoch: 500	 error: 0.31316556595265865
epoch: 600	 error: 0.306680791079998
epoch: 700	 error: 0.3004270549863577
epoch: 800	 error: 0.29432846419513226
epoch: 900	 error: 0.28866238333284855


## Try out the model

In [42]:
state = Variable(torch.zeros(1, model.state_size), requires_grad=False)
model.eval()

nb_a = 6

for i in range(nb_a):
    input_char = 'a'
    input = Variable(vocab_encoder[input_char].view(1, -1), requires_grad=False)
    output, state = model(input, state)
    output_char = output_decoder(output, vocab)
    #print("input->output\t{}->{}".format(input_char, output_char))
    print("input->output\t{}".format(input_char))
    
input_char = 'X'
input = Variable(vocab_encoder[input_char].view(1, -1), requires_grad=False)
output, state = model(input, state)
output_char = output_decoder(output, vocab)
print("input->output\t{}: {}->{}".format(1, input_char, output_char))
#print(output.data)

for i in range(nb_a):
    input_char = 'b'
    input = Variable(vocab_encoder[input_char].view(1, -1), requires_grad=False)
    output, state = model(input, state)
    output_char = output_decoder(output, vocab)
    print("input->output\t{}: {}->{}".format(i+2, input_char, output_char))
    #print(output.data)




input->output	a
input->output	a
input->output	a
input->output	a
input->output	a
input->output	a
input->output	1: X->b
input->output	2: b->b
input->output	3: b->b
input->output	4: b->b
input->output	5: b->b
input->output	6: b->b
input->output	7: b->.
