In [1]:
import numpy as np
np.random.seed(20)
import random

In [2]:
def generate_data(num_train, num_test, dim, num_sum, fn):
    data = np.arange(0,dim,dtype=np.int64)
    X, y = [], []
    for i in range(num_train + num_test):
        idx_a = random.sample(range(dim), num_sum)
        idx_b = random.sample([x for x in range(dim) if x not in idx_a], num_sum)
        a, b = data[idx_a].sum(), data[idx_b].sum()
        X.append([a, b])
        y.append(fn(a, b))
    X= np.array(X)
    y = np.array(y)
    #X = torch.FloatTensor(X)
    #y = torch.FloatTensor(y).unsqueeze_(1)
    indices = list(range(num_train + num_test))
    np.random.shuffle(indices)
    X_train, y_train = X[indices[num_test:]], y[indices[num_test:]]
    X_test, y_test = X[indices[:num_test]], y[indices[:num_test]]
    return X_train, y_train, X_test, y_test


## Addition

In [3]:
#Data Generation

fn = lambda x, y: x + y
X_train,Y_train,X_test, Y_test = generate_data(num_train=500, num_test=50,dim=100, num_sum=5, fn=fn)

In [4]:
from keras.models import Model
from keras.layers import Input
from keras.optimizers import RMSprop
from keras.callbacks import ModelCheckpoint
import os
from nac import NAC
from nalu import NALU

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [5]:
#Hyperparameters 
units = 2
num_samples = 1000

In [6]:
# generate the model
ip = Input(shape=(2,))
x = NAC(units)(ip)
x = NAC(1)(x)

model = Model(ip, x)
model.summary()

optimizer = RMSprop(0.1)
model.compile(optimizer, 'mse')

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 2)                 0         
_________________________________________________________________
nac_1 (NAC)                  (None, 2)                 8         
_________________________________________________________________
nac_2 (NAC)                  (None, 1)                 4         
Total params: 12
Trainable params: 12
Non-trainable params: 0
_________________________________________________________________


In [7]:
# Train model
model.fit(X_train, Y_train, batch_size=64, epochs=500,
          verbose=2,validation_data=(X_test, Y_test))

Train on 500 samples, validate on 50 samples
Epoch 1/500
 - 0s - loss: 120033.2295 - val_loss: 9871.0254
Epoch 2/500
 - 0s - loss: 2887.7311 - val_loss: 149.0609
Epoch 3/500
 - 0s - loss: 105.4040 - val_loss: 91.0382
Epoch 4/500
 - 0s - loss: 80.7535 - val_loss: 69.9209
Epoch 5/500
 - 0s - loss: 61.1970 - val_loss: 47.6732
Epoch 6/500
 - 0s - loss: 41.4109 - val_loss: 31.4510
Epoch 7/500
 - 0s - loss: 38.2154 - val_loss: 186.6171
Epoch 8/500
 - 0s - loss: 601.3105 - val_loss: 228.1660
Epoch 9/500
 - 0s - loss: 250.8413 - val_loss: 352.9539
Epoch 10/500
 - 0s - loss: 410.4835 - val_loss: 366.9554
Epoch 11/500
 - 0s - loss: 292.3383 - val_loss: 342.7209
Epoch 12/500
 - 0s - loss: 330.1232 - val_loss: 288.0601
Epoch 13/500
 - 0s - loss: 394.8875 - val_loss: 276.9398
Epoch 14/500
 - 0s - loss: 240.2451 - val_loss: 320.9694
Epoch 15/500
 - 0s - loss: 360.6876 - val_loss: 340.0929
Epoch 16/500
 - 0s - loss: 286.5602 - val_loss: 254.4518
Epoch 17/500
 - 0s - loss: 320.1152 - val_loss: 487.895

 - 0s - loss: 10.9834 - val_loss: 8.1519
Epoch 146/500
 - 0s - loss: 11.3163 - val_loss: 11.3092
Epoch 147/500
 - 0s - loss: 7.7420 - val_loss: 13.5639
Epoch 148/500
 - 0s - loss: 9.7583 - val_loss: 9.6204
Epoch 149/500
 - 0s - loss: 7.7495 - val_loss: 6.0046
Epoch 150/500
 - 0s - loss: 7.8722 - val_loss: 6.9684
Epoch 151/500
 - 0s - loss: 6.4284 - val_loss: 5.6044
Epoch 152/500
 - 0s - loss: 7.7963 - val_loss: 4.3874
Epoch 153/500
 - 0s - loss: 4.3976 - val_loss: 7.1391
Epoch 154/500
 - 0s - loss: 6.6668 - val_loss: 4.8602
Epoch 155/500
 - 0s - loss: 4.5088 - val_loss: 5.2843
Epoch 156/500
 - 0s - loss: 4.9070 - val_loss: 3.1406
Epoch 157/500
 - 0s - loss: 4.2452 - val_loss: 2.6673
Epoch 158/500
 - 0s - loss: 4.2263 - val_loss: 3.2193
Epoch 159/500
 - 0s - loss: 3.5755 - val_loss: 2.9657
Epoch 160/500
 - 0s - loss: 3.4023 - val_loss: 2.7653
Epoch 161/500
 - 0s - loss: 2.9633 - val_loss: 3.8995
Epoch 162/500
 - 0s - loss: 2.6286 - val_loss: 2.3539
Epoch 163/500
 - 0s - loss: 3.0047 - v

Epoch 290/500
 - 0s - loss: 3.1665e-06 - val_loss: 4.1141e-06
Epoch 291/500
 - 0s - loss: 5.7022e-06 - val_loss: 2.2191e-06
Epoch 292/500
 - 0s - loss: 2.2218e-06 - val_loss: 4.6653e-06
Epoch 293/500
 - 0s - loss: 4.2061e-06 - val_loss: 1.9451e-06
Epoch 294/500
 - 0s - loss: 2.7545e-06 - val_loss: 3.1733e-06
Epoch 295/500
 - 0s - loss: 2.9720e-06 - val_loss: 2.3539e-06
Epoch 296/500
 - 0s - loss: 2.1997e-06 - val_loss: 4.6181e-06
Epoch 297/500
 - 0s - loss: 1.9704e-06 - val_loss: 1.8170e-06
Epoch 298/500
 - 0s - loss: 2.1292e-06 - val_loss: 1.8211e-06
Epoch 299/500
 - 0s - loss: 1.5421e-06 - val_loss: 1.5271e-06
Epoch 300/500
 - 0s - loss: 1.7939e-06 - val_loss: 2.1017e-06
Epoch 301/500
 - 0s - loss: 8.8707e-07 - val_loss: 9.8966e-07
Epoch 302/500
 - 0s - loss: 1.9356e-06 - val_loss: 6.0836e-07
Epoch 303/500
 - 0s - loss: 4.5811e-07 - val_loss: 1.1458e-06
Epoch 304/500
 - 0s - loss: 1.3945e-06 - val_loss: 3.5942e-07
Epoch 305/500
 - 0s - loss: 6.7171e-07 - val_loss: 1.3713e-06
Epoch 30

Epoch 423/500
 - 0s - loss: 7.8836e-10 - val_loss: 4.2841e-10
Epoch 424/500
 - 0s - loss: 6.1328e-10 - val_loss: 4.2841e-10
Epoch 425/500
 - 0s - loss: 7.1619e-10 - val_loss: 8.3819e-10
Epoch 426/500
 - 0s - loss: 8.0187e-10 - val_loss: 9.4995e-10
Epoch 427/500
 - 0s - loss: 9.5414e-10 - val_loss: 5.5879e-10
Epoch 428/500
 - 0s - loss: 1.2536e-09 - val_loss: 3.7253e-10
Epoch 429/500
 - 0s - loss: 7.4413e-10 - val_loss: 1.0245e-09
Epoch 430/500
 - 0s - loss: 7.2690e-10 - val_loss: 5.5879e-10
Epoch 431/500
 - 0s - loss: 6.3190e-10 - val_loss: 6.1467e-10
Epoch 432/500
 - 0s - loss: 7.0035e-10 - val_loss: 5.5879e-10
Epoch 433/500
 - 0s - loss: 6.2305e-10 - val_loss: 5.7742e-10
