In [73]:
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.optim import Adam, SGD
import numpy as np

In [74]:
class NAC(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.W_hat = nn.Parameter(torch.Tensor(self.out_dim, self.in_dim))
        self.M_hat = nn.Parameter(torch.Tensor(self.out_dim, self.in_dim))
        nn.init.xavier_normal_(self.W_hat)
        nn.init.xavier_normal_(self.M_hat)
        self.bias = None
        
    def forward(self, x):
        W = torch.tanh(self.W_hat) * torch.sigmoid(self.M_hat)
        return F.linear(x, W, self.bias)

In [135]:
class NALU(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
#         self.G = nn.Linear(self.in_dim,
#                            self.out_dim,
#                            bias=False)
        self.G = nn.Parameter(torch.Tensor(1, 1))
        nn.init.xavier_normal_(self.G)
        self.nac = NAC(self.in_dim, self.out_dim)
        self.eps = 1e-10

    def forward(self, x):
        a = self.nac(x)
        g = torch.sigmoid(self.G)
        m = self.nac(torch.log(torch.abs(x) + self.eps))
        m = torch.exp(m)
        y = (g * a) + (1 - g) * m
        return y

In [136]:
eps = 1e-12
X_train = np.random.uniform(-5, 5+eps, size=(2000, 2))
Y_train = X_train[:, 0] * X_train[:, 1]

X_valid = np.random.uniform(-5, 5+eps, size=(500, 2))
Y_valid = X_valid[:, 0] * X_valid[:, 1]

X_test = np.random.uniform(-50, 50+eps, size=(2000, 2))
Y_test = X_test[:, 0] * X_test[:, 1]

In [137]:
def get_batches(data, target, batch_size, mode='test', use_gpu=False):
    idx = np.arange(0, data.shape[0])
    
    if mode == 'train':
        np.random.shuffle(idx)

    while idx.shape[0] > 0:
        batch_idx = idx[:batch_size]
        idx = idx[batch_size:]
        batch_data = data[batch_idx]
        batch_target = target[batch_idx]
        
        batch_data = torch.from_numpy(batch_data).float()
        batch_target = torch.from_numpy(batch_target).float().view(-1, 1)
        
        if use_gpu:
            batch_data = batch_data.cuda()
            batch_target = batch_target.cuda()
        
        yield batch_data, batch_target

In [138]:
def get_eval_loss(model, criterion, data, targets, use_gpu=False):
    preds, targets = get_eval_preds(model, data, targets, use_gpu)
    loss = criterion(preds, targets)
    return loss.item()

In [139]:
def get_eval_preds(model, data, targets, use_gpu=False):
    with torch.no_grad():
        model.eval()
        model_preds = []
        tensor_targets = []
        for x, y in get_batches(data, targets, batch_size,
                                mode='test', use_gpu=use_gpu):
            model_preds.append(model(x))
            tensor_targets.append(y)
        model_preds = torch.cat(model_preds, dim=0)
        tensor_targets = torch.cat(tensor_targets, dim=0)
    return model_preds, tensor_targets

In [140]:
batch_size = 32
patience = 15
running_patience = 5
checkpoint = 'best_model.sav'
print_every = 200
num_epochs = 5000
running_batch = 0
running_loss = 0
min_loss = float('inf')
use_gpu = torch.cuda.is_available()

criterion = nn.SmoothL1Loss()

model = NALU(2, 1)
if use_gpu:
    model = model.cuda()
optimizer = Adam(model.parameters())

for epoch in range(num_epochs):
    model.train()
    for x, y in get_batches(X_train, Y_train, batch_size,
                            mode='train', use_gpu=use_gpu):
        output = model(x)
        loss = criterion(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        running_batch += 1
        
#         if running_batch % print_every == 0:
#             print('Training loss after {} batches: {}'.format(running_batch, running_loss/running_batch))
            
    valid_loss = get_eval_loss(model, criterion,
                              X_valid, Y_valid, False)
    print("Validation loss after epoch {}: {}".format(epoch, valid_loss))
    if valid_loss < min_loss:
        min_loss = valid_loss
        print('Validation loss improved! Saving model.')
        with open(checkpoint, 'wb') as f:
            torch.save(model.state_dict(), f)
            running_patience = patience
    else:
        running_patience -= 1
    if running_patience == 0:
        print('Ran out of patience, early stopping employed!')
        break

Validation loss after epoch 0: 5.665528774261475
Validation loss improved! Saving model.
Validation loss after epoch 1: 5.657806873321533
Validation loss improved! Saving model.
Validation loss after epoch 2: 5.647764682769775
Validation loss improved! Saving model.
Validation loss after epoch 3: 5.636740207672119
Validation loss improved! Saving model.
Validation loss after epoch 4: 5.6271867752075195
Validation loss improved! Saving model.
Validation loss after epoch 5: 5.620261192321777
Validation loss improved! Saving model.
Validation loss after epoch 6: 5.614684104919434
Validation loss improved! Saving model.
Validation loss after epoch 7: 5.610252857208252
Validation loss improved! Saving model.
Validation loss after epoch 8: 5.606377124786377
Validation loss improved! Saving model.
Validation loss after epoch 9: 5.602868556976318
Validation loss improved! Saving model.
Validation loss after epoch 10: 5.599340438842773
Validation loss improved! Saving model.
Validation loss aft

Validation loss after epoch 92: 3.1872646808624268
Validation loss improved! Saving model.
Validation loss after epoch 93: 3.1794495582580566
Validation loss improved! Saving model.
Validation loss after epoch 94: 3.171966075897217
Validation loss improved! Saving model.
Validation loss after epoch 95: 3.1648077964782715
Validation loss improved! Saving model.
Validation loss after epoch 96: 3.158099412918091
Validation loss improved! Saving model.
Validation loss after epoch 97: 3.151664972305298
Validation loss improved! Saving model.
Validation loss after epoch 98: 3.1452035903930664
Validation loss improved! Saving model.
Validation loss after epoch 99: 3.1394567489624023
Validation loss improved! Saving model.
Validation loss after epoch 100: 3.133735418319702
Validation loss improved! Saving model.
Validation loss after epoch 101: 3.1280696392059326
Validation loss improved! Saving model.
Validation loss after epoch 102: 3.122917890548706
Validation loss improved! Saving model.
V

Validation loss after epoch 184: 2.9914612770080566
Validation loss improved! Saving model.
Validation loss after epoch 185: 2.990964412689209
Validation loss improved! Saving model.
Validation loss after epoch 186: 2.990417957305908
Validation loss improved! Saving model.
Validation loss after epoch 187: 2.989988088607788
Validation loss improved! Saving model.
Validation loss after epoch 188: 2.989508628845215
Validation loss improved! Saving model.
Validation loss after epoch 189: 2.9890151023864746
Validation loss improved! Saving model.
Validation loss after epoch 190: 2.9885690212249756
Validation loss improved! Saving model.
Validation loss after epoch 191: 2.9880945682525635
Validation loss improved! Saving model.
Validation loss after epoch 192: 2.987632989883423
Validation loss improved! Saving model.
Validation loss after epoch 193: 2.987166404724121
Validation loss improved! Saving model.
Validation loss after epoch 194: 2.9866912364959717
Validation loss improved! Saving m

Validation loss after epoch 277: 2.9584858417510986
Validation loss improved! Saving model.
Validation loss after epoch 278: 2.9583818912506104
Validation loss improved! Saving model.
Validation loss after epoch 279: 2.958021640777588
Validation loss improved! Saving model.
Validation loss after epoch 280: 2.9576172828674316
Validation loss improved! Saving model.
Validation loss after epoch 281: 2.9574856758117676
Validation loss improved! Saving model.
Validation loss after epoch 282: 2.9571480751037598
Validation loss improved! Saving model.
Validation loss after epoch 283: 2.9568674564361572
Validation loss improved! Saving model.
Validation loss after epoch 284: 2.9566686153411865
Validation loss improved! Saving model.
Validation loss after epoch 285: 2.9563510417938232
Validation loss improved! Saving model.
Validation loss after epoch 286: 2.956148624420166
Validation loss improved! Saving model.
Validation loss after epoch 287: 2.955791711807251
Validation loss improved! Savin

Validation loss after epoch 368: 2.937228202819824
Validation loss improved! Saving model.
Validation loss after epoch 369: 2.9371259212493896
Validation loss improved! Saving model.
Validation loss after epoch 370: 2.936882495880127
Validation loss improved! Saving model.
Validation loss after epoch 371: 2.936656951904297
Validation loss improved! Saving model.
Validation loss after epoch 372: 2.936566114425659
Validation loss improved! Saving model.
Validation loss after epoch 373: 2.9364707469940186
Validation loss improved! Saving model.
Validation loss after epoch 374: 2.936011552810669
Validation loss improved! Saving model.
Validation loss after epoch 375: 2.9358279705047607
Validation loss improved! Saving model.
Validation loss after epoch 376: 2.9356322288513184
Validation loss improved! Saving model.
Validation loss after epoch 377: 2.9357028007507324
Validation loss after epoch 378: 2.9353420734405518
Validation loss improved! Saving model.
Validation loss after epoch 379: 

Validation loss after epoch 464: 2.921745777130127
Validation loss improved! Saving model.
Validation loss after epoch 465: 2.9218156337738037
Validation loss after epoch 466: 2.9214107990264893
Validation loss improved! Saving model.
Validation loss after epoch 467: 2.9214680194854736
Validation loss after epoch 468: 2.921506404876709
Validation loss after epoch 469: 2.9213738441467285
Validation loss improved! Saving model.
Validation loss after epoch 470: 2.920971155166626
Validation loss improved! Saving model.
Validation loss after epoch 471: 2.9210119247436523
Validation loss after epoch 472: 2.9207053184509277
Validation loss improved! Saving model.
Validation loss after epoch 473: 2.9206840991973877
Validation loss improved! Saving model.
Validation loss after epoch 474: 2.920510768890381
Validation loss improved! Saving model.
Validation loss after epoch 475: 2.920426368713379
Validation loss improved! Saving model.
Validation loss after epoch 476: 2.9200470447540283
Validatio

Validation loss after epoch 565: 2.9106414318084717
Validation loss improved! Saving model.
Validation loss after epoch 566: 2.910597801208496
Validation loss improved! Saving model.
Validation loss after epoch 567: 2.9103424549102783
Validation loss improved! Saving model.
Validation loss after epoch 568: 2.9101898670196533
Validation loss improved! Saving model.
Validation loss after epoch 569: 2.9100756645202637
Validation loss improved! Saving model.
Validation loss after epoch 570: 2.909966468811035
Validation loss improved! Saving model.
Validation loss after epoch 571: 2.910041332244873
Validation loss after epoch 572: 2.9098222255706787
Validation loss improved! Saving model.
Validation loss after epoch 573: 2.909813642501831
Validation loss improved! Saving model.
Validation loss after epoch 574: 2.9096181392669678
Validation loss improved! Saving model.
Validation loss after epoch 575: 2.9097201824188232
Validation loss after epoch 576: 2.9095780849456787
Validation loss impr

Validation loss after epoch 669: 2.9027512073516846
Validation loss improved! Saving model.
Validation loss after epoch 670: 2.902700424194336
Validation loss improved! Saving model.
Validation loss after epoch 671: 2.9027600288391113
Validation loss after epoch 672: 2.9027466773986816
Validation loss after epoch 673: 2.902310371398926
Validation loss improved! Saving model.
Validation loss after epoch 674: 2.9021904468536377
Validation loss improved! Saving model.
Validation loss after epoch 675: 2.9021434783935547
Validation loss improved! Saving model.
Validation loss after epoch 676: 2.902147054672241
Validation loss after epoch 677: 2.9020297527313232
Validation loss improved! Saving model.
Validation loss after epoch 678: 2.901993751525879
Validation loss improved! Saving model.
Validation loss after epoch 679: 2.9019060134887695
Validation loss improved! Saving model.
Validation loss after epoch 680: 2.9018197059631348
Validation loss improved! Saving model.
Validation loss afte

Validation loss after epoch 767: 2.896061658859253
Validation loss improved! Saving model.
Validation loss after epoch 768: 2.8959805965423584
Validation loss improved! Saving model.
Validation loss after epoch 769: 2.895864725112915
Validation loss improved! Saving model.
Validation loss after epoch 770: 2.8958611488342285
Validation loss improved! Saving model.
Validation loss after epoch 771: 2.895988702774048
Validation loss after epoch 772: 2.8956711292266846
Validation loss improved! Saving model.
Validation loss after epoch 773: 2.895824432373047
Validation loss after epoch 774: 2.895472764968872
Validation loss improved! Saving model.
Validation loss after epoch 775: 2.8953006267547607
Validation loss improved! Saving model.
Validation loss after epoch 776: 2.8952295780181885
Validation loss improved! Saving model.
Validation loss after epoch 777: 2.8952383995056152
Validation loss after epoch 778: 2.8950963020324707
Validation loss improved! Saving model.
Validation loss after

Validation loss after epoch 866: 2.8878777027130127
Validation loss improved! Saving model.
Validation loss after epoch 867: 2.8878660202026367
Validation loss improved! Saving model.
Validation loss after epoch 868: 2.8879311084747314
Validation loss after epoch 869: 2.8877015113830566
Validation loss improved! Saving model.
Validation loss after epoch 870: 2.887653112411499
Validation loss improved! Saving model.
Validation loss after epoch 871: 2.8874242305755615
Validation loss improved! Saving model.
Validation loss after epoch 872: 2.8874411582946777
Validation loss after epoch 873: 2.8874144554138184
Validation loss improved! Saving model.
Validation loss after epoch 874: 2.8870508670806885
Validation loss improved! Saving model.
Validation loss after epoch 875: 2.886885404586792
Validation loss improved! Saving model.
Validation loss after epoch 876: 2.8866612911224365
Validation loss improved! Saving model.
Validation loss after epoch 877: 2.8863258361816406
Validation loss im

Validation loss after epoch 961: 2.8251051902770996
Validation loss improved! Saving model.
Validation loss after epoch 962: 2.8246922492980957
Validation loss improved! Saving model.
Validation loss after epoch 963: 2.824185848236084
Validation loss improved! Saving model.
Validation loss after epoch 964: 2.8240933418273926
Validation loss improved! Saving model.
Validation loss after epoch 965: 2.8236243724823
Validation loss improved! Saving model.
Validation loss after epoch 966: 2.823056936264038
Validation loss improved! Saving model.
Validation loss after epoch 967: 2.8228349685668945
Validation loss improved! Saving model.
Validation loss after epoch 968: 2.822425603866577
Validation loss improved! Saving model.
Validation loss after epoch 969: 2.822275400161743
Validation loss improved! Saving model.
Validation loss after epoch 970: 2.8218886852264404
Validation loss improved! Saving model.
Validation loss after epoch 971: 2.8217034339904785
Validation loss improved! Saving mo

Validation loss after epoch 1056: 2.814018964767456
Validation loss improved! Saving model.
Validation loss after epoch 1057: 2.813986301422119
Validation loss improved! Saving model.
Validation loss after epoch 1058: 2.8140788078308105
Validation loss after epoch 1059: 2.8140857219696045
Validation loss after epoch 1060: 2.813871145248413
Validation loss improved! Saving model.
Validation loss after epoch 1061: 2.813791513442993
Validation loss improved! Saving model.
Validation loss after epoch 1062: 2.8135993480682373
Validation loss improved! Saving model.
Validation loss after epoch 1063: 2.8136937618255615
Validation loss after epoch 1064: 2.8135809898376465
Validation loss improved! Saving model.
Validation loss after epoch 1065: 2.8135316371917725
Validation loss improved! Saving model.
Validation loss after epoch 1066: 2.8134396076202393
Validation loss improved! Saving model.
Validation loss after epoch 1067: 2.8135509490966797
Validation loss after epoch 1068: 2.813423871994

Validation loss after epoch 1161: 2.809556007385254
Validation loss improved! Saving model.
Validation loss after epoch 1162: 2.8095102310180664
Validation loss improved! Saving model.
Validation loss after epoch 1163: 2.8094966411590576
Validation loss improved! Saving model.
Validation loss after epoch 1164: 2.809441328048706
Validation loss improved! Saving model.
Validation loss after epoch 1165: 2.8094005584716797
Validation loss improved! Saving model.
Validation loss after epoch 1166: 2.8093764781951904
Validation loss improved! Saving model.
Validation loss after epoch 1167: 2.8094446659088135
Validation loss after epoch 1168: 2.8093740940093994
Validation loss improved! Saving model.
Validation loss after epoch 1169: 2.8092780113220215
Validation loss improved! Saving model.
Validation loss after epoch 1170: 2.8093550205230713
Validation loss after epoch 1171: 2.809202194213867
Validation loss improved! Saving model.
Validation loss after epoch 1172: 2.809164524078369
Validati

Validation loss after epoch 1266: 2.805896043777466
Validation loss after epoch 1267: 2.805851697921753
Validation loss improved! Saving model.
Validation loss after epoch 1268: 2.8058390617370605
Validation loss improved! Saving model.
Validation loss after epoch 1269: 2.8057892322540283
Validation loss improved! Saving model.
Validation loss after epoch 1270: 2.8057565689086914
Validation loss improved! Saving model.
Validation loss after epoch 1271: 2.805767059326172
Validation loss after epoch 1272: 2.805635929107666
Validation loss improved! Saving model.
Validation loss after epoch 1273: 2.8056857585906982
Validation loss after epoch 1274: 2.8056869506835938
Validation loss after epoch 1275: 2.8056299686431885
Validation loss improved! Saving model.
Validation loss after epoch 1276: 2.80556583404541
Validation loss improved! Saving model.
Validation loss after epoch 1277: 2.805593967437744
Validation loss after epoch 1278: 2.805560827255249
Validation loss improved! Saving model.

Validation loss after epoch 1371: 2.8028194904327393
Validation loss improved! Saving model.
Validation loss after epoch 1372: 2.802873134613037
Validation loss after epoch 1373: 2.8027493953704834
Validation loss improved! Saving model.
Validation loss after epoch 1374: 2.8027873039245605
Validation loss after epoch 1375: 2.8027637004852295
Validation loss after epoch 1376: 2.8026628494262695
Validation loss improved! Saving model.
Validation loss after epoch 1377: 2.802635908126831
Validation loss improved! Saving model.
Validation loss after epoch 1378: 2.802665948867798
Validation loss after epoch 1379: 2.8026123046875
Validation loss improved! Saving model.
Validation loss after epoch 1380: 2.8026063442230225
Validation loss improved! Saving model.
Validation loss after epoch 1381: 2.8025059700012207
Validation loss improved! Saving model.
Validation loss after epoch 1382: 2.8025169372558594
Validation loss after epoch 1383: 2.8024680614471436
Validation loss improved! Saving mode

Validation loss after epoch 1475: 2.800170660018921
Validation loss after epoch 1476: 2.800201654434204
Validation loss after epoch 1477: 2.8001725673675537
Validation loss after epoch 1478: 2.800163984298706
Validation loss improved! Saving model.
Validation loss after epoch 1479: 2.800180673599243
Validation loss after epoch 1480: 2.800121545791626
Validation loss improved! Saving model.
Validation loss after epoch 1481: 2.8001039028167725
Validation loss improved! Saving model.
Validation loss after epoch 1482: 2.800082206726074
Validation loss improved! Saving model.
Validation loss after epoch 1483: 2.8000457286834717
Validation loss improved! Saving model.
Validation loss after epoch 1484: 2.800034284591675
Validation loss improved! Saving model.
Validation loss after epoch 1485: 2.80004620552063
Validation loss after epoch 1486: 2.800009250640869
Validation loss improved! Saving model.
Validation loss after epoch 1487: 2.7999956607818604
Validation loss improved! Saving model.
V

Validation loss after epoch 1581: 2.7980029582977295
Validation loss after epoch 1582: 2.797994613647461
Validation loss after epoch 1583: 2.798001527786255
Validation loss after epoch 1584: 2.7979824542999268
Validation loss improved! Saving model.
Validation loss after epoch 1585: 2.797968864440918
Validation loss improved! Saving model.
Validation loss after epoch 1586: 2.7979276180267334
Validation loss improved! Saving model.
Validation loss after epoch 1587: 2.797935962677002
Validation loss after epoch 1588: 2.7979111671447754
Validation loss improved! Saving model.
Validation loss after epoch 1589: 2.797879695892334
Validation loss improved! Saving model.
Validation loss after epoch 1590: 2.7978525161743164
Validation loss improved! Saving model.
Validation loss after epoch 1591: 2.7978405952453613
Validation loss improved! Saving model.
Validation loss after epoch 1592: 2.797849655151367
Validation loss after epoch 1593: 2.79780650138855
Validation loss improved! Saving model.

Validation loss after epoch 1690: 2.7961504459381104
Validation loss after epoch 1691: 2.7961316108703613
Validation loss improved! Saving model.
Validation loss after epoch 1692: 2.7961585521698
Validation loss after epoch 1693: 2.796118974685669
Validation loss improved! Saving model.
Validation loss after epoch 1694: 2.7960762977600098
Validation loss improved! Saving model.
Validation loss after epoch 1695: 2.7961151599884033
Validation loss after epoch 1696: 2.796058416366577
Validation loss improved! Saving model.
Validation loss after epoch 1697: 2.796074628829956
Validation loss after epoch 1698: 2.7960894107818604
Validation loss after epoch 1699: 2.7960407733917236
Validation loss improved! Saving model.
Validation loss after epoch 1700: 2.795980930328369
Validation loss improved! Saving model.
Validation loss after epoch 1701: 2.7959909439086914
Validation loss after epoch 1702: 2.796022891998291
Validation loss after epoch 1703: 2.7959468364715576
Validation loss improved! 

Validation loss after epoch 1800: 2.7945165634155273
Validation loss after epoch 1801: 2.7945258617401123
Validation loss after epoch 1802: 2.7944819927215576
Validation loss improved! Saving model.
Validation loss after epoch 1803: 2.79439377784729
Validation loss improved! Saving model.
Validation loss after epoch 1804: 2.794442653656006
Validation loss after epoch 1805: 2.794417381286621
Validation loss after epoch 1806: 2.794372320175171
Validation loss improved! Saving model.
Validation loss after epoch 1807: 2.794365882873535
Validation loss improved! Saving model.
Validation loss after epoch 1808: 2.7943694591522217
Validation loss after epoch 1809: 2.7943637371063232
Validation loss improved! Saving model.
Validation loss after epoch 1810: 2.794325351715088
Validation loss improved! Saving model.
Validation loss after epoch 1811: 2.7942798137664795
Validation loss improved! Saving model.
Validation loss after epoch 1812: 2.794297933578491
Validation loss after epoch 1813: 2.794

Validation loss after epoch 1913: 2.792510986328125
Validation loss after epoch 1914: 2.792466640472412
Validation loss improved! Saving model.
Validation loss after epoch 1915: 2.7924201488494873
Validation loss improved! Saving model.
Validation loss after epoch 1916: 2.792431354522705
Validation loss after epoch 1917: 2.792428970336914
Validation loss after epoch 1918: 2.7923641204833984
Validation loss improved! Saving model.
Validation loss after epoch 1919: 2.7923872470855713
Validation loss after epoch 1920: 2.7923812866210938
Validation loss after epoch 1921: 2.792356491088867
Validation loss improved! Saving model.
Validation loss after epoch 1922: 2.792313814163208
Validation loss improved! Saving model.
Validation loss after epoch 1923: 2.792318820953369
Validation loss after epoch 1924: 2.7923200130462646
Validation loss after epoch 1925: 2.792269706726074
Validation loss improved! Saving model.
Validation loss after epoch 1926: 2.792283773422241
Validation loss after epoch

Validation loss after epoch 2023: 2.790745496749878
Validation loss improved! Saving model.
Validation loss after epoch 2024: 2.7907285690307617
Validation loss improved! Saving model.
Validation loss after epoch 2025: 2.79070782661438
Validation loss improved! Saving model.
Validation loss after epoch 2026: 2.790693521499634
Validation loss improved! Saving model.
Validation loss after epoch 2027: 2.7906806468963623
Validation loss improved! Saving model.
Validation loss after epoch 2028: 2.7906534671783447
Validation loss improved! Saving model.
Validation loss after epoch 2029: 2.7906579971313477
Validation loss after epoch 2030: 2.7906417846679688
Validation loss improved! Saving model.
Validation loss after epoch 2031: 2.790646553039551
Validation loss after epoch 2032: 2.790627956390381
Validation loss improved! Saving model.
Validation loss after epoch 2033: 2.790584087371826
Validation loss improved! Saving model.
Validation loss after epoch 2034: 2.7906322479248047
Validation 

In [141]:
model.load_state_dict(torch.load(checkpoint))

In [142]:
test_loss = get_eval_loss(model, criterion,
                          X_test, Y_test, False)

In [143]:
test_loss

317.2533874511719

In [144]:
test_preds, test_targets = get_eval_preds(model, X_test, Y_test, False)

In [145]:
test_preds = test_preds.cpu().numpy().flatten()
test_targets = test_targets.cpu().numpy().flatten()

In [146]:
accuracy = np.isclose(test_preds, test_targets, rtol=1e-4).astype(np.int32).mean()
accuracy

0.0005

In [147]:
torch.tanh(model.nac.W_hat)

tensor([[1.0000, 0.9976]], grad_fn=<TanhBackward>)

In [148]:
torch.sigmoid(model.nac.M_hat)

tensor([[0.9982, 0.9977]], grad_fn=<SigmoidBackward>)

In [149]:
model.S

Linear(in_features=2, out_features=1, bias=True)

In [150]:
test_targets[:10]

array([-1218.0051  ,   508.96805 ,   367.82147 ,   851.1205  ,
         584.9041  ,  -377.07687 , -1888.1703  ,  -505.0698  ,
        1028.0148  ,    18.078764], dtype=float32)

In [151]:
test_preds[:10]

array([-1171.1208  ,   491.456   ,   351.11554 ,   811.66907 ,
         559.27966 ,  -365.16028 ,  1811.9347  ,   486.8439  ,
         980.87006 ,    17.749733], dtype=float32)

In [152]:
X_test[:10]

array([[-41.82293364,  29.1229016 ],
       [ 27.17737426,  18.7276387 ],
       [-10.57032672, -34.79755119],
       [-21.88003222, -38.8994161 ],
       [-32.24733084, -18.13806356],
       [-44.07634234,   8.55508536],
       [ 45.73059987, -41.28899055],
       [ 24.08210651, -20.97282424],
       [-33.70063163, -30.50431924],
       [  4.02715109,   4.48921935]])

In [153]:
Y_test[:5]

array([-1218.00518107,   508.96804602,   367.8214852 ,   851.12047768,
         584.90413634])