In [1]:
import torch
from torch import nn as nn
from models.nac import NAC
from models.nalu import NALU
from torch.optim import Adam, SGD
import numpy as np

In [2]:
# class NAC(nn.Module):
#     def __init__(self, in_dim, out_dim):
#         super().__init__()
#         self.in_dim = in_dim
#         self.out_dim = out_dim
#         self.W_hat = nn.Parameter(torch.Tensor(self.out_dim, self.in_dim))
#         self.M_hat = nn.Parameter(torch.Tensor(self.out_dim, self.in_dim))
#         nn.init.xavier_normal_(self.W_hat)
#         nn.init.xavier_normal_(self.M_hat)
#         self.bias = None
        
#     def forward(self, x):
#         W = torch.tanh(self.W_hat) * torch.sigmoid(self.M_hat)
#         return F.linear(x, W, self.bias)

In [3]:
# class NALU(nn.Module):
#     def __init__(self, in_dim, out_dim):
#         super().__init__()
#         self.in_dim = in_dim
#         self.out_dim = out_dim
#         self.G = nn.Parameter(torch.Tensor(1, 1))
#         nn.init.xavier_normal_(self.G)
#         self.nac = NAC(self.in_dim, self.out_dim)
#         self.eps = 1e-12

#     def forward(self, x):
#         a = self.nac(x)
#         g = torch.sigmoid(self.G)
#         m = self.nac(torch.log(torch.abs(x) + self.eps))
#         m = torch.exp(m)
#         y = (g * a) + (1 - g) * m
#         return y

In [4]:
eps = 1e-12
X_train = np.random.uniform(-5, 5+eps, size=(2000, 2))
Y_train = X_train[:, 0] * X_train[:, 1]

X_valid = np.random.uniform(-5, 5+eps, size=(500, 2))
Y_valid = X_valid[:, 0] * X_valid[:, 1]

X_test = np.random.uniform(-50, 50+eps, size=(2000, 2))
Y_test = X_test[:, 0] * X_test[:, 1]

In [5]:
def get_batches(data, target, batch_size, mode='test', use_gpu=False):
    idx = np.arange(0, data.shape[0])
    
    if mode == 'train':
        np.random.shuffle(idx)

    while idx.shape[0] > 0:
        batch_idx = idx[:batch_size]
        idx = idx[batch_size:]
        batch_data = data[batch_idx]
        batch_target = target[batch_idx]
        
        batch_data = torch.from_numpy(batch_data).float()
        batch_target = torch.from_numpy(batch_target).float().view(-1, 1)
        
        if use_gpu:
            batch_data = batch_data.cuda()
            batch_target = batch_target.cuda()
        
        yield batch_data, batch_target

In [6]:
def get_eval_loss(model, criterion, data, targets, use_gpu=False):
    preds, targets = get_eval_preds(model, data, targets, use_gpu)
    loss = criterion(preds, targets)
    return loss.item()

In [7]:
def get_eval_preds(model, data, targets, use_gpu=False):
    with torch.no_grad():
        model.eval()
        model_preds = []
        tensor_targets = []
        for x, y in get_batches(data, targets, batch_size,
                                mode='test', use_gpu=use_gpu):
            model_preds.append(model(x))
            tensor_targets.append(y)
        model_preds = torch.cat(model_preds, dim=0)
        tensor_targets = torch.cat(tensor_targets, dim=0)
    return model_preds, tensor_targets

In [8]:
batch_size = 32
patience = 15
running_patience = 5
checkpoint = 'best_model.sav'
print_every = 200
num_epochs = 5000
running_batch = 0
running_loss = 0
min_loss = float('inf')
use_gpu = torch.cuda.is_available()

criterion = nn.SmoothL1Loss()

model = NALU(2, 1)
if use_gpu:
    model = model.cuda()
optimizer = Adam(model.parameters())