Epoch 434/500
 - 0s - loss: 8.8755e-10 - val_loss: 1.0990e-09
Epoch 435/500
 - 0s - loss: 1.1199e-09 - val_loss: 1.0990e-09
Epoch 436/500
 - 0s - loss: 9.3598e-10 - val_loss: 6.1467e-10
Epoch 437/500
 - 0s - loss: 8.8103e-10 - val_loss: 7.4506e-10
Epoch 438/500
 - 0s - loss: 9.3179e-10 - val_loss: 5.7742e-10
Epoch 43

<keras.callbacks.History at 0x2e02e4aa9e8>

In [8]:
scores = model.evaluate(X_test, Y_test, batch_size=128)
print("Mean Squared error : ", scores)

Mean Squared error :  4.842877210364804e-10


## Subtraction

In [9]:
#Data Generation

fn = lambda x, y: x - y
X_train,Y_train,X_test, Y_test = generate_data(num_train=500, num_test=50,dim=100, num_sum=5, fn=fn)

In [10]:
# generate the model
ip = Input(shape=(2,))
x = NAC(units)(ip)
x = NAC(1)(x)

model = Model(ip, x)
model.summary()

optimizer = RMSprop(0.1)
model.compile(optimizer, 'mse')

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 2)                 0         
_________________________________________________________________
nac_3 (NAC)                  (None, 2)                 8         
_________________________________________________________________
nac_4 (NAC)                  (None, 1)                 4         
Total params: 12
Trainable params: 12
Non-trainable params: 0
_________________________________________________________________


In [11]:
# Train model
model.fit(X_train, Y_train, batch_size=64, epochs=500,
          verbose=2,validation_data=(X_test, Y_test))

Train on 500 samples, validate on 50 samples
Epoch 1/500
 - 0s - loss: 3819.2518 - val_loss: 1063.4409
Epoch 2/500
 - 0s - loss: 579.6079 - val_loss: 151.1811
Epoch 3/500
 - 0s - loss: 255.1229 - val_loss: 133.7117
Epoch 4/500
 - 0s - loss: 30.6646 - val_loss: 4.3595
Epoch 5/500
 - 0s - loss: 114.9286 - val_loss: 353.5896
Epoch 6/500
 - 0s - loss: 76.6308 - val_loss: 23.5676
Epoch 7/500
 - 0s - loss: 119.5123 - val_loss: 122.9758
Epoch 8/500
 - 0s - loss: 47.9868 - val_loss: 77.9906
Epoch 9/500
 - 0s - loss: 114.2068 - val_loss: 50.1931
Epoch 10/500
 - 0s - loss: 55.7863 - val_loss: 204.1114
Epoch 11/500
 - 0s - loss: 82.8261 - val_loss: 49.3878
Epoch 12/500
 - 0s - loss: 74.7804 - val_loss: 98.5186
Epoch 13/500
 - 0s - loss: 67.2642 - val_loss: 47.0016
Epoch 14/500
 - 0s - loss: 62.5858 - val_loss: 122.5292
Epoch 15/500
 - 0s - loss: 52.0462 - val_loss: 35.6180
Epoch 16/500
 - 0s - loss: 73.2303 - val_loss: 62.7503
Epoch 17/500
 - 0s - loss: 39.7363 - val_loss: 32.1289
Epoch 18/500
 -

Epoch 152/500
 - 0s - loss: 0.0015 - val_loss: 0.0011
Epoch 153/500
 - 0s - loss: 0.0012 - val_loss: 0.0017
Epoch 154/500
 - 0s - loss: 0.0012 - val_loss: 9.2664e-04
Epoch 155/500
 - 0s - loss: 0.0013 - val_loss: 8.3291e-04
Epoch 156/500
 - 0s - loss: 8.9228e-04 - val_loss: 0.0015
Epoch 157/500
 - 0s - loss: 0.0011 - val_loss: 5.0797e-04
Epoch 158/500
 - 0s - loss: 7.3655e-04 - val_loss: 0.0017
Epoch 159/500
 - 0s - loss: 9.7897e-04 - val_loss: 8.4091e-04
Epoch 160/500
 - 0s - loss: 6.4216e-04 - val_loss: 7.1956e-04
Epoch 161/500
 - 0s - loss: 9.1906e-04 - val_loss: 4.4774e-04
Epoch 162/500
 - 0s - loss: 6.6495e-04 - val_loss: 0.0015
Epoch 163/500
 - 0s - loss: 5.4465e-04 - val_loss: 1.6633e-04
Epoch 164/500
 - 0s - loss: 7.0825e-04 - val_loss: 7.8449e-04
Epoch 165/500
 - 0s - loss: 4.0778e-04 - val_loss: 2.8251e-04
Epoch 166/500
 - 0s - loss: 6.8068e-04 - val_loss: 2.6079e-04
Epoch 167/500
 - 0s - loss: 3.6527e-04 - val_loss: 9.6094e-04
Epoch 168/500
 - 0s - loss: 4.7626e-04 - val_los

Epoch 285/500
 - 0s - loss: 3.2610e-07 - val_loss: 1.0915e-07
Epoch 286/500
 - 0s - loss: 1.1106e-07 - val_loss: 2.5833e-07
Epoch 287/500
 - 0s - loss: 2.2749e-07 - val_loss: 1.0042e-07
Epoch 288/500
 - 0s - loss: 1.6662e-07 - val_loss: 1.2237e-07
Epoch 289/500
 - 0s - loss: 1.7570e-07 - val_loss: 1.0181e-07
Epoch 290/500
 - 0s - loss: 1.2422e-07 - val_loss: 2.4830e-07
Epoch 291/500
 - 0s - loss: 1.7114e-07 - val_loss: 1.0557e-07
Epoch 292/500
 - 0s - loss: 1.0727e-07 - val_loss: 1.7100e-07
Epoch 293/500
 - 0s - loss: 1.3452e-07 - val_loss: 6.6074e-08
Epoch 294/500
 - 0s - loss: 1.1503e-07 - val_loss: 1.5357e-07
Epoch 295/500
 - 0s - loss: 8.9203e-08 - val_loss: 1.5038e-07
Epoch 296/500
 - 0s - loss: 1.0612e-07 - val_loss: 1.6683e-07
Epoch 297/500
 - 0s - loss: 1.1093e-07 - val_loss: 3.9837e-08
Epoch 298/500
 - 0s - loss: 9.9395e-08 - val_loss: 1.4590e-07
Epoch 299/500
 - 0s - loss: 6.8804e-08 - val_loss: 2.0169e-08
Epoch 300/500
 - 0s - loss: 9.8749e-08 - val_loss: 1.6829e-07
Epoch 30

 - 0s - loss: 3.9425e-10 - val_loss: 8.2021e-10
Epoch 418/500
 - 0s - loss: 6.0133e-10 - val_loss: 8.0317e-10
Epoch 419/500
 - 0s - loss: 9.5190e-10 - val_loss: 8.0660e-10
Epoch 420/500
 - 0s - loss: 9.1095e-10 - val_loss: 1.4280e-10
Epoch 421/500
 - 0s - loss: 7.4018e-10 - val_loss: 8.0695e-10
Epoch 422/500
 - 0s - loss: 6.4474e-10 - val_loss: 8.0624e-10
Epoch 423/500
 - 0s - loss: 4.2398e-10 - val_loss: 8.0439e-10
Epoch 424/500
 - 0s - loss: 6.2480e-10 - val_loss: 1.3537e-10
Epoch 425/500
 - 0s - loss: 6.4495e-10 - val_loss: 1.4141e-10
Epoch 426/500
 - 0s - loss: 4.5340e-10 - val_loss: 1.4238e-10
Epoch 427/500
 - 0s - loss: 1.8405e-10 - val_loss: 1.2148e-09
Epoch 428/500
 - 0s - loss: 3.3607e-10 - val_loss: 1.4112e-10
Epoch 429/500
 - 0s - loss: 1.3187e-09 - val_loss: 1.3792e-10
Epoch 430/500
 - 0s - loss: 1.0254e-10 - val_loss: 1.3537e-10
Epoch 431/500
 - 0s - loss: 1.9051e-10 - val_loss: 1.2926e-10
