In [1]:
from nalu import NAC, NALU, INALU
from utils import train_model, train_inalu_model
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch import optim
import random
import torch
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
class SyntheticDataset(Dataset):
    def __init__(self, X, y):
        assert len(X) == len(y)
        self.X = X.view(-1, 2)
        self.y = y.view(-1, 1)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [3]:
DIM = 100
NUM_SUM = 5
RANGE_TRAIN = [-5, 10]
RANGE_TEST = [-10, 15]
NUM_TRAIN = 500
NUM_TEST = 50

In [4]:
def generate_data_interpolation(fn):
    data = torch.FloatTensor(DIM).uniform_(*RANGE_TRAIN).unsqueeze_(1)
    X, y = [], []
    for _ in range(NUM_TRAIN + NUM_TEST):
        idx_a = random.sample(range(DIM), NUM_SUM)
        idx_b = random.sample([x for x in range(DIM) if x not in idx_a], NUM_SUM)
        a, b = data[idx_a].sum(), data[idx_b].sum()
        X.append([a, b])
        y.append(fn(a, b))
    X = torch.FloatTensor(X)
    y = torch.FloatTensor(y)
    X_train, y_train = X[:NUM_TRAIN], y[:NUM_TRAIN]
    X_test, y_test = X[NUM_TRAIN:], y[NUM_TRAIN:]
    return X_train, y_train, X_test, y_test

In [5]:
def generate_data_extrapolation(fn):
    data = torch.FloatTensor(DIM).uniform_(*RANGE_TRAIN).unsqueeze_(1)
    X, y = [], []
    for _ in range(NUM_TRAIN):
        idx_a = random.sample(range(DIM), NUM_SUM)
        idx_b = random.sample([x for x in range(DIM) if x not in idx_a], NUM_SUM)
        a, b = data[idx_a].sum(), data[idx_b].sum()
        X.append([a, b])
        y.append(fn(a, b))
    X_train = torch.FloatTensor(X)
    y_train = torch.FloatTensor(y)

    data = torch.FloatTensor(DIM).uniform_(*RANGE_TEST).unsqueeze_(1)
    X, y = [], []
    for _ in range(NUM_TEST):
        idx_a = random.sample(range(DIM), NUM_SUM)
        idx_b = random.sample([x for x in range(DIM) if x not in idx_a], NUM_SUM)
        a, b = data[idx_a].sum(), data[idx_b].sum()
        X.append([a, b])
        y.append(fn(a, b))
    X_test = torch.FloatTensor(X)
    y_test = torch.FloatTensor(y)

    return X_train, y_train, X_test, y_test

