In [69]:
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.optim import Adam, SGD
import numpy as np

In [70]:
class NAC(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.W_hat = nn.Parameter(torch.Tensor(self.out_dim, self.in_dim))
        self.M_hat = nn.Parameter(torch.Tensor(self.out_dim, self.in_dim))
        nn.init.xavier_normal_(self.W_hat)
        nn.init.xavier_normal_(self.M_hat)
        self.bias = None
        
    def forward(self, x):
        W = torch.tanh(self.W_hat) * torch.sigmoid(self.M_hat)
        return F.linear(x, W, self.bias)

In [93]:
class NALU(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
#         self.G = nn.Linear(self.in_dim,
#                            self.out_dim,
#                            bias=False)
        self.G = nn.Parameter(torch.Tensor(1, 1))
        nn.init.xavier_normal_(self.G)
        self.nac = NAC(self.in_dim, self.out_dim)
        self.eps = 1e-10

    def forward(self, x):
        a = self.nac(x)
        g = torch.sigmoid(self.G)
        m = self.nac(torch.log(torch.abs(x) + self.eps))
        m = torch.exp(m)
        y = (g * a) + (1 - g) * m
        return y

In [82]:
eps = 1e-12
X_train = np.random.uniform(-5, 5+eps, size=(2000, 2))
Y_train = X_train[:, 0] * X_train[:, 1]

X_valid = np.random.uniform(-5, 5+eps, size=(500, 2))
Y_valid = X_valid[:, 0] * X_valid[:, 1]

X_test = np.random.uniform(-50, 50+eps, size=(2000, 2))
Y_test = X_test[:, 0] * X_test[:, 1]

In [94]:
def get_batches(data, target, batch_size, mode='test', use_gpu=False):
    idx = np.arange(0, data.shape[0])
    
    if mode == 'train':
        np.random.shuffle(idx)

    while idx.shape[0] > 0:
        batch_idx = idx[:batch_size]
        idx = idx[batch_size:]
        batch_data = data[batch_idx]
        batch_target = target[batch_idx]
        
        batch_data = torch.from_numpy(batch_data).float()
        batch_target = torch.from_numpy(batch_target).float().view(-1, 1)
        
        if use_gpu:
            batch_data = batch_data.cuda()
            batch_target = batch_target.cuda()
        
        yield batch_data, batch_target

In [95]:
def get_eval_loss(model, criterion, data, targets, use_gpu=False):
    preds, targets = get_eval_preds(model, data, targets, use_gpu)
    loss = criterion(preds, targets)
    return loss.item()

In [96]:
def get_eval_preds(model, data, targets, use_gpu=False):
    with torch.no_grad():
        model.eval()
        model_preds = []
        tensor_targets = []
        for x, y in get_batches(X_valid, Y_valid, batch_size,
                                mode='test', use_gpu=use_gpu):
            model_preds.append(model(x))
            tensor_targets.append(y)
        model_preds = torch.cat(model_preds, dim=0)
        tensor_targets = torch.cat(tensor_targets, dim=0)
    return model_preds, tensor_targets

In [98]:
batch_size = 32
patience = 5
running_patience = 5
checkpoint = 'best_model.sav'
print_every = 200
num_epochs = 2000
running_batch = 0
running_loss = 0
min_loss = float('inf')
use_gpu = torch.cuda.is_available()

criterion = nn.SmoothL1Loss()

model = NALU(2, 1)
if use_gpu:
    model = model.cuda()
optimizer = SGD(model.parameters(), lr=0.1)

for epoch in range(num_epochs):
    model.train()
    for x, y in get_batches(X_train, Y_train, batch_size,
                            mode='train', use_gpu=use_gpu):
        output = model(x)
        loss = criterion(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        running_batch += 1
        
#         if running_batch % print_every == 0:
#             print('Training loss after {} batches: {}'.format(running_batch, running_loss/running_batch))
            
    valid_loss = get_eval_loss(model, criterion,
                              X_valid, Y_valid, False)
    print("Validation loss after epoch {}: {}".format(epoch, valid_loss))
    if valid_loss < min_loss:
        min_loss = valid_loss
        print('Validation loss improved! Saving model.')
        with open(checkpoint, 'wb') as f:
            torch.save(model.state_dict(), f)
            running_patience = patience
    else:
        running_patience -= 1
    if running_patience == 0:
        print('Ran out of patience, early stopping employed!')
        break

Validation loss after epoch 0: 5.484178066253662
Validation loss improved! Saving model.
Validation loss after epoch 1: 5.477092742919922
Validation loss improved! Saving model.
Validation loss after epoch 2: 5.475063323974609
Validation loss improved! Saving model.
Validation loss after epoch 3: 5.471747398376465
Validation loss improved! Saving model.
Validation loss after epoch 4: 5.469732761383057
Validation loss improved! Saving model.
Validation loss after epoch 5: 5.467689037322998
Validation loss improved! Saving model.
Validation loss after epoch 6: 5.465944766998291
Validation loss improved! Saving model.
Validation loss after epoch 7: 5.463921070098877
Validation loss improved! Saving model.
Validation loss after epoch 8: 5.462299346923828
Validation loss improved! Saving model.
Validation loss after epoch 9: 5.460833549499512
Validation loss improved! Saving model.
Validation loss after epoch 10: 5.459917068481445
Validation loss improved! Saving model.
Validation loss afte

In [87]:
model.load_state_dict(torch.load(checkpoint))

In [88]:
test_loss = get_eval_loss(model, criterion,
                          X_test, Y_test, False)

In [89]:
test_loss

4.807424545288086

In [90]:
test_preds, test_targets = get_eval_preds(model, X_test, Y_test, False)

In [91]:
test_preds = test_preds.cpu().numpy().flatten()
test_targets = test_targets.cpu().numpy().flatten()

In [92]:
accuracy = np.isclose(test_preds, test_targets, rtol=1e-4).astype(np.int32).mean()
accuracy

0.0

In [66]:
torch.tanh(model.nac.W_hat)

tensor([[0.2643, 0.1579]], grad_fn=<TanhBackward>)

In [67]:
torch.sigmoid(model.nac.M_hat)

tensor([[0.7959, 0.4905]], grad_fn=<SigmoidBackward>)

In [39]:
model.G.weight

Parameter containing:
tensor([[3.0545, 0.0673]], requires_grad=True)