Epoch 432/500
 - 0s - loss: 1.0608e-09 - val_loss: 1.3013e-10
Epoch 433/500
 - 0s - 

<keras.callbacks.History at 0x2e02f78f5c0>

In [12]:
scores = model.evaluate(X_test, Y_test, batch_size=128)
print("Mean Squared error : ", scores)

Mean Squared error :  1.0481016748942196e-10


## Multiplication

In [13]:
#Data Generation
fn = lambda x, y: x - y
X_train,Y_train,X_test, Y_test = generate_data(num_train=500, num_test=50,dim=100, num_sum=5, fn=fn)

In [14]:
# generate the model
ip = Input(shape=(2,))
x = NALU(units)(ip)
x = NALU(1)(x)

model = Model(ip, x)
model.summary()

optimizer = RMSprop(0.1)
model.compile(optimizer, 'mse')

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         (None, 2)                 0         
_________________________________________________________________
nalu_1 (NALU)                (None, 2)                 12        
_________________________________________________________________
nalu_2 (NALU)                (None, 1)                 6         
Total params: 18
Trainable params: 18
Non-trainable params: 0
_________________________________________________________________


In [15]:
# Train model
model.fit(X_train, Y_train, batch_size=64, epochs=500,
          verbose=2,validation_data=(X_test, Y_test))

Train on 500 samples, validate on 50 samples
Epoch 1/500
 - 0s - loss: 8012.1509 - val_loss: 7664.9277
Epoch 2/500
 - 0s - loss: 6690.0299 - val_loss: 56661.8789
Epoch 3/500
 - 0s - loss: 16199.3298 - val_loss: 3543.9387
Epoch 4/500
 - 0s - loss: 3644.6087 - val_loss: 2468.0857
Epoch 5/500
 - 0s - loss: 2559.2448 - val_loss: 2060.8779
Epoch 6/500
 - 0s - loss: 2224.0600 - val_loss: 2032.3004
Epoch 7/500
 - 0s - loss: 2118.4472 - val_loss: 1819.3683
Epoch 8/500
 - 0s - loss: 2023.0980 - val_loss: 1703.7301
Epoch 9/500
 - 0s - loss: 1940.9264 - val_loss: 1877.5830
Epoch 10/500
 - 0s - loss: 1666.5400 - val_loss: 1374.2887
Epoch 11/500
 - 0s - loss: 1434.9968 - val_loss: 1322.2197
Epoch 12/500
 - 0s - loss: 1372.8442 - val_loss: 766.7863
Epoch 13/500
 - 0s - loss: 1012.4784 - val_loss: 814.6803
Epoch 14/500
 - 0s - loss: 1047.6403 - val_loss: 438.4252
Epoch 15/500
 - 0s - loss: 549.5485 - val_loss: 680.6532
Epoch 16/500
 - 0s - loss: 835.3422 - val_loss: 286.8799
Epoch 17/500
 - 0s - loss

Epoch 142/500
 - 0s - loss: 6.5700e-09 - val_loss: 7.6479e-09
Epoch 143/500
 - 0s - loss: 6.7322e-09 - val_loss: 6.2370e-09
Epoch 144/500
 - 0s - loss: 7.2789e-09 - val_loss: 5.1147e-09
Epoch 145/500
 - 0s - loss: 7.0118e-09 - val_loss: 5.8100e-09
Epoch 146/500
 - 0s - loss: 5.2366e-09 - val_loss: 3.8289e-09
Epoch 147/500
 - 0s - loss: 7.2723e-09 - val_loss: 4.4256e-09
Epoch 148/500
 - 0s - loss: 5.2673e-09 - val_loss: 4.4394e-09
Epoch 149/500
 - 0s - loss: 5.1158e-09 - val_loss: 3.9884e-09
Epoch 150/500
 - 0s - loss: 4.8013e-09 - val_loss: 3.6558e-09
Epoch 151/500
 - 0s - loss: 4.9796e-09 - val_loss: 4.7927e-09
Epoch 152/500
 - 0s - loss: 4.9571e-09 - val_loss: 2.7958e-09
Epoch 153/500
 - 0s - loss: 4.3459e-09 - val_loss: 3.1173e-09
Epoch 154/500
 - 0s - loss: 4.1552e-09 - val_loss: 6.7256e-09
Epoch 155/500
 - 0s - loss: 4.2292e-09 - val_loss: 2.5991e-09
Epoch 156/500
 - 0s - loss: 4.1041e-09 - val_loss: 4.0883e-09
Epoch 157/500
 - 0s - loss: 4.0944e-09 - val_loss: 2.6254e-09
Epoch 15

Epoch 275/500
 - 0s - loss: 1.0888e-09 - val_loss: 7.8214e-10
Epoch 276/500
 - 0s - loss: 1.1086e-09 - val_loss: 6.4587e-10
Epoch 277/500
 - 0s - loss: 1.4643e-09 - val_loss: 6.7157e-10
Epoch 278/500
 - 0s - loss: 9.2432e-10 - val_loss: 1.4849e-09
Epoch 279/500
 - 0s - loss: 1.2515e-09 - val_loss: 6.8805e-10
Epoch 280/500
 - 0s - loss: 1.1673e-09 - val_loss: 7.7979e-10
Epoch 281/500
 - 0s - loss: 1.0904e-09 - val_loss: 1.4534e-09
Epoch 282/500
 - 0s - loss: 1.2353e-09 - val_loss: 8.8377e-10
Epoch 283/500
 - 0s - loss: 1.0705e-09 - val_loss: 7.6072e-10
Epoch 284/500
 - 0s - loss: 1.3967e-09 - val_loss: 1.5452e-09
Epoch 285/500
 - 0s - loss: 1.4766e-09 - val_loss: 6.3621e-10
Epoch 286/500
 - 0s - loss: 9.9347e-10 - val_loss: 1.3218e-09
Epoch 287/500
 - 0s - loss: 1.1682e-09 - val_loss: 6.8458e-10
Epoch 288/500
 - 0s - loss: 1.1856e-09 - val_loss: 1.5892e-09
Epoch 289/500
 - 0s - loss: 1.3103e-09 - val_loss: 1.5691e-09
Epoch 290/500
 - 0s - loss: 1.2452e-09 - val_loss: 8.2300e-10
Epoch 29

Epoch 408/500
 - 0s - loss: 9.0928e-10 - val_loss: 8.4092e-10
Epoch 409/500
 - 0s - loss: 7.9829e-10 - val_loss: 9.3833e-10
Epoch 410/500
 - 0s - loss: 8.9687e-10 - val_loss: 5.9599e-10
Epoch 411/500
 - 0s - loss: 7.7467e-10 - val_loss: 1.0492e-09
Epoch 412/500
 - 0s - loss: 8.7490e-10 - val_loss: 9.2139e-10
Epoch 413/500
 - 0s - loss: 8.1705e-10 - val_loss: 6.1361e-10
Epoch 414/500
 - 0s - loss: 8.0176e-10 - val_loss: 6.1349e-10
Epoch 415/500
 - 0s - loss: 8.5670e-10 - val_loss: 8.3476e-10
Epoch 416/500
 - 0s - loss: 8.7644e-10 - val_loss: 4.7149e-10
Epoch 417/500
 - 0s - loss: 8.8905e-10 - val_loss: 4.2805e-10
Epoch 418/500
 - 0s - loss: 7.7926e-10 - val_loss: 8.3185e-10
Epoch 419/500
 - 0s - loss: 7.5684e-10 - val_loss: 6.6402e-10
Epoch 420/500
 - 0s - loss: 9.2111e-10 - val_loss: 4.6026e-10
Epoch 421/500
 - 0s - loss: 8.0413e-10 - val_loss: 5.7216e-10
Epoch 422/500
 - 0s - loss: 8.2799e-10 - val_loss: 4.4137e-10
Epoch 423/500
 - 0s - loss: 7.6276e-10 - val_loss: 4.9297e-10
Epoch 42

<keras.callbacks.History at 0x2e02e4aa940>

In [16]:
scores = model.evaluate(X_test, Y_test, batch_size=128)
print("Mean Squared error : ", scores)

