NAC and NALU are as defined in https://arxiv.org/abs/1808.00508

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm_notebook as tqdm
from torch.nn.init import kaiming_normal_
from IPython.display import display

In [2]:
class NAC(nn.Module):
    
    def __init__(self, input_size, output_size):
        super(NAC, self).__init__()
        # W_hat and M_hat - the params of the NAC model
        self.W_hat = nn.Parameter(data=torch.Tensor(output_size, input_size))
        self.M_hat = nn.Parameter(data=torch.Tensor(output_size, input_size))
        # NAC doesn't converge nearly as fast without this way of init!
        kaiming_normal_(self.W_hat)
        kaiming_normal_(self.M_hat)
    
    def W(self):
        # W is computed from W_hat and M_hat params
        return torch.tanh(self.W_hat) * torch.sigmoid(self.M_hat)
    
    def forward(self, x):
        # Affine transform of x as defined by W
        return F.linear(input=x, weight=self.W(), bias=None)

In [3]:
class NALU(nn.Module):
    
    def __init__(self, input_size, output_size):
        super(NALU, self).__init__()
        # Shared NAC
        self.nac = NAC(input_size, output_size)
        # Global gate params that decide between add/sub VS mul/div operation
        self.G = nn.Linear(input_size, 1)
        self.eps = 1e-10
    
    def g(self, x):
        return torch.sigmoid(self.G(x))
    
    def forward(self, x):
        # Gate (decide add/sub VS mul/div operation)
        g = self.g(x)
        # Result of add/sub
        a = self.nac(x)
        # Result of mul/div
        log_space_x = torch.log(abs(x) + self.eps)
        log_space_output = F.linear(input=log_space_x, weight=self.nac.W(), bias=None)
        m = torch.exp(log_space_output)
        # Interpolated result between a and m
        return g * a + (1-g) * m

In [4]:
examples = []
for k in range(500):
    # Generate training "input" between range 1 and 6
    a_b = torch.randint(low=1, high=6, size=(1, 2)).type(torch.FloatTensor)
    # I. Training "labels" to induce an "adder" function
    c =   a_b[:, 0] + a_b[:, 1]
#     # II. Training "labels" to induce a "subtractor" function
#     c =   a_b[:, 0] - a_b[:, 1]
#     # III. Training "labels" to induce a "multiplier" function
#     c =   a_b[:, 0] * a_b[:, 1]
#     # IV. Training "labels" to induce an "divider" function
#     c =   a_b[:, 0] / a_b[:, 1]
    examples.append((a_b, c))

for eg in examples[:3]:  print(eg)

(tensor([[3., 1.]]), tensor([4.]))
(tensor([[5., 3.]]), tensor([8.]))
(tensor([[2., 2.]]), tensor([4.]))


In [5]:
# Init NALU
nalu = NALU(input_size=2, output_size=1)  # Takes two inputs and returns one output

# Define loss function and optimizer
loss_function = nn.MSELoss()
optimizer = optim.SGD(nalu.parameters(), lr=0.03)

In [6]:
# Training loop
nalu.train()

EPOCHS = 30

bar_1 = tqdm(range(EPOCHS))
for epoch in bar_1:
    total_loss = 0
    for a_b, c in examples:
        c_pred = nalu(a_b).view(1)
        
        nalu.zero_grad()

        loss = loss_function(c_pred, c)
        
        loss.backward()
        
        optimizer.step()
        
        total_loss += loss
    bar_1.set_description("loss: %0.4f" % total_loss)


HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




In [7]:
# Generic arithmetic function induced from data!

def my_func(a, b):
    
    mode_map = {
        0: 'div/mul',
        1: 'add/sub'
    }
    
    x = torch.Tensor([a, b])
    nalu.eval()
    print('NAC\'s ~W:')
    display(nalu.nac.W().data.numpy().round())
    print('NALU\'s mode:')
    print(mode_map[int(nalu.g(x).data.numpy().round())])
    return float(nalu(x))

In [8]:
a = 100
b = 10

my_func(a, b)

NAC's ~W:


array([[1., 1.]], dtype=float32)

NALU's mode:
add/sub


109.110595703125