In [67]:
%matplotlib inline

import numpy as np
import itertools

import torch
import torch.nn as nn
import torch.optim as optim

import random
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm, tqdm_notebook

In [68]:
# set random seed to 0
np.random.seed(0)
torch.manual_seed(0)
torch.set_default_tensor_type('torch.DoubleTensor')

In [69]:
def create_data(prefix_chars, suffix_chars, prefix_length, suffix_length):
    """
    prefix_chars: iterable
    suffix_chars: iterable
    prefix_length: int
    suffix_length: int
    return: take every possible combination of prefixes of length prefix_length
        and take every possible combination of suffixies of length suffix_length,
        then create every possible combination of those.
        Generates len(prefix_chars)^prefix_length * len(suffix_chars)^suffix_length
    """
    prefixes = list(itertools.product(prefix_chars, repeat=prefix_length))
    suffixes = list(itertools.product(suffix_chars, repeat=suffix_length))
    return [p + s for p, s in itertools.product(prefixes, suffixes)]
    
def vectorize_2d(seq, vocab_size):
    vectors = np.zeros((len(seq), vocab_size))
    for i in range(len(seq)):
        vectors[i][seq[i]] = 1
    return torch.tensor(vectors, dtype=torch.float64, device='cuda:0')

In [70]:
def create_output(seq):
    """
    pointer: index to inspect next (attention direction)
    task1: cur^2 mod n (rule-learning)
    task2: value of right-neighbor (attention direction, holding memory, write vs. no write)
    task3: sum of self and left and right neighbors, mod n (combination of tasks 1 and 2)
    """
    pointer_outputs = []
    task1_outputs = []
    task2_outputs = []
    task3_outputs = []
    
    cur = 0
    for i in range(len(seq)):
        
        task1_out = (seq[cur]**2)%len(seq)
        task2_out = seq[(cur+1)%len(seq)]
        
        task3_prev = seq[(cur-1)%len(seq)]
        task3_next = seq[(cur+1)%len(seq)]
        task3_out = (task3_prev + seq[cur] + task3_next)%len(seq)
        
        cur = seq[cur]
        pointer_outputs.append(cur)
        task1_outputs.append(task1_out)
        task2_outputs.append(task2_out)
        task3_outputs.append(task3_out)
        
    return pointer_outputs, task1_outputs, task2_outputs, task3_outputs

In [71]:
def training_sequence_1(seq):
    """
    Trains goto, task1, and task2. e.g.
    Iter 1: Compute task 1, move to i=1 to perform task 2
        Input: seq[0]
        Goto: 1
        Task 1: seq[0]^2
        Task 2: No-op
    Iter 2: Compute task 2, move to i=seq[0] to perform task 1
        Input: seq[1]
        Goto: seq[0]
        Task 1: No-op
        Task 2: seq[1]
    etc...
    """
    n = len(seq)
    no_op = len(seq)
    pointer_outputs, task1_outputs, task2_outputs, task3_outputs = create_output(seq)
    X, Y1, Y2, Y3 = [], [], [], []
    index = 0
    for i in range(n):
        X.append(seq[index])
        X.append(seq[(index+1)%n])
        Y1.append((index+1)%n)
        Y1.append(pointer_outputs[i])
        Y2.append(task1_outputs[i])
        Y2.append(no_op)
        Y3.append(no_op)
        Y3.append(task2_outputs[i])
        index = pointer_outputs[i]
    return torch.tensor(X, dtype=torch.int64, device='cuda:0'), \
            torch.tensor(Y1, dtype=torch.int64, device='cuda:0'), \
            torch.tensor(Y2, dtype=torch.int64, device='cuda:0'), \
            torch.tensor(Y3, dtype=torch.int64, device='cuda:0')

In [90]:
class Model1(nn.Module):
    def __init__(self, input_dim, pointer_dim, task1_dim, task2_dim, hidden_layer_size):
        super(Model1, self).__init__()
        self.input_dim = input_dim
        self.pointer_dim = pointer_dim
        self.task1_dim = task1_dim
        self.task2_dim = task2_dim
        self.hidden_layer_size = hidden_layer_size
        
        self.lstm = nn.LSTMCell(input_dim, hidden_layer_size).cuda()
        self.lin_pointer = nn.Linear(hidden_layer_size, pointer_dim)
        self.lin_task1 = nn.Linear(hidden_layer_size, task1_dim)
        self.lin_task2 = nn.Linear(hidden_layer_size, task2_dim)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        pointer_outputs, task1_outputs, task2_outputs = [], [], []
        h_t = torch.zeros(1, self.hidden_layer_size, dtype=torch.float64, device='cuda:0')
        c_t = torch.zeros(1, self.hidden_layer_size, dtype=torch.float64, device='cuda:0')
        
        for x_t in x.reshape(len(x), 1, self.input_dim):
            h_t, c_t = self.lstm(x_t, (h_t, c_t))
            pointer_outputs.append(self.softmax(self.lin_pointer(h_t)))
            task1_outputs.append(self.softmax(self.lin_task1(h_t)))
            task2_outputs.append(self.softmax(self.lin_task2(h_t)))
        
        return torch.stack(pointer_outputs).reshape(len(x), self.pointer_dim), \
                torch.stack(task1_outputs).reshape(len(x), self.task1_dim), \
                torch.stack(task2_outputs).reshape(len(x), self.task2_dim)