Mean Squared error :  3.3966313095490364e-10


## Division

In [17]:
#Data Generation
fn = lambda x, y: x - y
X_train,Y_train,X_test, Y_test = generate_data(num_train=500, num_test=50,dim=100, num_sum=5, fn=fn)

In [18]:
# generate the model
ip = Input(shape=(2,))
x = NALU(units)(ip)
x = NALU(1)(x)

model = Model(ip, x)
model.summary()

optimizer = RMSprop(0.1)
model.compile(optimizer, 'mse')

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         (None, 2)                 0         
_________________________________________________________________
nalu_3 (NALU)                (None, 2)                 12        
_________________________________________________________________
nalu_4 (NALU)                (None, 1)                 6         
Total params: 18
Trainable params: 18
Non-trainable params: 0
_________________________________________________________________


In [19]:
# Train model
model.fit(X_train, Y_train, batch_size=64, epochs=1000,
          verbose=2,validation_data=(X_test, Y_test))

Train on 500 samples, validate on 50 samples
Epoch 1/1000
 - 0s - loss: 8677.3013 - val_loss: 8616.3984
Epoch 2/1000
 - 0s - loss: 8677.2282 - val_loss: 8616.4141
Epoch 3/1000
 - 0s - loss: 8677.2255 - val_loss: 8616.4209
Epoch 4/1000
 - 0s - loss: 8677.2245 - val_loss: 8616.4277
Epoch 5/1000
 - 0s - loss: 8677.2230 - val_loss: 8616.4316
Epoch 6/1000
 - 0s - loss: 8677.2228 - val_loss: 8616.4336
Epoch 7/1000
 - 0s - loss: 8677.2226 - val_loss: 8616.4346
Epoch 8/1000
 - 0s - loss: 8677.2226 - val_loss: 8616.4365
Epoch 9/1000
 - 0s - loss: 8677.2221 - val_loss: 8616.4375
Epoch 10/1000
 - 0s - loss: 8677.2225 - val_loss: 8616.4375
Epoch 11/1000
 - 0s - loss: 8677.2220 - val_loss: 8616.4385
Epoch 12/1000
 - 0s - loss: 8677.2220 - val_loss: 8616.4385
Epoch 13/1000
 - 0s - loss: 8677.2220 - val_loss: 8616.4395
Epoch 14/1000
 - 0s - loss: 8677.2220 - val_loss: 8616.4395
Epoch 15/1000
 - 0s - loss: 8677.2220 - val_loss: 8616.4395
Epoch 16/1000
 - 0s - loss: 8677.2220 - val_loss: 8616.4375
Epoc

Epoch 141/1000
 - 0s - loss: 5.3776 - val_loss: 1.9601
Epoch 142/1000
 - 0s - loss: 3.2122 - val_loss: 6.3752
Epoch 143/1000
 - 0s - loss: 5.9778 - val_loss: 1.5048
Epoch 144/1000
 - 0s - loss: 2.3307 - val_loss: 4.2593
Epoch 145/1000
 - 0s - loss: 3.4106 - val_loss: 1.8728
Epoch 146/1000
 - 0s - loss: 2.4587 - val_loss: 1.3654
Epoch 147/1000
 - 0s - loss: 3.1657 - val_loss: 1.2175
Epoch 148/1000
 - 0s - loss: 1.7612 - val_loss: 2.4748
Epoch 149/1000
 - 0s - loss: 2.1454 - val_loss: 1.1359
Epoch 150/1000
 - 0s - loss: 1.2308 - val_loss: 1.7249
Epoch 151/1000
 - 0s - loss: 436.9644 - val_loss: 0.8689
Epoch 152/1000
 - 0s - loss: 1.3052 - val_loss: 0.8240
Epoch 153/1000
 - 0s - loss: 1.2320 - val_loss: 0.8021
Epoch 154/1000
 - 0s - loss: 1.7558 - val_loss: 0.8830
Epoch 155/1000
 - 0s - loss: 1.5948 - val_loss: 1.7612
Epoch 156/1000
 - 0s - loss: 1.6570 - val_loss: 0.9803
Epoch 157/1000
 - 0s - loss: 1.6516 - val_loss: 0.6662
Epoch 158/1000
 - 0s - loss: 1.1508 - val_loss: 1.1007
Epoch 15

Epoch 283/1000
 - 0s - loss: 5.0176e-06 - val_loss: 1.7381e-06
Epoch 284/1000
 - 0s - loss: 3.8987e-06 - val_loss: 4.2931e-06
Epoch 285/1000
 - 0s - loss: 1.5554e-06 - val_loss: 1.5450e-06
Epoch 286/1000
 - 0s - loss: 1.0196e-06 - val_loss: 1.3606e-06
Epoch 287/1000
 - 0s - loss: 4.7062e-06 - val_loss: 1.0213e-06
Epoch 288/1000
 - 0s - loss: 9.1242e-07 - val_loss: 1.2126e-06
Epoch 289/1000
 - 0s - loss: 2.3007e-06 - val_loss: 1.4861e-06
Epoch 290/1000
 - 0s - loss: 9.5191e-07 - val_loss: 6.6211e-07
Epoch 291/1000
 - 0s - loss: 1.6448e-06 - val_loss: 7.9723e-07
Epoch 292/1000
 - 0s - loss: 7.3695e-07 - val_loss: 5.8977e-07
Epoch 293/1000
 - 0s - loss: 8.0189e-07 - val_loss: 4.9185e-07
Epoch 294/1000
 - 0s - loss: 4.6069e-07 - val_loss: 1.4837e-06
Epoch 295/1000
 - 0s - loss: 9.6723e-07 - val_loss: 3.8570e-07
Epoch 296/1000
 - 0s - loss: 3.3163e-07 - val_loss: 7.5324e-07
Epoch 297/1000
 - 0s - loss: 5.5062e-07 - val_loss: 4.3120e-07
Epoch 298/1000
 - 0s - loss: 2.4442e-07 - val_loss: 1.8

Epoch 414/1000
 - 0s - loss: 9.0273e-10 - val_loss: 7.2472e-10
Epoch 415/1000
 - 0s - loss: 8.6780e-10 - val_loss: 7.7739e-10
Epoch 416/1000
 - 0s - loss: 8.9496e-10 - val_loss: 7.2877e-10
Epoch 417/1000
 - 0s - loss: 8.7516e-10 - val_loss: 7.3012e-10
Epoch 418/1000
 - 0s - loss: 8.5609e-10 - val_loss: 7.3391e-10
Epoch 419/1000
 - 0s - loss: 8.6853e-10 - val_loss: 7.4071e-10
Epoch 420/1000
 - 0s - loss: 8.6118e-10 - val_loss: 8.0745e-10
Epoch 421/1000
 - 0s - loss: 8.6664e-10 - val_loss: 7.3013e-10
Epoch 422/1000
 - 0s - loss: 1.4032e-09 - val_loss: 1.5715e-09
Epoch 423/1000
 - 0s - loss: 1.0800e-09 - val_loss: 8.1008e-10
Epoch 424/1000
 - 0s - loss: 7.1418e-10 - val_loss: 8.1285e-10
Epoch 425/1000
 - 0s - loss: 7.1414e-10 - val_loss: 8.1011e-10
Epoch 426/1000
 - 0s - loss: 7.1226e-10 - val_loss: 8.1000e-10
Epoch 427/1000
 - 0s - loss: 6.9849e-10 - val_loss: 8.1205e-10
Epoch 428/1000
 - 0s - loss: 6.9557e-10 - val_loss: 7.6016e-10
Epoch 429/1000
 - 0s - loss: 7.4802e-10 - val_loss: 8.6

Epoch 545/1000
 - 0s - loss: 4.0998e-10 - val_loss: 3.7202e-10
Epoch 546/1000
 - 0s - loss: 4.2837e-10 - val_loss: 3.7471e-10
Epoch 547/1000
 - 0s - loss: 4.0002e-10 - val_loss: 3.8045e-10
Epoch 548/1000
 - 0s - loss: 3.9811e-10 - val_loss: 3.7201e-10
Epoch 549/1000
 - 0s - loss: 3.9967e-10 - val_loss: 3.7552e-10
