In [None]:
import torch
import torch.nn.functional as F

import numpy as np

from tqdm import tqdm_notebook as tqdm

from data_generator_helper import generate_synthetic_selection_dataset
from data_generator_helper import generate_synthetic_operation_dataset
from models_new.nalu import NALU
from models_new.nac import NAC

from torchvision import datasets
import torchvision.models as models
import torchvision.utils as vutils
from tensorboardX import SummaryWriter

import datetime
import os

import matplotlib
import matplotlib.pyplot as plt
#from mpl_toolkits.mplot3d import Axes3D as plt3

from ipywidgets import interactive
from ipywidgets import widgets

In [None]:
def reportLoss(loss):
    print(loss)
    
def train(model, optimizer, x_train, y_train, epochs, batch_size, model_param):
    '''
    if model_param == "NAC":
        weights = np.zeros((epochs,len(x_train)//batch_size,out_dim,sample_size,3))
    elif model_param == "NALU":
        weights = np.zeros((epochs,len(x_train)//batch_size,out_dim,sample_size,4))
        g_prev = np.zeros((epochs,len(x_train)//batch_size,out_dim))
    
    '''
    losses = np.zeros((epochs,len(x_train)//batch_size))
    for epoch in range(epochs):
        for batch in range(len(x_train) // batch_size):
            '''
            # Save weights
            for children in model.model.children():
                plist = [param for param in children.parameters()]
                if model_param == "NAC":
                    W_hat = plist[0]
                    M_hat = plist[1]
                elif model_param == "NALU":
                    W_hat = plist[1]
                    M_hat = plist[2]
                    weights[epoch,batch,:,:,3] = plist[0].detach().numpy()
                W = torch.tanh(W_hat) * torch.sigmoid(M_hat)
                weights[epoch,batch,:,:,0] = W_hat.detach().numpy()
                weights[epoch,batch,:,:,1] = M_hat.detach().numpy()
                weights[epoch,batch,:,:,2] = W.detach().numpy()
            '''
            model.train()
            optimizer.zero_grad()

            x_batch_train = x_train[batch:(batch+batch_size),:]
            y_batch_train = y_train[batch:(batch+batch_size),:]
            '''
            if model_param == "NALU":
                g_prev[epoch,batch,:] = F.linear(plist[0],x_batch_train,).detach().numpy().flatten()
                #print(np.shape(g_prev),np.shape(F.linear(plist[0],x_batch_train).detach().numpy()))
            '''
            out = model(x_batch_train)
    
            loss = F.mse_loss(out, y_batch_train)
            
            if loss != loss:
                print("nan detected")
            #losses[epoch,batch] = loss
            loss.backward()
            optimizer.step()
    
           
    return test(model,x_train,y_train),losses

     
    
def test(model, x_test, y_test):
    model.eval()
    output_test = model(x_test)
    loss = F.mse_loss(output_test, y_test)
    return loss

In [None]:
operators = ["add","sub","mult","div","square","root"]
init = 'Kai_uni'
model_param =  "NALU"

test_per_range = 10
sample_size = 2
set_size = 100

in_dim = sample_size
hidden_dim = 1
out_dim = 1
num_layers = 1

lr = 0.01
epochs = 1000
batch_size = 1
values = [[0,1],[-1,1],[0,100],[-100,100]]
for op in operators:
    print("Operator: " + str(op))
    for j, value in tqdm(enumerate(values)):

        min_value = value[0]
        max_value = value[1]

        #print("Test range: "+str(values))
        print("Train range: ["+str(min_value)+","+str(max_value)+"]")
        for k in range(test_per_range):
            avgloss = np.zeros(4)
            exploss = 0
            i=0
            if model_param == "NALU":
                model = NALU(num_layers, in_dim, hidden_dim, out_dim,init)
            elif model_param == "NAC":
                model = NAC(num_layers, in_dim, hidden_dim, out_dim, init)
            optimizer = torch.optim.RMSprop(model.parameters(),lr=lr)

            x_train, y_train = generate_synthetic_operation_dataset(op,min_value, 
                                                                                max_value, sample_size, 
                                                                                set_size, boundaries = None)

            x_test, y_test = generate_synthetic_operation_dataset(op,min_value, max_value,
                                                             sample_size, set_size, boundaries = None)

            x_train = x_train.type(torch.DoubleTensor)
            y_train = y_train.type(torch.DoubleTensor)
            x_test = x_test.type(torch.DoubleTensor)
            y_test = y_test.type(torch.DoubleTensor)

            loss,losses = train(model, optimizer, x_train, y_train, epochs, batch_size, model_param)
            filename = str(op)+"_"+str(min_value)+"_"+str(max_value)+"_"+str(k)
            np.save(filename,losses)
            out = loss.data.numpy() / np.max(x_test.data.numpy())
            #print("Interpolation Loss: ",'{:.2e}'.format(out))
            exploss = exploss + out
            #print("Extrapolation Loss: ", end='')
            for idx,val in enumerate(values):
                
                x_test, y_test = generate_synthetic_operation_dataset(op,val[0], val[1],
                                                                 sample_size, set_size, boundaries = None)

                x_test = x_test.type(torch.DoubleTensor)
                y_test = y_test.type(torch.DoubleTensor)               
                test_loss  = test(model, x_test, y_test)
                out = test_loss.data.numpy() #/ np.max(x_test.data.numpy())
                avgloss[idx] = avgloss[idx] + out
                #print('{:.2e}'.format(out), end=' ')
        print('{:.2e}'.format(exploss/10))
        for avgl in avgloss:
            print('{:.2e}'.format(avgl/10), end=' ') 
        print('\n')
            #print('')
        #print('\n')