In [6]:
class MLP(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = nn.Sequential(nn.Linear(2, 2),
                                   nn.Linear(2, 1))

    def forward(self, x):
        return self.model(x)

In [7]:
class MLPReLU6(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = nn.Sequential(nn.Linear(2, 2),
                                   nn.ReLU6(),
                                   nn.Linear(2, 1))

    def forward(self, x):
        return self.model(x)

In [8]:
class NACNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = nn.Sequential(NAC(2, 2),
                                   NAC(2, 1))

    def forward(self, x):
        return self.model(x)

In [9]:
class NALUNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = nn.Sequential(NALU(2, 2),
                                   NALU(2, 1))

    def forward(self, x):
        return self.model(x)

In [10]:
class INALUNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.inalu1 = INALU(2, 2)
        self.inalu2 = INALU(2, 1)

    def forward(self, x):
        x = self.inalu1(x)
        return self.inalu2(x)


    def reinitialize(self):
        self.inalu1.reinitialize()
        self.inalu2.reinitialize()

    def reg_loss(self):
        return self.inalu1.reg_loss() + self.inalu2.reg_loss()

In [11]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def create_models():
    return [
        ("MLP", MLP().to(DEVICE)),
        ("MLPReLU6", MLPReLU6().to(DEVICE)),
        ("NAC", NACNet().to(DEVICE)),
        ("NALU", NALUNet().to(DEVICE)),
        ("INALU", INALUNet().to(DEVICE))
    ]

ARITHMETIC_FUNCTIONS = {
    'add': lambda x, y: x + y,
    'sub': lambda x, y: x - y,
    'mul': lambda x, y: x * y,
    'mul_neg': lambda x, y: -x * y,
    'div': lambda x, y: x / y,
    'squared': lambda x, y: torch.pow(x, 2),
    'root': lambda x, y: torch.sqrt(torch.abs(x)),
}

In [12]:
def check_function(func_name, generator):
    data = generator(ARITHMETIC_FUNCTIONS[func_name])
    dataset_train = SyntheticDataset(data[0], data[1])
    dataset_test = SyntheticDataset(data[2], data[3])
    train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)
    test_loader = DataLoader(dataset_test, batch_size=64, shuffle=True)

    interp_result = []
    for model_ind, (model_name, model) in enumerate(create_models()):
        num_epochs = 700
        dataloaders = {'train': train_loader, 'test': test_loader}
        criterion = nn.MSELoss()
        optimizer = optim.RMSprop(model.parameters(), lr=1e-2)

        name = f'{model_name}_{func_name}'
        if model_name == 'INALU':
            _, train_loss_his, test_loss_his = train_inalu_model(model, dataloaders, criterion, optimizer,
                                                    None, num_epochs, DEVICE, name)
        else:
            _, train_loss_his, test_loss_his = train_model(model, dataloaders, criterion, optimizer,
                                                  None, num_epochs, DEVICE, name)

        min_loss = float(torch.tensor(test_loss_his).min())
        print(model_name, min_loss)
        interp_result.append(min_loss)
    return interp_result

In [15]:
interp_results = []
interp_results.append(check_function('add', generate_data_interpolation))

  0%|          | 0/700 [00:00<?, ?it/s]

MLP 9.829149348661304e-06


  0%|          | 0/700 [00:00<?, ?it/s]

MLPReLU6 0.7062729001045227


  0%|          | 0/700 [00:00<?, ?it/s]

NAC 1.164153192248496e-12


  0%|          | 0/700 [00:00<?, ?it/s]

NALU 1.153696894645691


  0%|          | 0/700 [00:00<?, ?it/s]

INALU 0.0016753372037783265


In [16]:
res = check_function('sub', generate_data_interpolation)
interp_results.append(res)

  0%|          | 0/700 [00:00<?, ?it/s]

MLP 5.779148750661989e-07


  0%|          | 0/700 [00:00<?, ?it/s]

MLPReLU6 0.0399162583053112


  0%|          | 0/700 [00:00<?, ?it/s]

NAC 2.3533174649432997e-13


  0%|          | 0/700 [00:00<?, ?it/s]

NALU 10.48787784576416


  0%|          | 0/700 [00:00<?, ?it/s]

INALU 0.0011900915997102857


In [17]:
res = check_function('mul', generate_data_interpolation)
interp_results.append(res)

  0%|          | 0/700 [00:00<?, ?it/s]

MLP 5548.12841796875


  0%|          | 0/700 [00:00<?, ?it/s]

MLPReLU6 16035.5322265625


  0%|          | 0/700 [00:00<?, ?it/s]

NAC 15436.111328125


  0%|          | 0/700 [00:00<?, ?it/s]

NALU 852.943115234375


  0%|          | 0/700 [00:00<?, ?it/s]

INALU 0.0015692510642111301


In [18]:
res = check_function('mul_neg', generate_data_interpolation)
interp_results.append(res)

  0%|          | 0/700 [00:00<?, ?it/s]

MLP 8699.7119140625


  0%|          | 0/700 [00:00<?, ?it/s]

MLPReLU6 8212.140625


  0%|          | 0/700 [00:00<?, ?it/s]

NAC 26616.609375


  0%|          | 0/700 [00:00<?, ?it/s]

NALU 3232.18701171875


  0%|          | 0/700 [00:00<?, ?it/s]

INALU 1.607330322265625


In [19]:
res = check_function('div', generate_data_interpolation)
interp_results.append(res)

  0%|          | 0/700 [00:00<?, ?it/s]

MLP 82.71844482421875


  0%|          | 0/700 [00:00<?, ?it/s]

MLPReLU6 17.342254638671875


  0%|          | 0/700 [00:00<?, ?it/s]

NAC 86.12367248535156


  0%|          | 0/700 [00:00<?, ?it/s]

NALU 67.85159301757812


  0%|          | 0/700 [00:00<?, ?it/s]

INALU 82.26152038574219


In [20]:
res = check_function('squared', generate_data_interpolation)
interp_results.append(res)

  0%|          | 0/700 [00:00<?, ?it/s]

MLP 24634.0078125


  0%|          | 0/700 [00:00<?, ?it/s]

MLPReLU6 7722.87158203125


  0%|          | 0/700 [00:00<?, ?it/s]

NAC 121214.4609375


  0%|          | 0/700 [00:00<?, ?it/s]

NALU 1289.3287353515625


  0%|          | 0/700 [00:00<?, ?it/s]

INALU 0.0016088738339021802


In [21]:
res = check_function('root', generate_data_interpolation)
interp_results.append(res)

  0%|          | 0/700 [00:00<?, ?it/s]

MLP 0.15651340782642365


  0%|          | 0/700 [00:00<?, ?it/s]

MLPReLU6 0.06803086400032043


  0%|          | 0/700 [00:00<?, ?it/s]

NAC 0.951472818851471


  0%|          | 0/700 [00:00<?, ?it/s]

NALU 1.4578239643014967e-05


  0%|          | 0/700 [00:00<?, ?it/s]

INALU 0.16032104194164276


In [32]:
extr_results = []
extr_results.append(check_function('add', generate_data_extrapolation))

  0%|          | 0/700 [00:00<?, ?it/s]

MLP 7.860350592636678e-09


  0%|          | 0/700 [00:00<?, ?it/s]

MLPReLU6 43.48899841308594


  0%|          | 0/700 [00:00<?, ?it/s]

NAC 7.685229948088679e-13


  0%|          | 0/700 [00:00<?, ?it/s]

NALU 17.86178970336914


  0%|          | 0/700 [00:00<?, ?it/s]

INALU 0.0014470197493210435


In [21]:
res = check_function('sub', generate_data_extrapolation)
extr_results.append(res)

  0%|          | 0/700 [00:00<?, ?it/s]

MLP 1.1195002116437536e-06


  0%|          | 0/700 [00:00<?, ?it/s]

MLPReLU6 8.87283706665039


  0%|          | 0/700 [00:00<?, ?it/s]

NAC 3.548450324807745e-13


  0%|          | 0/700 [00:00<?, ?it/s]

NALU 35.97146987915039


  0%|          | 0/700 [00:00<?, ?it/s]

INALU 0.00023707430227659643


In [24]:
res = check_function('mul', generate_data_extrapolation)
extr_results.append(res)

  0%|          | 0/700 [00:00<?, ?it/s]

MLP 71244.9296875


  0%|          | 0/700 [00:00<?, ?it/s]

MLPReLU6 54325.22265625


  0%|          | 0/700 [00:00<?, ?it/s]

NAC 74272.1953125


  0%|          | 0/700 [00:00<?, ?it/s]

NALU 33298.88671875


  0%|          | 0/700 [00:00<?, ?it/s]

INALU 1.5117969098810136e-08


In [35]:
res = check_function('mul_neg', generate_data_extrapolation)
extr_results.append(res)

  0%|          | 0/700 [00:00<?, ?it/s]

MLP 49386.30078125


  0%|          | 0/700 [00:00<?, ?it/s]

MLPReLU6 36491.5


  0%|          | 0/700 [00:00<?, ?it/s]

NAC 75911.4296875


  0%|          | 0/700 [00:00<?, ?it/s]

NALU 19887.525390625


  0%|          | 0/700 [00:00<?, ?it/s]

INALU 2.5017786811076803e-06


In [36]:
res = check_function('div', generate_data_extrapolation)
extr_results.append(res)

  0%|          | 0/700 [00:00<?, ?it/s]

MLP 1343.95263671875


  0%|          | 0/700 [00:00<?, ?it/s]

MLPReLU6 1373.087158203125


  0%|          | 0/700 [00:00<?, ?it/s]

NAC 1347.36279296875


  0%|          | 0/700 [00:00<?, ?it/s]

NALU 1392.420654296875


  0%|          | 0/700 [00:00<?, ?it/s]

INALU 1295.3865966796875


In [26]:
res = check_function('squared', generate_data_extrapolation)
extr_results.append(res)

  0%|          | 0/700 [00:00<?, ?it/s]

MLP 92861.1640625


  0%|          | 0/700 [00:00<?, ?it/s]

MLPReLU6 209992.859375


  0%|          | 0/700 [00:00<?, ?it/s]

NAC 355227.03125


  0%|          | 0/700 [00:00<?, ?it/s]

NALU 129.38681030273438


  0%|          | 0/700 [00:00<?, ?it/s]

INALU 0.20869147777557373


In [38]:
res = check_function('root', generate_data_extrapolation)
extr_results.append(res)

  0%|          | 0/700 [00:00<?, ?it/s]

MLP 1.454134464263916


  0%|          | 0/700 [00:00<?, ?it/s]

MLPReLU6 0.20828643441200256


  0%|          | 0/700 [00:00<?, ?it/s]

NAC 5.972806453704834


  0%|          | 0/700 [00:00<?, ?it/s]

NALU 0.007435419596731663


  0%|          | 0/700 [00:00<?, ?it/s]

INALU 3.025660276412964


In [30]:
import numpy as np

interp_results = np.array(interp_results).T
extr_results = np.array(extr_results).T

pd.set_option('display.float_format', lambda x: '%.3f' % x)
df = pd.DataFrame({'Function': ARITHMETIC_FUNCTIONS.keys(),
                   'INTERPOLATION_MLP': interp_results[0],
                   'INTERPOLATION_MLP_ReLU6': interp_results[1],
                   'INTERPOLATION_NAC': interp_results[2],
                   'INTERPOLATION_NALU': interp_results[3],
                   'INTERPOLATION_INALU': interp_results[4],
                   'EXTRAPOLATION_MLP': extr_results[0],
                   'EXTRAPOLATION_MLP_ReLU6': extr_results[1],
                   'EXTRAPOLATION_NAC': extr_results[2],
                   'EXTRAPOLATION_NALU': extr_results[3],
                   'EXTRAPOLATION_INALU': extr_results[4]
                   })

df.to_csv('raw_nalu_check_results.csv')
# df = pd.read_csv('raw_nalu_check_results.csv')
df_inter = df[['Function', 'INTERPOLATION_MLP', 'INTERPOLATION_MLP_ReLU6', 'INTERPOLATION_NAC', 'INTERPOLATION_NALU', 'INTERPOLATION_INALU']]
df_extr = df[['Function', 'EXTRAPOLATION_MLP', 'EXTRAPOLATION_MLP_ReLU6', 'EXTRAPOLATION_NAC', 'EXTRAPOLATION_NALU', 'EXTRAPOLATION_INALU']]

In [31]:
df_inter

Unnamed: 0,Function,INTERPOLATION_MLP,INTERPOLATION_MLP_ReLU6,INTERPOLATION_NAC,INTERPOLATION_NALU,INTERPOLATION_INALU
0,add,0.0,0.706,0.0,1.154,0.002
1,sub,0.0,0.04,0.0,10.488,0.001
2,mul,5548.128,16035.532,15436.111,852.943,0.002
3,mul_neg,8699.712,8212.141,26616.609,3232.187,1.607
4,div,82.718,17.342,86.124,67.852,82.262
5,squared,24634.008,7722.872,121214.461,1289.329,0.002
6,root,0.157,0.068,0.951,0.0,0.16


In [32]:
df_extr

Unnamed: 0,Function,EXTRAPOLATION_MLP,EXTRAPOLATION_MLP_ReLU6,EXTRAPOLATION_NAC,EXTRAPOLATION_NALU,EXTRAPOLATION_INALU
0,add,0.0,43.489,0.0,17.862,0.001
1,sub,0.0,8.873,0.0,35.971,0.0
2,mul,71244.93,54325.223,74272.195,33298.887,0.0
3,mul_neg,49386.301,36491.5,75911.43,19887.525,0.0
4,div,1343.953,1373.087,1347.363,1392.421,1295.387
5,squared,92861.164,209992.859,355227.031,129.387,0.209
6,root,1.454,0.208,5.973,0.007,3.026


In [36]:
import torch.nn.functional as F

def mean_random_mse(generator):
    mses = []
    for func_name, func in ARITHMETIC_FUNCTIONS.items():
        arr = []
        data = generator(func)
        dataset_train = SyntheticDataset(data[0], data[1])
        dataset_test = SyntheticDataset(data[2], data[3])
        train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)
        test_loader = DataLoader(dataset_test, batch_size=64, shuffle=True)
        for _ in range(100):
            running_loss = 0.
            for input, target in test_loader:
                input = input.to(DEVICE)
                target = target.to(DEVICE)
                model = MLP().to(DEVICE)
                out = model(input)
                running_loss += F.mse_loss(out, target).item() * input.size(0)
            arr.append(running_loss / len(test_loader.dataset))
        mses.append(np.mean(arr))
    return mses