Epoch 550/1000
 - 0s - loss: 4.1635e-10 - val_loss: 3.5749e-10
Epoch 551/1000
 - 0s - loss: 3.8976e-10 - val_loss: 3.7532e-10
Epoch 552/1000
 - 0s - loss: 1.0242e-09 - val_loss: 9.6623e-10
Epoch 553/1000
 - 0s - loss: 5.6695e-10 - val_loss: 3.9693e-10
Epoch 554/1000
 - 0s - loss: 3.5106e-10 - val_loss: 3.7034e-10
Epoch 555/1000
 - 0s - loss: 3.6415e-10 - val_loss: 3.9396e-10
Epoch 556/1000
 - 0s - loss: 3.6087e-10 - val_loss: 3.9732e-10
Epoch 557/1000
 - 0s - loss: 3.6919e-10 - val_loss: 3.4833e-10
Epoch 558/1000
 - 0s - loss: 3.6915e-10 - val_loss: 3.4395e-10
Epoch 559/1000
 - 0s - loss: 3.5429e-10 - val_loss: 3.9094e-10
Epoch 560/1000
 - 0s - loss: 3.6432e-10 - val_loss: 3.9

Epoch 676/1000
 - 0s - loss: 4.5027e-10 - val_loss: 2.4636e-10
Epoch 677/1000
 - 0s - loss: 4.6511e-10 - val_loss: 2.4602e-10
Epoch 678/1000
 - 0s - loss: 4.8338e-10 - val_loss: 2.3633e-10
Epoch 679/1000
 - 0s - loss: 4.6168e-10 - val_loss: 2.1416e-10
Epoch 680/1000
 - 0s - loss: 4.4047e-10 - val_loss: 2.4347e-10
Epoch 681/1000
 - 0s - loss: 5.1932e-10 - val_loss: 2.3944e-10
Epoch 682/1000
 - 0s - loss: 5.1967e-10 - val_loss: 2.4468e-10
Epoch 683/1000
 - 0s - loss: 5.3081e-10 - val_loss: 2.6414e-10
Epoch 684/1000
 - 0s - loss: 4.4782e-10 - val_loss: 1.1817e-09
Epoch 685/1000
 - 0s - loss: 5.2078e-10 - val_loss: 1.0671e-09
Epoch 686/1000
 - 0s - loss: 5.2485e-10 - val_loss: 5.6127e-10
Epoch 687/1000
 - 0s - loss: 5.5163e-10 - val_loss: 2.1962e-10
Epoch 688/1000
 - 0s - loss: 4.2515e-10 - val_loss: 6.0464e-10
Epoch 689/1000
 - 0s - loss: 5.5242e-10 - val_loss: 2.4305e-10
Epoch 690/1000
 - 0s - loss: 4.9694e-10 - val_loss: 2.4703e-10
Epoch 691/1000
 - 0s - loss: 5.3969e-10 - val_loss: 2.3

Epoch 807/1000
 - 0s - loss: 1.7859e-10 - val_loss: 1.5296e-10
Epoch 808/1000
 - 0s - loss: 3.4474e-10 - val_loss: 1.6307e-10
Epoch 809/1000
 - 0s - loss: 1.7573e-10 - val_loss: 1.6654e-10
Epoch 810/1000
 - 0s - loss: 1.7662e-10 - val_loss: 1.7090e-10
Epoch 811/1000
 - 0s - loss: 3.1387e-10 - val_loss: 1.7359e-10
Epoch 812/1000
 - 0s - loss: 1.8390e-10 - val_loss: 1.5960e-10
Epoch 813/1000
 - 0s - loss: 1.8195e-10 - val_loss: 1.6302e-10
Epoch 814/1000
 - 0s - loss: 3.6974e-10 - val_loss: 1.6390e-10
Epoch 815/1000
 - 0s - loss: 1.7825e-10 - val_loss: 1.7049e-10
Epoch 816/1000
 - 0s - loss: 1.7704e-10 - val_loss: 1.7363e-10
Epoch 817/1000
 - 0s - loss: 3.4622e-10 - val_loss: 1.7090e-10
Epoch 818/1000
 - 0s - loss: 1.7402e-10 - val_loss: 1.6060e-10
Epoch 819/1000
 - 0s - loss: 2.6016e-10 - val_loss: 3.6735e-10
Epoch 820/1000
 - 0s - loss: 2.2929e-10 - val_loss: 1.6074e-10
Epoch 821/1000
 - 0s - loss: 1.7732e-10 - val_loss: 1.6373e-10
Epoch 822/1000
 - 0s - loss: 3.6480e-10 - val_loss: 1.5

Epoch 938/1000
 - 0s - loss: 1.1340e-10 - val_loss: 1.1928e-10
Epoch 939/1000
 - 0s - loss: 1.1255e-10 - val_loss: 1.2413e-10
Epoch 940/1000
 - 0s - loss: 1.1426e-10 - val_loss: 9.9444e-11
Epoch 941/1000
 - 0s - loss: 1.1090e-10 - val_loss: 1.1961e-10
Epoch 942/1000
 - 0s - loss: 1.1585e-10 - val_loss: 1.2801e-10
Epoch 943/1000
 - 0s - loss: 1.1500e-10 - val_loss: 1.3095e-10
Epoch 944/1000
 - 0s - loss: 1.1413e-10 - val_loss: 1.3094e-10
Epoch 945/1000
 - 0s - loss: 1.1385e-10 - val_loss: 1.3838e-10
Epoch 946/1000
 - 0s - loss: 1.1880e-10 - val_loss: 1.1961e-10
Epoch 947/1000
 - 0s - loss: 1.1205e-10 - val_loss: 1.3842e-10
Epoch 948/1000
 - 0s - loss: 1.1609e-10 - val_loss: 1.1350e-10
Epoch 949/1000
 - 0s - loss: 1.1679e-10 - val_loss: 1.0677e-10
Epoch 950/1000
 - 0s - loss: 1.1119e-10 - val_loss: 1.4385e-10
Epoch 951/1000
 - 0s - loss: 1.1225e-10 - val_loss: 1.3097e-10
Epoch 952/1000
 - 0s - loss: 1.0930e-10 - val_loss: 1.0579e-10
Epoch 953/1000
 - 0s - loss: 1.1507e-10 - val_loss: 1.3

<keras.callbacks.History at 0x2e02fdc0748>

In [20]:
scores = model.evaluate(X_test, Y_test, batch_size=128)
print("Mean Squared error : ", scores)

Mean Squared error :  1.1464832555541093e-10


## Square

In [21]:
#Data Generation
fn = lambda x, y: x*x
X_train,Y_train,X_test, Y_test = generate_data(num_train=500, num_test=50,dim=100, num_sum=5, fn=fn)

In [22]:
# generate the model
ip = Input(shape=(2,))
x = NALU(units)(ip)
x = NALU(1)(x)

model = Model(ip, x)
model.summary()

optimizer = RMSprop(0.1)
model.compile(optimizer, 'mse')

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         (None, 2)                 0         
_________________________________________________________________
nalu_5 (NALU)                (None, 2)                 12        
_________________________________________________________________
nalu_6 (NALU)                (None, 1)                 6         
Total params: 18
Trainable params: 18
Non-trainable params: 0
_________________________________________________________________


In [23]:
# Train model
model.fit(X_train, Y_train, batch_size=64, epochs=1000,
          verbose=2,validation_data=(X_test, Y_test))

Train on 500 samples, validate on 50 samples
Epoch 1/1000
 - 0s - loss: 5127175413.7600 - val_loss: 2295386368.0000
Epoch 2/1000
 - 0s - loss: 1342490277.3760 - val_loss: 528971520.0000
Epoch 3/1000
 - 0s - loss: 618174364.1600 - val_loss: 572756416.0000
Epoch 4/1000
 - 0s - loss: 654317163.5200 - val_loss: 1427265792.0000
Epoch 5/1000
 - 0s - loss: 822280160.7680 - val_loss: 494668672.0000
Epoch 6/1000
 - 0s - loss: 665335855.1040 - val_loss: 509367424.0000
Epoch 7/1000
 - 0s - loss: 649993331.2000 - val_loss: 672236928.0000
