Problem statement:
https://blog.openai.com/requests-for-research-2/

Train an LSTM to solve the XOR problem: that is, given a sequence of bits, determine its parity. The LSTM should consume the sequence, one bit at a time, and then output the correct answer at the sequence’s end. Test the two approaches below:

* Generate a dataset of random 100,000 binary strings of length 50. Train the LSTM; what performance do you get?
* Generate a dataset of random 100,000 binary strings, where the length of each string is independently and randomly chosen between 1 and 50. Train the LSTM. Does it succeed? What explains the difference?

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np

Create 100k strings of length 50, and a function to check the parity

In [2]:
np.random.seed(1)
n_examples=100000
example_length=50
data_raw=np.random.randint(2, size=(n_examples, example_length, 1))
data_tensor=torch.Tensor(data_raw)
def get_parity(arr):
    p=np.sum(arr) % 2
    return int(p)

Pull a random example from the list. Truncate to length `d_max` (so we can test or train with shorter strings), and also allow a `random_truncate` flag which will truncate to a length uniformally distributed from 0 to d_max.

Outputs are variables, parity should be integer tensor

In [3]:
def getRandomExample(d_max=example_length, random_truncate=True):
    l=np.random.randint(n_examples)
    d=min(d_max, example_length)
    if (random_truncate):
        d=np.random.randint(d)
    tensor=torch.zeros(d+1, 1, 2)
    example_data=data_raw[l,:d+1]
    tensor[:,0,0]=1-torch.Tensor(example_data)
    tensor[:,0,1]=torch.Tensor(example_data)
    return Variable(torch.LongTensor([get_parity(example_data)])), Variable(tensor)

LSTM followed by a linear layer to squish or expand to output size, then logsoftmax to produce probabilities.

Note the LSTM can take the whole string as input and will apply iteratively. It also returns a value for every stage, which is why we take the last value from its `output` for the linear layer.

In [4]:
class LSTM(nn.Module):
    def __init__(self, input_size, output_size, n_hidden, n_layers):
        super(LSTM, self).__init__()
        self.rnn = nn.LSTM(input_size, n_hidden, n_layers)
        self.lin = nn.Linear(n_hidden, output_size)
        self.softmax = nn.LogSoftmax(dim=1) 
        self.input_size = input_size
        self.output_size = output_size
        self.n_hidden = n_hidden
        self.n_layers = n_layers
    def forward(self, input, hidden):
        output, hn = self.rnn(input, hidden)
        out = self.softmax(self.lin(output[-1]))
        return out, hn
    def init_hidden(self):
        h0 = Variable(torch.zeros(self.n_layers, 1, self.n_hidden))
        c0 = Variable(torch.zeros(self.n_layers, 1, self.n_hidden))
        return (h0, c0)

The function we're trying to model is very simple, we only need one hidden cell and one layer.

In [14]:
n_hidden = 1
n_layers = 1
n_batch = 1
input_size = 2
output_size=2
np.random.seed(1)
model=LSTM(input_size, output_size, n_hidden, n_layers)
loss_function = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

Run the training! We count a result as correct if the output of the right answer is > 50%. 

This training starts at a string length of 1 and increments the string length once more than 95% of results are correct. It always uses the length randomization. This leads to *much* faster training.

In [15]:
print_freq = 100
np.random.seed(1)
correct = 0
count = 0
total_loss = 0.0
dmax=1
for i in range(0,10000):
    model.zero_grad()
    target, input = getRandomExample(d_max=dmax)
    hidden = model.init_hidden()
    output, hn = model(input, hidden)
    loss = loss_function(output, target)
    loss.backward()
    total_loss += loss.data[0]
    optimizer.step()
    if (np.exp(output.data[0,target.data[0]]) > 0.5):
        correct += 1
    count += 1
    if (i % print_freq == print_freq-1): 
        print('{0}[{1}]: {2}/{3} loss: {4:.3f}'.format(i, dmax, correct, count, total_loss))
        if (correct/count > 0.95):
            dmax=min(dmax+1,example_length)
        correct=0
        count=0
        total_loss=0

99[1]: 59/100 loss: 65.654
199[1]: 92/100 loss: 50.949
299[1]: 100/100 loss: 18.023
399[2]: 80/100 loss: 57.065
499[2]: 86/100 loss: 46.193
599[2]: 88/100 loss: 42.437
699[2]: 89/100 loss: 38.163
799[2]: 90/100 loss: 33.627
899[2]: 86/100 loss: 38.834
999[2]: 85/100 loss: 37.220
1099[2]: 88/100 loss: 31.482
1199[2]: 91/100 loss: 26.364
1299[2]: 87/100 loss: 29.151
1399[2]: 86/100 loss: 25.314
1499[2]: 90/100 loss: 21.643
1599[2]: 94/100 loss: 12.065
1699[2]: 94/100 loss: 15.891
1799[2]: 100/100 loss: 11.709
1899[3]: 89/100 loss: 32.375
1999[3]: 92/100 loss: 30.304
2099[3]: 100/100 loss: 11.724
2199[4]: 100/100 loss: 9.433
2299[5]: 100/100 loss: 8.575
2399[6]: 100/100 loss: 6.932
2499[7]: 100/100 loss: 5.786
2599[8]: 100/100 loss: 4.712
2699[9]: 100/100 loss: 4.650
2799[10]: 100/100 loss: 4.334
2899[11]: 100/100 loss: 3.718
2999[12]: 100/100 loss: 3.614
3099[13]: 100/100 loss: 3.424
3199[14]: 100/100 loss: 2.362
3299[15]: 100/100 loss: 2.812
3399[16]: 100/100 loss: 2.533
3499[17]: 100/1