In [1]:
import torch
import torch.nn.functional as F

import numpy as np

from tqdm import tqdm_notebook as tqdm

from data_generator_helper import generate_synthetic_selection_dataset
from models_new.nalu_b import NALU
from models_new.nac import NAC

from torchvision import datasets
import torchvision.models as models
import torchvision.utils as vutils
from tensorboardX import SummaryWriter

import datetime
import os

import matplotlib
import matplotlib.pyplot as plt
#from mpl_toolkits.mplot3d import Axes3D as plt3

from ipywidgets import interactive
from ipywidgets import widgets

In [2]:
def reportLoss(loss):
    print(loss)
    
def train(model, optimizer, x_train, y_train, epochs, batch_size, model_param):
    '''
    if model_param == "NAC":
        weights = np.zeros((epochs,len(x_train)//batch_size,out_dim,sample_size,3))
    elif model_param == "NALU":
        weights = np.zeros((epochs,len(x_train)//batch_size,out_dim,sample_size,4))
        g_prev = np.zeros((epochs,len(x_train)//batch_size,out_dim))
    losses = np.zeros((epochs,len(x_train)//batch_size))
    '''
    for epoch in range(epochs):
        for batch in range(len(x_train) // batch_size):
            '''
            # Save weights
            for children in model.model.children():
                plist = [param for param in children.parameters()]
                if model_param == "NAC":
                    W_hat = plist[0]
                    M_hat = plist[1]
                elif model_param == "NALU":
                    W_hat = plist[1]
                    M_hat = plist[2]
                    weights[epoch,batch,:,:,3] = plist[0].detach().numpy()
                W = torch.tanh(W_hat) * torch.sigmoid(M_hat)
                weights[epoch,batch,:,:,0] = W_hat.detach().numpy()
                weights[epoch,batch,:,:,1] = M_hat.detach().numpy()
                weights[epoch,batch,:,:,2] = W.detach().numpy()
            '''
            model.train()
            optimizer.zero_grad()

            x_batch_train = x_train[batch:(batch+batch_size),:]
            y_batch_train = y_train[batch:(batch+batch_size),:]
            '''
            if model_param == "NALU":
                g_prev[epoch,batch,:] = F.linear(plist[0],x_batch_train,).detach().numpy().flatten()
                #print(np.shape(g_prev),np.shape(F.linear(plist[0],x_batch_train).detach().numpy()))
            '''
            out = model(x_batch_train)
            
            loss = F.mse_loss(out, y_batch_train)
            
            if loss != loss:
                print("nan detected")
            #losses[epoch,batch] = loss
            loss.backward()
            optimizer.step()
    
           
    return test(model,x_train,y_train)

     
    
def test(model, x_test, y_test):
    model.eval()
    output_test = model(x_test)
    loss = F.mse_loss(output_test, y_test)
    return loss

In [3]:
save_all = False

init = 'Kai_uni'
model_param =  "NALU"

test_per_range = 10
sample_size = 100
set_size = 200

in_dim = sample_size
hidden_dim = 1
out_dim = 2
num_layers = 1

lr = 0.05
epochs = 1000
batch_size = 1
#values = [1] #np.linspace(1,1000,10)
values = [[0,1],[-1,1],[0,10],[-10,10]]
extr_scale = [5,10,100]
extr_values = [[0,1],[-1,1]]

datadict = {}
for value in values:
    datadict[str(value)] = []
    for e_s in extr_scale:
        for e_val in extr_values:
            datadict[str(value)+str(e_s)+str(e_val)] = []


for j, value in tqdm(enumerate(values)):

    min_value = value[0]
    max_value = value[1]

    print("Train range: ["+str(min_value)+","+str(max_value)+"]")
    for k in range(test_per_range):
        #i=0
        
        model = NALU(num_layers, in_dim, hidden_dim, out_dim, init)
        optimizer = torch.optim.RMSprop(model.parameters(),lr=lr)

        x_train, y_train, boundaries = generate_synthetic_selection_dataset(min_value, 
                                                                            max_value, sample_size, 
                                                                            set_size, boundaries = None)

        x_test, y_test, _ = generate_synthetic_selection_dataset(min_value, max_value,
                                                         sample_size, set_size, boundaries = boundaries)

        x_train = x_train.type(torch.DoubleTensor)
        y_train = y_train.type(torch.DoubleTensor)
        x_test = x_test.type(torch.DoubleTensor)
        y_test = y_test.type(torch.DoubleTensor)

        loss = train(model, optimizer, x_train, y_train, epochs, batch_size, model_param)
        out = loss.data.numpy() / np.max(x_test.data.numpy())
        datadict[str(value)].append(out)
        
        print("Interpolation Loss: ",'{:.2e}'.format(out))
        print("Extrapolation Loss: ", end='')
        for e_s in extr_scale:
            for e_val in extr_values:
                x_test, y_test, _ = generate_synthetic_selection_dataset(e_s*e_val[0]*value[1], e_s*e_val[1]*value[1],
                                                                 sample_size, set_size, boundaries = boundaries)

                x_test = x_test.type(torch.DoubleTensor)
                y_test = y_test.type(torch.DoubleTensor)               
                test_loss  = test(model, x_test, y_test)
                out = test_loss.data.numpy() / np.max(x_test.data.numpy())
                datadict[str(value)+str(e_s)+str(e_val)].append(out)
                print('{:.2e}'.format(out), end=' ')
                #print(datadict)
        print('')
    print('\n')



Train range: [0,1]
Interpolation Loss:  2.18e-11
Extrapolation Loss: 3.02e+19 6.54e+20 7.91e+50 4.48e+48 8.02e+139 4.52e+140 
Interpolation Loss:  1.88e-11
Extrapolation Loss: 1.07e+14 5.96e+13 2.49e+36 1.22e+33 3.50e+104 3.61e+107 
Interpolation Loss:  2.17e-11
Extrapolation Loss: 3.94e+14 1.32e+16 1.39e+38 5.83e+38 1.23e+114 5.87e+112 
Interpolation Loss:  8.80e-12
Extrapolation Loss: 1.27e-03 5.53e-04 4.71e+04 6.56e+04 1.77e+30 1.24e+30 
Interpolation Loss:  1.46e-11
Extrapolation Loss: 4.32e-05 5.81e-04 3.92e+02 6.20e+01 3.15e+23 6.63e+22 
Interpolation Loss:  2.43e-11
Extrapolation Loss: 7.34e+24 5.35e+28 4.36e+63 2.88e+60 6.91e+171 6.14e+173 
Interpolation Loss:  3.46e-11
Extrapolation Loss: 2.81e+17 1.11e+16 1.05e+37 3.92e+40 6.22e+116 8.03e+115 
Interpolation Loss:  1.59e-11
Extrapolation Loss: 9.34e-01 8.95e-01 2.86e+10 1.79e+10 6.93e+42 7.78e+42 
Interpolation Loss:  1.76e-11
Extrapolation Loss: 8.53e+13 4.22e+11 4.85e+33 1.49e+32 1.12e+104 6.12e+103 
Interpolation Loss:  1.2

In [5]:
for e_val in extr_values:
    for e_s in extr_scale:
        for value in values:
            print('${:.2e}$'.format(np.median(datadict[str(value)+str(e_s)+str(e_val)])),' & ',end='')
        print('')



$9.60e+13$  & $4.76e+11$  & $3.63e+06$  & $1.13e-06$  & 
$1.25e+36$  & $1.11e+33$  & $4.16e+07$  & $1.69e-03$  & 
$2.31e+104$  & $2.18e+104$  & $6.91e+11$  & $3.33e+04$  & 
$3.00e+13$  & $2.04e+13$  & $4.03e+06$  & $7.38e-07$  & 
$6.87e+32$  & $9.83e+33$  & $2.18e+08$  & $1.95e-03$  & 
$1.81e+107$  & $5.76e+104$  & $6.86e+11$  & $4.77e+04$  & 


In [10]:
for e_val in extr_values:
    for e_s in extr_scale:
        for value in values:
            print('${:.2e}'.format(np.mean(datadict[str(value)+str(e_s)+str(e_val)])),'}$ & ',end='')
        print('')

$7.34e+23 }$ & $5.95e+31 }$ & $3.39e+08 }$ & $2.72e+01 }$ & 
$4.36e+62 }$ & $4.39e+68 }$ & $7.90e+12 }$ & $4.40e+06 }$ & 
$6.91e+170 }$ & $1.95e+195 }$ & $2.16e+16 }$ & $7.76e+27 }$ & 
$5.35e+27 }$ & $2.50e+31 }$ & $2.05e+10 }$ & $2.52e+01 }$ & 
$2.88e+59 }$ & $2.71e+68 }$ & $4.25e+10 }$ & $6.02e+06 }$ & 
$6.14e+172 }$ & $4.45e+199 }$ & $8.72e+16 }$ & $6.88e+27 }$ & 