Epoch 8/1000
 - 0s - loss: 661349989.8880 - val_loss: 536960256.0000
Epoch 9/1000
 - 0s - loss: 615216164.8640 - val_loss: 599231168.0000
Epoch 10/1000
 - 0s - loss: 601141562.3680 - val_loss: 462123136.0000
Epoch 11/1000
 - 0s - loss: 551805747.7120 - val_loss: 526261248.0000
Epoch 12/1000
 - 0s - loss: 685392365.5680 - val_loss: 696607872.0000
Epoch 13/1000
 - 0s - loss: 573058887.1680 - val_loss: 449845120.0000
Epoch 14/1000
 - 0s - loss: 564686653.4400 - val_l

Epoch 119/1000
 - 0s - loss: 9398052.3360 - val_loss: 9468748.0000
Epoch 120/1000
 - 0s - loss: 5065874.4000 - val_loss: 5440928.5000
Epoch 121/1000
 - 0s - loss: 12458745.6625 - val_loss: 944381.1250
Epoch 122/1000
 - 0s - loss: 3413284.4000 - val_loss: 13346239.0000
Epoch 123/1000
 - 0s - loss: 10393730.7420 - val_loss: 4676960.5000
Epoch 124/1000
 - 0s - loss: 4174562.3280 - val_loss: 8485647.0000
Epoch 125/1000
 - 0s - loss: 9903422.5660 - val_loss: 1347225.1250
Epoch 126/1000
 - 0s - loss: 985251.7890 - val_loss: 1748186.1250
Epoch 127/1000
 - 0s - loss: 10474273.2440 - val_loss: 1680644.5000
Epoch 128/1000
 - 0s - loss: 1394515.5720 - val_loss: 3756218.5000
Epoch 129/1000
 - 0s - loss: 10029265.8775 - val_loss: 647702.5000
Epoch 130/1000
 - 0s - loss: 735519.7010 - val_loss: 2060672.1250
Epoch 131/1000
 - 0s - loss: 7602471.0320 - val_loss: 5124438.0000
Epoch 132/1000
 - 0s - loss: 3458661.7930 - val_loss: 1736297.8750
Epoch 133/1000
 - 0s - loss: 7015318.2800 - val_loss: 1778810

Epoch 246/1000
 - 0s - loss: 22259.5962 - val_loss: 42740.2188
Epoch 247/1000
 - 0s - loss: 30112.8633 - val_loss: 19709.5801
Epoch 248/1000
 - 0s - loss: 33749.5635 - val_loss: 14599.7627
Epoch 249/1000
 - 0s - loss: 19615.4267 - val_loss: 15279.5215
Epoch 250/1000
 - 0s - loss: 32593.8207 - val_loss: 20311.8594
Epoch 251/1000
 - 0s - loss: 23182.9635 - val_loss: 18024.6895
Epoch 252/1000
 - 0s - loss: 28480.5812 - val_loss: 15168.4111
Epoch 253/1000
 - 0s - loss: 18997.5994 - val_loss: 13624.5967
Epoch 254/1000
 - 0s - loss: 26770.3859 - val_loss: 13857.9238
Epoch 255/1000
 - 0s - loss: 23857.9433 - val_loss: 19944.5273
Epoch 256/1000
 - 0s - loss: 26639.2535 - val_loss: 15805.7412
Epoch 257/1000
 - 0s - loss: 20525.7577 - val_loss: 16598.6602
Epoch 258/1000
 - 0s - loss: 23954.3887 - val_loss: 13871.1602
Epoch 259/1000
 - 0s - loss: 19485.8471 - val_loss: 30080.0801
Epoch 260/1000
 - 0s - loss: 29360.3496 - val_loss: 16650.6543
Epoch 261/1000
 - 0s - loss: 21531.7992 - val_loss: 132

Epoch 377/1000
 - 0s - loss: 20751.0632 - val_loss: 21927.6426
Epoch 378/1000
 - 0s - loss: 21028.0210 - val_loss: 13320.5342
Epoch 379/1000
 - 0s - loss: 20059.9285 - val_loss: 23538.6191
Epoch 380/1000
 - 0s - loss: 21175.6399 - val_loss: 13362.7852
Epoch 381/1000
 - 0s - loss: 19761.8967 - val_loss: 14058.0527
Epoch 382/1000
 - 0s - loss: 21161.2919 - val_loss: 13431.8760
Epoch 383/1000
 - 0s - loss: 21565.3588 - val_loss: 13526.2373
Epoch 384/1000
 - 0s - loss: 19381.5363 - val_loss: 13638.4688
Epoch 385/1000
 - 0s - loss: 19886.9410 - val_loss: 12981.2549
Epoch 386/1000
 - 0s - loss: 20479.6147 - val_loss: 13147.6104
Epoch 387/1000
 - 0s - loss: 19868.3638 - val_loss: 12928.6025
Epoch 388/1000
 - 0s - loss: 22482.4467 - val_loss: 12863.7363
Epoch 389/1000
 - 0s - loss: 19753.4166 - val_loss: 13314.7471
Epoch 390/1000
 - 0s - loss: 19326.1611 - val_loss: 13170.9834
Epoch 391/1000
 - 0s - loss: 20174.3295 - val_loss: 16371.1416
Epoch 392/1000
 - 0s - loss: 19142.7125 - val_loss: 155

Epoch 508/1000
 - 0s - loss: 19690.3786 - val_loss: 13963.1133
Epoch 509/1000
 - 0s - loss: 18911.9741 - val_loss: 13238.1592
Epoch 510/1000
 - 0s - loss: 19669.8245 - val_loss: 13210.5488
Epoch 511/1000
 - 0s - loss: 20191.0114 - val_loss: 13949.6924
Epoch 512/1000
 - 0s - loss: 21373.8377 - val_loss: 15034.1816
Epoch 513/1000
 - 0s - loss: 20390.9682 - val_loss: 13997.5088
Epoch 514/1000
 - 0s - loss: 20152.6442 - val_loss: 13734.3721
Epoch 515/1000
 - 0s - loss: 19023.4627 - val_loss: 13454.5137
Epoch 516/1000
 - 0s - loss: 21109.3895 - val_loss: 13575.7676
Epoch 517/1000
 - 0s - loss: 20673.6378 - val_loss: 14032.4766
Epoch 518/1000
 - 0s - loss: 19592.8453 - val_loss: 14000.5898
Epoch 519/1000
 - 0s - loss: 19669.1967 - val_loss: 14387.0879
Epoch 520/1000
 - 0s - loss: 21378.7916 - val_loss: 13817.6621
Epoch 521/1000
 - 0s - loss: 19602.9749 - val_loss: 13318.2217
Epoch 522/1000
 - 0s - loss: 20534.4752 - val_loss: 22835.5273
Epoch 523/1000
 - 0s - loss: 22101.5906 - val_loss: 132

Epoch 639/1000
 - 0s - loss: 19902.3590 - val_loss: 15397.1973
Epoch 640/1000
 - 0s - loss: 20956.3850 - val_loss: 16717.7207
Epoch 641/1000
 - 0s - loss: 20354.5264 - val_loss: 13532.1777
Epoch 642/1000
 - 0s - loss: 22044.0085 - val_loss: 13658.1152
Epoch 643/1000
 - 0s - loss: 19824.5411 - val_loss: 13077.8945
Epoch 644/1000
 - 0s - loss: 21829.1817 - val_loss: 16851.5293
Epoch 645/1000
 - 0s - loss: 20506.9714 - val_loss: 13078.0361
Epoch 646/1000
 - 0s - loss: 19680.5063 - val_loss: 13640.5273
Epoch 647/1000
 - 0s - loss: 21201.0493 - val_loss: 13211.6084
Epoch 648/1000
 - 0s - loss: 20325.5816 - val_loss: 13391.7539
Epoch 649/1000
 - 0s - loss: 21657.2894 - val_loss: 22276.8477
Epoch 650/1000
 - 0s - loss: 20224.6896 - val_loss: 13975.4434
Epoch 651/1000
 - 0s - loss: 19534.0937 - val_loss: 14154.5410
Epoch 652/1000
 - 0s - loss: 20524.6414 - val_loss: 26838.4141