In [91]:
max_digit = 7
split = 4
train_data = create_data(range(split), range(split, max_digit), split, max_digit-split)
X, Y1, Y2, Y3 = [], [], [], []
for datum in train_data:
    inputs, pointer_outputs, task1_outputs, task2_outputs = training_sequence_1(datum)
    X.append(vectorize_2d(inputs, max_digit))
    Y1.append(pointer_outputs)
    Y2.append(task1_outputs)
    Y3.append(task2_outputs)

In [94]:
model = Model1(max_digit, max_digit, max_digit+1, max_digit+1, 30).cuda()
# model = model.double().cuda()
optimizer = optim.Adam(model.parameters())

closures = []

for i in range(len(X)):
    x, y1, y2, y3 = X[i], Y1[i], Y2[i], Y3[i]
    def closure():
        optimizer.zero_grad()
        pred1, pred2, pred3 = model(x)
        pointer_loss = nn.functional.nll_loss(pred1, pointer_outputs)
        task1_loss = nn.functional.nll_loss(pred2, task1_outputs)
        task2_loss = nn.functional.nll_loss(pred3, task2_outputs)
        loss = pointer_loss + task1_loss + task2_loss

        loss.backward()
        return loss
    closures.append(closure)

HBox(children=(IntProgress(value=0, max=6912), HTML(value='')))

In [96]:
for i in tqdm_notebook(range(10)):
    total_loss = 0
    random.shuffle(closures)
    for closure in tqdm_notebook(closures):
        total_loss += float(optimizer.step(closure))
    print(total_loss)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=6912), HTML(value='')))

-20385.731109705794


HBox(children=(IntProgress(value=0, max=6912), HTML(value='')))

-20735.861148377706


HBox(children=(IntProgress(value=0, max=6912), HTML(value='')))

-20735.996003700664


HBox(children=(IntProgress(value=0, max=6912), HTML(value='')))

-20735.999861477227


HBox(children=(IntProgress(value=0, max=6912), HTML(value='')))

KeyboardInterrupt: 

In [81]:
a =  nn.LSTMCell(4,4).cuda()

In [146]:
def roll(seq, n):
    return torch.cat((seq[-n:], seq[:-n]))

In [148]:
roll(a, 3)

tensor([[7., 7., 7., 7.],
        [8., 8., 8., 8.],
        [9., 9., 9., 9.],
        [0., 0., 0., 0.],
        [1., 1., 1., 1.],
        [2., 2., 2., 2.],
        [3., 3., 3., 3.],
        [4., 4., 4., 4.],
        [5., 5., 5., 5.],
        [6., 6., 6., 6.]])

In [138]:
a = torch.stack([torch.ones(4)*i for i in range(10)])
b = torch.stack([torch.ones(4)*-i for i in range(10)])

In [154]:
torch.stack((a, roll(a, 1)), dim=1).reshape(10, 8)