for epoch in range(num_epochs):
    model.train()
    for x, y in get_batches(X_train, Y_train, batch_size,
                            mode='train', use_gpu=use_gpu):
        output = model(x)
        loss = criterion(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        running_batch += 1
        
#         if running_batch % print_every == 0:
#             print('Training loss after {} batches: {}'.format(running_batch, running_loss/running_batch))
            
    valid_loss = get_eval_loss(model, criterion,
                              X_valid, Y_valid, False)
    print("Validation loss after epoch {}: {}".format(epoch, valid_loss))
    if valid_loss < min_loss:
        min_loss = valid_loss
        print('Validation loss improved! Saving model.')
        with open(checkpoint, 'wb') as f:
            torch.save(model.state_dict(), f)
            running_patience = patience
    else:
        running_patience -= 1
    if running_patience == 0:
        print('Ran out of patience, early stopping employed!')
        break

Validation loss after epoch 0: 5.76000452041626
Validation loss improved! Saving model.
Validation loss after epoch 1: 5.758058071136475
Validation loss improved! Saving model.
Validation loss after epoch 2: 5.756496429443359
Validation loss improved! Saving model.
Validation loss after epoch 3: 5.755018711090088
Validation loss improved! Saving model.
Validation loss after epoch 4: 5.753583908081055
Validation loss improved! Saving model.
Validation loss after epoch 5: 5.752229690551758
Validation loss improved! Saving model.
Validation loss after epoch 6: 5.7509260177612305
Validation loss improved! Saving model.
Validation loss after epoch 7: 5.749770641326904
Validation loss improved! Saving model.
Validation loss after epoch 8: 5.74862003326416
Validation loss improved! Saving model.
Validation loss after epoch 9: 5.74719762802124
Validation loss improved! Saving model.
Validation loss after epoch 10: 5.746130466461182
Validation loss improved! Saving model.
Validation loss after 

Validation loss after epoch 95: 5.648275375366211
Validation loss improved! Saving model.
Validation loss after epoch 96: 5.647966384887695
Validation loss improved! Saving model.
Validation loss after epoch 97: 5.647377014160156
Validation loss improved! Saving model.
Validation loss after epoch 98: 5.647089004516602
Validation loss improved! Saving model.
Validation loss after epoch 99: 5.646607875823975
Validation loss improved! Saving model.
Validation loss after epoch 100: 5.646136283874512
Validation loss improved! Saving model.
Validation loss after epoch 101: 5.645874500274658
Validation loss improved! Saving model.
Validation loss after epoch 102: 5.645546913146973
Validation loss improved! Saving model.
Validation loss after epoch 103: 5.645334243774414
Validation loss improved! Saving model.
Validation loss after epoch 104: 5.644866466522217
Validation loss improved! Saving model.
Validation loss after epoch 105: 5.644613742828369
Validation loss improved! Saving model.
Vali

In [9]:
model.load_state_dict(torch.load(checkpoint))

In [10]:
test_loss = get_eval_loss(model, criterion,
                          X_test, Y_test, False)

In [11]:
test_loss

615.465087890625

In [12]:
test_preds, test_targets = get_eval_preds(model, X_test, Y_test, False)

In [13]:
test_preds = test_preds.cpu().numpy().flatten()
test_targets = test_targets.cpu().numpy().flatten()

In [14]:
accuracy = np.isclose(test_preds, test_targets, rtol=1e-4).astype(np.int32).mean()
accuracy

0.0

In [15]:
torch.tanh(model.nac.W_hat)

tensor([[-0.4055, -0.1050]], grad_fn=<TanhBackward>)

In [16]:
torch.sigmoid(model.nac.M_hat)

tensor([[0.5964, 0.3760]], grad_fn=<SigmoidBackward>)

In [19]:
test_targets[:10]

array([2004.8478   , 1297.4943   , -470.1426   , -106.39968  ,
       1777.3282   , 1015.44104  ,  770.76166  ,  146.34264  ,
          2.3569129,  649.5627   ], dtype=float32)

In [20]:
test_preds[:10]

array([11.102668 , -9.007928 ,  8.301753 ,  2.3614633, 10.970412 ,
        6.549237 ,  8.244163 ,  2.596319 ,  0.5792012,  5.3412333],
      dtype=float32)

In [21]:
X_test[:10]

array([[-47.09977046, -42.56597659],
       [ 38.88390069,  33.36841774],
       [-42.14832541,  11.15447884],
       [-12.55681032,   8.47346406],
       [-47.26149487, -37.60626413],
       [-25.14039153, -40.39082043],
       [-36.60591108, -21.05566128],
       [ -9.89400027, -14.7910476 ],
       [ -2.00640687,  -1.17469336],
       [-20.68032595, -31.40969368]])

In [22]:
Y_test[:5]

array([2004.84772673, 1297.49424143, -470.14260403, -106.39968089,
       1777.32825938])