Epoch 653/1000
 - 0s - loss: 20864.9513 - val_loss: 13456.2041
Epoch 654/1000
 - 0s - loss: 20763.6752 - val_loss: 134

Epoch 770/1000
 - 0s - loss: 20631.5943 - val_loss: 13780.7197
Epoch 771/1000
 - 0s - loss: 21180.8518 - val_loss: 16016.5137
Epoch 772/1000
 - 0s - loss: 20732.8186 - val_loss: 13306.7734
Epoch 773/1000
 - 0s - loss: 20811.1388 - val_loss: 15331.2305
Epoch 774/1000
 - 0s - loss: 19679.2425 - val_loss: 14955.3877
Epoch 775/1000
 - 0s - loss: 20280.4970 - val_loss: 13403.9785
Epoch 776/1000
 - 0s - loss: 19850.9292 - val_loss: 16653.1523
Epoch 777/1000
 - 0s - loss: 21619.3882 - val_loss: 17537.4980
Epoch 778/1000
 - 0s - loss: 20557.2000 - val_loss: 14195.0967
Epoch 779/1000
 - 0s - loss: 20214.3250 - val_loss: 13217.1572
Epoch 780/1000
 - 0s - loss: 19568.1443 - val_loss: 13274.8623
Epoch 781/1000
 - 0s - loss: 20648.2910 - val_loss: 16086.3672
Epoch 782/1000
 - 0s - loss: 21175.4780 - val_loss: 14765.9736
Epoch 783/1000
 - 0s - loss: 19269.8183 - val_loss: 14381.8652
Epoch 784/1000
 - 0s - loss: 20893.7916 - val_loss: 18065.0547
Epoch 785/1000
 - 0s - loss: 22620.1096 - val_loss: 135

Epoch 901/1000
 - 0s - loss: 20824.7459 - val_loss: 13186.3926
Epoch 902/1000
 - 0s - loss: 19819.7511 - val_loss: 13628.4688
Epoch 903/1000
 - 0s - loss: 19706.6964 - val_loss: 25611.7598
Epoch 904/1000
 - 0s - loss: 20869.1530 - val_loss: 15121.9971
Epoch 905/1000
 - 0s - loss: 21721.5437 - val_loss: 17610.8867
Epoch 906/1000
 - 0s - loss: 19720.8282 - val_loss: 15419.0195
Epoch 907/1000
 - 0s - loss: 20819.4180 - val_loss: 13247.9229
Epoch 908/1000
 - 0s - loss: 21120.1945 - val_loss: 14253.2910
Epoch 909/1000
 - 0s - loss: 20790.7762 - val_loss: 13936.3271
Epoch 910/1000
 - 0s - loss: 19771.2827 - val_loss: 14705.4385
Epoch 911/1000
 - 0s - loss: 20706.4268 - val_loss: 13450.3896
Epoch 912/1000
 - 0s - loss: 20394.6987 - val_loss: 13555.9199
Epoch 913/1000
 - 0s - loss: 22157.3025 - val_loss: 13895.9570
Epoch 914/1000
 - 0s - loss: 20148.7686 - val_loss: 13280.0889
Epoch 915/1000
 - 0s - loss: 19780.3194 - val_loss: 15415.2803
Epoch 916/1000
 - 0s - loss: 21159.2664 - val_loss: 135

<keras.callbacks.History at 0x2e0311d8e48>

In [24]:
scores = model.evaluate(X_test, Y_test, batch_size=128)
print("Mean Squared error : ", scores)

Mean Squared error :  12942.78515625


## SquareRoot

In [25]:
#Data Generation
fn = lambda x, y: np.sqrt(x)
X_train,Y_train,X_test, Y_test = generate_data(num_train=500, num_test=50,dim=100, num_sum=5, fn=fn)

In [26]:
# generate the model
ip = Input(shape=(2,))
x = NALU(units)(ip)
x = NALU(1)(x)

model = Model(ip, x)
model.summary()

optimizer = RMSprop(0.1)
model.compile(optimizer, 'mse')

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         (None, 2)                 0         
_________________________________________________________________
nalu_7 (NALU)                (None, 2)                 12        
_________________________________________________________________
nalu_8 (NALU)                (None, 1)                 6         
Total params: 18
Trainable params: 18
Non-trainable params: 0
_________________________________________________________________


In [27]:
# Train model
model.fit(X_train, Y_train, batch_size=64, epochs=1000,
          verbose=2,validation_data=(X_test, Y_test))

Train on 500 samples, validate on 50 samples
Epoch 1/1000
 - 0s - loss: 226.1798 - val_loss: 16.6493
Epoch 2/1000
 - 0s - loss: 13.9168 - val_loss: 7.4314
Epoch 3/1000
 - 0s - loss: 10.4187 - val_loss: 6.0464
Epoch 4/1000
 - 0s - loss: 8.0741 - val_loss: 6.7963
Epoch 5/1000
 - 0s - loss: 8.0076 - val_loss: 4.2951
Epoch 6/1000
 - 0s - loss: 8.4123 - val_loss: 2.9075
Epoch 7/1000
 - 0s - loss: 8.6713 - val_loss: 15.2789
Epoch 8/1000
 - 0s - loss: 7.4999 - val_loss: 8.6747
Epoch 9/1000
 - 0s - loss: 8.6051 - val_loss: 4.9712
Epoch 10/1000
 - 0s - loss: 8.8856 - val_loss: 5.8948
Epoch 11/1000
 - 0s - loss: 7.5010 - val_loss: 9.4879
Epoch 12/1000
 - 0s - loss: 11.3245 - val_loss: 5.9011
Epoch 13/1000
 - 0s - loss: 6.4124 - val_loss: 5.8721
Epoch 14/1000
 - 0s - loss: 5.8777 - val_loss: 6.1588
Epoch 15/1000
 - 0s - loss: 7.7243 - val_loss: 3.6264
Epoch 16/1000
 - 0s - loss: 6.9761 - val_loss: 4.5992
Epoch 17/1000
 - 0s - loss: 20.0196 - val_loss: 1.6499
Epoch 18/1000
 - 0s - loss: 1.6407 - v

Epoch 151/1000
 - 0s - loss: 1.1205 - val_loss: 0.7976
Epoch 152/1000
 - 0s - loss: 1.8793 - val_loss: 1.8071
Epoch 153/1000
 - 0s - loss: 0.8562 - val_loss: 1.4016
Epoch 154/1000
 - 0s - loss: 1.8659 - val_loss: 0.9528
Epoch 155/1000
 - 0s - loss: 1.1536 - val_loss: 1.6779
Epoch 156/1000
 - 0s - loss: 1.4056 - val_loss: 2.1629
Epoch 157/1000
 - 0s - loss: 1.3695 - val_loss: 1.2291
Epoch 158/1000
 - 0s - loss: 1.2943 - val_loss: 2.5279
Epoch 159/1000
 - 0s - loss: 1.3119 - val_loss: 1.2659
Epoch 160/1000
 - 0s - loss: 1.5187 - val_loss: 1.1797
Epoch 161/1000
 - 0s - loss: 1.1467 - val_loss: 2.1051
Epoch 162/1000
 - 0s - loss: 1.4860 - val_loss: 0.8004
Epoch 163/1000
 - 0s - loss: 1.2306 - val_loss: 1.7197
Epoch 164/1000
 - 0s - loss: 1.2971 - val_loss: 1.2460
Epoch 165/1000
 - 0s - loss: 1.3667 - val_loss: 1.4917
Epoch 166/1000
 - 0s - loss: 1.2708 - val_loss: 1.6631
Epoch 167/1000
 - 0s - loss: 1.2192 - val_loss: 1.8980
Epoch 168/1000
 - 0s - loss: 1.4020 - val_loss: 1.2767
Epoch 169/

Epoch 300/1000
 - 0s - loss: 1.1224 - val_loss: 1.9570
Epoch 301/1000
 - 0s - loss: 1.1508 - val_loss: 1.2825
Epoch 302/1000
 - 0s - loss: 1.2482 - val_loss: 1.0020
Epoch 303/1000
 - 0s - loss: 1.1928 - val_loss: 1.5024
Epoch 304/1000
 - 0s - loss: 1.1226 - val_loss: 1.5867