tensor([[0., 0., 0., 0., 9., 9., 9., 9.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [2., 2., 2., 2., 1., 1., 1., 1.],
        [3., 3., 3., 3., 2., 2., 2., 2.],
        [4., 4., 4., 4., 3., 3., 3., 3.],
        [5., 5., 5., 5., 4., 4., 4., 4.],
        [6., 6., 6., 6., 5., 5., 5., 5.],
        [7., 7., 7., 7., 6., 6., 6., 6.],
        [8., 8., 8., 8., 7., 7., 7., 7.],
        [9., 9., 9., 9., 8., 8., 8., 8.]])

In [135]:
c

[tensor([0., 0., 0., 0.]),
 tensor([1., 1., 1., 1.]),
 tensor([2., 2., 2., 2.]),
 tensor([3., 3., 3., 3.]),
 tensor([4., 4., 4., 4.]),
 tensor([5., 5., 5., 5.]),
 tensor([6., 6., 6., 6.]),
 tensor([7., 7., 7., 7.]),
 tensor([8., 8., 8., 8.]),
 tensor([9., 9., 9., 9.]),
 tensor([0., 0., 0., 0.]),
 tensor([-1., -1., -1., -1.]),
 tensor([-2., -2., -2., -2.]),
 tensor([-3., -3., -3., -3.]),
 tensor([-4., -4., -4., -4.]),
 tensor([-5., -5., -5., -5.]),
 tensor([-6., -6., -6., -6.]),
 tensor([-7., -7., -7., -7.]),
 tensor([-8., -8., -8., -8.]),
 tensor([-9., -9., -9., -9.])]

In [143]:
torch.stack((a, b), dim=1).shape

torch.Size([10, 2, 4])

In [145]:
torch.stack((a, b), dim=1).reshape(10, 8)

tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 1.,  1.,  1.,  1., -1., -1., -1., -1.],
        [ 2.,  2.,  2.,  2., -2., -2., -2., -2.],
        [ 3.,  3.,  3.,  3., -3., -3., -3., -3.],
        [ 4.,  4.,  4.,  4., -4., -4., -4., -4.],
        [ 5.,  5.,  5.,  5., -5., -5., -5., -5.],
        [ 6.,  6.,  6.,  6., -6., -6., -6., -6.],
        [ 7.,  7.,  7.,  7., -7., -7., -7., -7.],
        [ 8.,  8.,  8.,  8., -8., -8., -8., -8.],
        [ 9.,  9.,  9.,  9., -9., -9., -9., -9.]])

In [221]:
mat, indices = round_robin_concat(torch.stack([torch.ones(4)*i for i in range(4)]))

In [222]:
mat.shape

torch.Size([8, 8])

In [223]:
mat

tensor([[0., 0., 0., 0., 3., 3., 3., 3.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [2., 2., 2., 2., 1., 1., 1., 1.],
        [3., 3., 3., 3., 2., 2., 2., 2.],
        [0., 0., 0., 0., 2., 2., 2., 2.],
        [1., 1., 1., 1., 3., 3., 3., 3.],
        [2., 2., 2., 2., 0., 0., 0., 0.],
        [3., 3., 3., 3., 1., 1., 1., 1.]])

In [208]:
indices

{(0, 1): 1,
 (0, 2): 12,
 (0, 3): 23,
 (0, 4): 34,
 (0, 5): 45,
 (0, 6): 30,
 (0, 7): 20,
 (0, 8): 10,
 (0, 9): 0,
 (1, 0): 1,
 (1, 2): 2,
 (1, 3): 13,
 (1, 4): 24,
 (1, 5): 35,
 (1, 6): 46,
 (1, 7): 31,
 (1, 8): 21,
 (1, 9): 11,
 (2, 0): 12,
 (2, 1): 2,
 (2, 3): 3,
 (2, 4): 14,
 (2, 5): 25,
 (2, 6): 36,
 (2, 7): 47,
 (2, 8): 32,
 (2, 9): 22,
 (3, 0): 23,
 (3, 1): 13,
 (3, 2): 3,
 (3, 4): 4,
 (3, 5): 15,
 (3, 6): 26,
 (3, 7): 37,
 (3, 8): 48,
 (3, 9): 33,
 (4, 0): 34,
 (4, 1): 24,
 (4, 2): 14,
 (4, 3): 4,
 (4, 5): 5,
 (4, 6): 16,
 (4, 7): 27,
 (4, 8): 38,
 (4, 9): 49,
 (5, 0): 45,
 (5, 1): 35,
 (5, 2): 25,
 (5, 3): 15,
 (5, 4): 5,
 (5, 6): 6,
 (5, 7): 17,
 (5, 8): 28,
 (5, 9): 39,
 (6, 0): 30,
 (6, 1): 46,
 (6, 2): 36,
 (6, 3): 26,
 (6, 4): 16,
 (6, 5): 6,
 (6, 7): 7,
 (6, 8): 18,
 (6, 9): 29,
 (7, 0): 20,
 (7, 1): 31,
 (7, 2): 47,
 (7, 3): 37,
 (7, 4): 27,
 (7, 5): 17,
 (7, 6): 7,
 (7, 8): 8,
 (7, 9): 19,
 (8, 0): 10,
 (8, 1): 21,
 (8, 2): 32,
 (8, 3): 48,
 (8, 4): 38,
 (8, 5): 28,
 (

In [209]:
mat[indices[(3, 2)]]

tensor([3., 3., 3., 3., 2., 2., 2., 2.])

In [210]:
mat[indices[(2, 3)]]

tensor([3., 3., 3., 3., 2., 2., 2., 2.])

In [211]:
mat[indices[(0, 1)]]

tensor([1., 1., 1., 1., 0., 0., 0., 0.])