In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import math
import random
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from models import *

In [3]:
def train(model, optimizer, data, target, num_iters):
    for i in range(num_iters):
        out = model(data)
        loss = F.mse_loss(out, target)
        mea = torch.mean(torch.abs(target - out))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 1000 == 0:
            print("\t{}/{}: loss: {:.3f} - mea: {:.3f}".format(
                i+1, num_iters, loss.item(), mea.item())
            )

## Permutation

In [4]:
# permute the first column with the third

A = torch.from_numpy(np.array([
    [0, 1, -1],
    [3, -1, 1],
    [1, 1, -2],
])).float()

B = torch.from_numpy(np.array([
    [-1, 1, -0],
    [1, -1, 3],
    [-2, 1, 1],
])).float()

P = torch.from_numpy(np.array([
    [0, 0, 1],
    [0, 1, 0],
    [1, 0, 0],
])).float()

assert torch.allclose(torch.matmul(A, P), B)

In [5]:
net = NeuralAccumulatorCell(3, 3)
optim = torch.optim.RMSprop(net.parameters(), lr=1e-2)

train(net, optim, A, B, int(1e4))

	1/10000: loss: 4.414 - mea: 1.841
	1001/10000: loss: 0.001 - mea: 0.023
	2001/10000: loss: 0.000 - mea: 0.003
	3001/10000: loss: 0.000 - mea: 0.001
	4001/10000: loss: 0.000 - mea: 0.000
	5001/10000: loss: 0.000 - mea: 0.000
	6001/10000: loss: 0.000 - mea: 0.000
	7001/10000: loss: 0.000 - mea: 0.000
	8001/10000: loss: 0.000 - mea: 0.000
	9001/10000: loss: 0.000 - mea: 0.000


In [6]:
W = torch.tanh(net.W_hat) * torch.sigmoid(net.M_hat)

print("actual: \n{}\n".format(W.transpose(0, 1).data))
print("expected: \n{}\n".format(P))

actual: 
tensor([[-1.1790e-05, -3.5627e-06,  9.9999e-01],
        [-4.8258e-05,  9.9994e-01, -3.1704e-05],
        [ 9.9996e-01, -3.1261e-05, -3.0030e-05]])

expected: 
tensor([[ 0.,  0.,  1.],
        [ 0.,  1.,  0.],
        [ 1.,  0.,  0.]])



## Column Scaling

A single NAC cell can't learn column scaling since the weight matrix is constrained to -1, 0 or 1.

In [7]:
# scale the first column by 5

A = torch.from_numpy(np.array([
    [0, 1, -1],
    [3, -1, 1],
    [1, 1, -2],
])).float()

B = torch.from_numpy(np.array([
    [0, 1, -1],
    [15, -1, 1],
    [5, 1, -2],
])).float()

P = torch.from_numpy(np.array([
    [5, 0, 0],
    [0, 1, 0],
    [0, 0, 1],
])).float()

assert torch.allclose(torch.matmul(A, P), B)

In [8]:
net = NAC(2, 3, 3, 3)
optim = torch.optim.RMSprop(net.parameters(), lr=1e-3)

train(net, optim, A, B, int(1e4))

	1/10000: loss: 23.243 - mea: 2.837
	1001/10000: loss: 10.816 - mea: 1.681
	2001/10000: loss: 5.834 - mea: 1.251
	3001/10000: loss: 4.020 - mea: 1.048
	4001/10000: loss: 3.358 - mea: 0.963
	5001/10000: loss: 3.105 - mea: 0.930
	6001/10000: loss: 3.005 - mea: 0.917
	7001/10000: loss: 2.966 - mea: 0.912
	8001/10000: loss: 2.950 - mea: 0.910
	9001/10000: loss: 2.943 - mea: 0.909


## Column Elimination

In [9]:
def basis_vec(k, n):
    """Creates the k'th standard basis vector in R^n."""
    error_msg = "[!] k cannot exceed {}.".format(n)
    assert (k < n), error_msg
    b = np.zeros([n, 1])
    b[k] = 1
    return b

In [10]:
# add -3x the second column to the first => P = (I - (c)(e_k)(e_l.T))

A = torch.from_numpy(np.array([
    [3, 1, -1],
    [3, -1, 1],
    [1, 1, -2],
])).float()

B = torch.from_numpy(np.array([
    [0, 1, -1],
    [6, -1, 1],
    [-2, 1, -2],
])).float()

P = torch.from_numpy(
    np.eye(3) + (-3)*basis_vec(1, 3).dot(basis_vec(0, 3).T)
).float()

assert torch.allclose(torch.matmul(A, P), B)

In [11]:
net = NAC(2, 3, 3, 3)
optim = torch.optim.RMSprop(net.parameters(), lr=1e-3)

train(net, optim, A, B, int(1e4))

	1/10000: loss: 5.194 - mea: 1.862
	1001/10000: loss: 0.520 - mea: 0.595
	2001/10000: loss: 0.073 - mea: 0.226
	3001/10000: loss: 0.012 - mea: 0.080
	4001/10000: loss: 0.002 - mea: 0.028
	5001/10000: loss: 0.000 - mea: 0.012
	6001/10000: loss: 0.000 - mea: 0.006
	7001/10000: loss: 0.000 - mea: 0.003
	8001/10000: loss: 0.000 - mea: 0.002
	9001/10000: loss: 0.000 - mea: 0.001