Epoch 305/1000
 - 0s - loss: 1.0818 - val_loss: 1.6832
Epoch 306/1000
 - 0s - loss: 1.3252 - val_loss: 0.7162
Epoch 307/1000
 - 0s - loss: 1.1940 - val_loss: 1.2650
Epoch 308/1000
 - 0s - loss: 1.1718 - val_loss: 1.9720
Epoch 309/1000
 - 0s - loss: 1.0103 - val_loss: 0.9877
Epoch 310/1000
 - 0s - loss: 1.2818 - val_loss: 1.2635
Epoch 311/1000
 - 0s - loss: 1.1922 - val_loss: 1.3106
Epoch 312/1000
 - 0s - loss: 1.2202 - val_loss: 1.5483
Epoch 313/1000
 - 0s - loss: 1.1642 - val_loss: 0.9972
Epoch 314/1000
 - 0s - loss: 1.2059 - val_loss: 1.6166
Epoch 315/1000
 - 0s - loss: 1.1837 - val_loss: 1.1211
Epoch 316/1000
 - 0s - loss: 1.1419 - val_loss: 0.9588
Epoch 317/1000
 - 0s - loss: 1.2334 - val_loss: 1.1991
Epoch 318/

Epoch 449/1000
 - 0s - loss: 1.2265 - val_loss: 0.7230
Epoch 450/1000
 - 0s - loss: 1.1625 - val_loss: 1.5850
Epoch 451/1000
 - 0s - loss: 1.1377 - val_loss: 1.5895
Epoch 452/1000
 - 0s - loss: 1.1703 - val_loss: 2.2504
Epoch 453/1000
 - 0s - loss: 1.1250 - val_loss: 2.4948
Epoch 454/1000
 - 0s - loss: 1.1617 - val_loss: 1.3017
Epoch 455/1000
 - 0s - loss: 1.2871 - val_loss: 1.6562
Epoch 456/1000
 - 0s - loss: 1.0530 - val_loss: 1.3249
Epoch 457/1000
 - 0s - loss: 1.2402 - val_loss: 0.9530
Epoch 458/1000
 - 0s - loss: 1.1981 - val_loss: 1.8907
Epoch 459/1000
 - 0s - loss: 1.0936 - val_loss: 1.8602
Epoch 460/1000
 - 0s - loss: 1.2279 - val_loss: 1.1232
Epoch 461/1000
 - 0s - loss: 1.2462 - val_loss: 1.1069
Epoch 462/1000
 - 0s - loss: 1.0614 - val_loss: 1.5148
Epoch 463/1000
 - 0s - loss: 1.2401 - val_loss: 1.3002
Epoch 464/1000
 - 0s - loss: 1.1922 - val_loss: 1.0942
Epoch 465/1000
 - 0s - loss: 1.1813 - val_loss: 1.8609
Epoch 466/1000
 - 0s - loss: 1.1538 - val_loss: 0.7465
Epoch 467/

Epoch 598/1000
 - 0s - loss: 1.2731 - val_loss: 0.7514
Epoch 599/1000
 - 0s - loss: 1.1079 - val_loss: 1.2546
Epoch 600/1000
 - 0s - loss: 1.0506 - val_loss: 1.5075
Epoch 601/1000
 - 0s - loss: 1.2781 - val_loss: 0.9668
Epoch 602/1000
 - 0s - loss: 1.2193 - val_loss: 1.0491
Epoch 603/1000
 - 0s - loss: 1.1234 - val_loss: 1.6630
Epoch 604/1000
 - 0s - loss: 1.1151 - val_loss: 1.2287
Epoch 605/1000
 - 0s - loss: 1.3413 - val_loss: 1.1174
Epoch 606/1000
 - 0s - loss: 1.1627 - val_loss: 0.7038
Epoch 607/1000
 - 0s - loss: 1.1642 - val_loss: 1.1405
Epoch 608/1000
 - 0s - loss: 1.1921 - val_loss: 0.9128
Epoch 609/1000
 - 0s - loss: 1.1032 - val_loss: 2.3561
Epoch 610/1000
 - 0s - loss: 1.2531 - val_loss: 0.8115
Epoch 611/1000
 - 0s - loss: 1.1368 - val_loss: 1.4406
Epoch 612/1000
 - 0s - loss: 1.1370 - val_loss: 1.6987
Epoch 613/1000
 - 0s - loss: 1.3631 - val_loss: 1.0859
Epoch 614/1000
 - 0s - loss: 1.0397 - val_loss: 2.0141
Epoch 615/1000
 - 0s - loss: 1.1092 - val_loss: 1.1690
Epoch 616/

Epoch 747/1000
 - 0s - loss: 1.2376 - val_loss: 1.8167
Epoch 748/1000
 - 0s - loss: 1.1690 - val_loss: 0.7764
Epoch 749/1000
 - 0s - loss: 1.2175 - val_loss: 1.7203
Epoch 750/1000
 - 0s - loss: 1.1202 - val_loss: 1.2368
Epoch 751/1000
 - 0s - loss: 1.3203 - val_loss: 1.0104
Epoch 752/1000
 - 0s - loss: 1.0317 - val_loss: 1.8284
Epoch 753/1000
 - 0s - loss: 1.2138 - val_loss: 1.0677
Epoch 754/1000
 - 0s - loss: 1.2509 - val_loss: 1.3406
Epoch 755/1000
 - 0s - loss: 1.1015 - val_loss: 1.0671
Epoch 756/1000
 - 0s - loss: 1.1832 - val_loss: 1.2323
Epoch 757/1000
 - 0s - loss: 1.2176 - val_loss: 1.8106
Epoch 758/1000
 - 0s - loss: 1.1210 - val_loss: 1.8292
Epoch 759/1000
 - 0s - loss: 1.1947 - val_loss: 1.4154
Epoch 760/1000
 - 0s - loss: 1.1418 - val_loss: 1.0543
Epoch 761/1000
 - 0s - loss: 1.2752 - val_loss: 1.4815
Epoch 762/1000
 - 0s - loss: 1.0364 - val_loss: 1.3510
Epoch 763/1000
 - 0s - loss: 1.3625 - val_loss: 1.2462
Epoch 764/1000
 - 0s - loss: 1.0533 - val_loss: 2.5435
Epoch 765/

Epoch 896/1000
 - 0s - loss: 1.2257 - val_loss: 1.5861
Epoch 897/1000
 - 0s - loss: 1.0807 - val_loss: 1.6231
Epoch 898/1000
 - 0s - loss: 1.2278 - val_loss: 0.6814
Epoch 899/1000
 - 0s - loss: 1.2849 - val_loss: 1.1830
Epoch 900/1000
 - 0s - loss: 1.0576 - val_loss: 1.5866
Epoch 901/1000
 - 0s - loss: 1.1630 - val_loss: 1.2053
Epoch 902/1000
 - 0s - loss: 1.2779 - val_loss: 1.2584
Epoch 903/1000
 - 0s - loss: 0.9788 - val_loss: 1.4656
Epoch 904/1000
 - 0s - loss: 1.4409 - val_loss: 1.2773
Epoch 905/1000
 - 0s - loss: 0.9705 - val_loss: 2.1126
Epoch 906/1000
 - 0s - loss: 1.2690 - val_loss: 1.0111
Epoch 907/1000
 - 0s - loss: 1.1111 - val_loss: 1.5803
Epoch 908/1000
 - 0s - loss: 1.2570 - val_loss: 0.7638
Epoch 909/1000
 - 0s - loss: 1.1494 - val_loss: 2.1544
Epoch 910/1000
 - 0s - loss: 1.1055 - val_loss: 1.2228
Epoch 911/1000
 - 0s - loss: 1.3760 - val_loss: 1.1446
Epoch 912/1000
 - 0s - loss: 1.0330 - val_loss: 1.9763
Epoch 913/1000
 - 0s - loss: 1.1306 - val_loss: 1.2574
Epoch 914/

<keras.callbacks.History at 0x2e02e4aa898>

In [28]:
scores = model.evaluate(X_test, Y_test, batch_size=128)
print("Mean Squared error : ", scores)

Mean Squared error :  1.047717571258545