mses_interp = mean_random_mse(generate_data_interpolation)
mses_extr = mean_random_mse(generate_data_extrapolation)

In [39]:
for i, scale in enumerate(mses_interp):
    for key in ['INTERPOLATION_MLP', 'INTERPOLATION_MLP_ReLU6', 'INTERPOLATION_NAC', 'INTERPOLATION_NALU', 'INTERPOLATION_INALU']:
        df_inter.at[i, key] = 100. * df_inter[key][i] / scale

for i, scale in enumerate(mses_extr):
    for key in ['EXTRAPOLATION_MLP', 'EXTRAPOLATION_MLP_ReLU6', 'EXTRAPOLATION_NAC', 'EXTRAPOLATION_NALU', 'EXTRAPOLATION_INALU']:
        df_extr.at[i, key] = 100. * df_extr[key][i] / scale

In [40]:
df_inter

Unnamed: 0,Function,INTERPOLATION_MLP,INTERPOLATION_MLP_ReLU6,INTERPOLATION_NAC,INTERPOLATION_NALU,INTERPOLATION_INALU
0,add,0.0,0.064,0.0,0.104,0.0
1,sub,0.0,0.018,0.0,4.697,0.001
2,mul,40.175,116.116,111.775,6.176,0.0
3,mul_neg,17.676,16.685,54.079,6.567,0.003
4,div,73.904,15.494,76.946,60.621,73.496
5,squared,13.338,4.182,65.633,0.698,0.0
6,root,0.319,0.139,1.94,0.0,0.327


In [41]:
df_extr

Unnamed: 0,Function,EXTRAPOLATION_MLP,EXTRAPOLATION_MLP_ReLU6,EXTRAPOLATION_NAC,EXTRAPOLATION_NALU,EXTRAPOLATION_INALU
0,add,0.0,4.593,0.0,1.886,0.0
1,sub,0.0,2.077,0.0,8.42,0.0
2,mul,47.831,36.472,49.863,22.355,0.0
3,mul_neg,52.859,39.058,81.25,21.286,0.0
4,div,227.03,231.952,227.607,235.218,218.826
5,squared,34.425,77.848,131.688,0.048,0.0
6,root,1.986,0.284,8.158,0.01,4.132


In [42]:
df_inter.to_csv('scaled_interp_check_results.csv')
df_extr.to_csv('scaled_extr_check_results.csv')