In [None]:
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import pickle
import numpy as np
import training
import config
import matplotlib.pyplot as plt

In [None]:
with open('./data/inv_power_dataset.p', 'rb') as file:
    a = pickle.load(file)

In [None]:
X, Y = a['X'], a['Y']
Xn, Yn = a['Xn'], a['Yn']
X_min, X_max = a['X_min'], a['X_max']
Y_min, Y_max = a['Y_min'], a['Y_max']

X_learn, Y_learn = a['X_learn'], a['Y_learn']
X_train, Y_train = a['X_train'], a['Y_train']
X_valid, Y_valid = a['X_valid'], a['Y_valid']
X_test , Y_test  = a['X_test'] , a['Y_test']

Xn_learn, Yn_learn = a['Xn_learn'], a['Yn_learn']
Xn_train, Yn_train = a['Xn_train'], a['Yn_train']
Xn_valid, Yn_valid = a['Xn_valid'], a['Yn_valid']
Xn_test , Yn_test  = a['Xn_test'] , a['Yn_test']

In [None]:
train_data = TensorDataset(Xn_train, Yn_train)
valid_data = TensorDataset(Xn_valid, Yn_valid)
test_data  = TensorDataset(Xn_test, Yn_test)

train_loader = DataLoader(train_data, batch_size=len(train_data))
valid_loader = DataLoader(valid_data, batch_size=len(valid_data))
test_loader  = DataLoader(test_data, batch_size=len(test_data))

In [None]:
lossfunction = torch.nn.MSELoss()

In [None]:
SEEDs  = [1,2,3,4,5,6,7,8,9,10]
LAYERs = [2,3,4,5,6,7,8,9,10,11,12,13,14,15]
LRs    = [0, -1, -2, -3, -4,-5]

In [None]:
results = torch.zeros([len(SEEDs), len(LAYERs), len(LRs), 3])
results.shape

In [None]:
for sidx in range(len(SEEDs)):
    seed = SEEDs[sidx]
    for lidx in range(len(LAYERs)):
        layer = LAYERs[lidx]
        for lridx in range(len(LRs)):
            lr = LRs[lridx]
            
            exp_setup = f'{layer}_{lr}_{seed}'
            
            NN_temp = torch.load(f'./NNs/INVPOW__{exp_setup}')

            for x_train, y_train in train_loader:
                prediction_train = NN_temp(x_train)
            for x_valid, y_valid in valid_loader:
                prediction_valid = NN_temp(x_valid)
            for x_test, y_test in test_loader:
                prediction_test = NN_temp(x_test)

            loss_train = lossfunction(Yn_train, prediction_train)
            loss_valid = lossfunction(Yn_valid, prediction_valid)
            loss_test = lossfunction(Yn_test, prediction_test)
            
            results[sidx, lidx, lridx, 0] = loss_train
            results[sidx, lidx, lridx, 1] = loss_valid
            results[sidx, lidx, lridx, 2] = loss_test

            plt.figure(figsize=(12,12))
            plt.plot(np.linspace(0,1,100), np.linspace(0,1,100), c='gray')
            plt.scatter(Yn_train.flatten().numpy(), prediction_train.detach().flatten().numpy(), s=2, c='blue', label=f'train_loss: {loss_train:.4f}')
            plt.scatter(Yn_valid.flatten().numpy(), prediction_valid.detach().flatten().numpy(), s=2, c='green', label=f'valid_loss: {loss_valid:.4f}')
            plt.scatter(Yn_test.flatten().numpy(), prediction_test.detach().flatten().numpy(), s=2, c='red', label=f'test_loss: {loss_test:.4f}')
            plt.xlim([0,1])
            plt.ylim([0,1])
            plt.title(f'{layer}-layer {lr}-lr {seed}-seed', fontsize=30)
            plt.legend(fontsize=12)
            plt.savefig(f'./NNs/{exp_setup}.pdf', format='pdf', bbox_inches='tight')

In [None]:
torch.save(results, './nlc.result')

In [None]:
results = torch.load('./nlc.result')

In [None]:
mean_results = torch.mean(results, dim=0)

In [None]:
torch.where(results[:,:,:,1]==results[:,:,:,1].min())

In [None]:
best_lidx, best_lridx = torch.where(mean_results[:,:,1]==mean_results[:,:,1].min())
best_lidx, best_lridx

In [None]:
LAYERs[best_lidx], LRs[best_lridx]

In [None]:
plt.figure(figsize=(12,12*0.618))
t = torch.min(torch.min(results[:,:,:,0], dim=0)[0], dim=1)[0]
v = torch.min(torch.min(results[:,:,:,1], dim=0)[0], dim=1)[0]
p = torch.min(torch.min(results[:,:,:,2], dim=0)[0], dim=1)[0]
plt.plot(LAYERs, t.detach().numpy(), c='blue')
plt.plot(LAYERs, v.detach().numpy(), c='green')
plt.plot(LAYERs, p.detach().numpy(), c='red')
plt.xticks(np.arange(14)+2)
plt.xlim([2,15])
plt.xlabel('#Layer', fontsize=15)
plt.ylabel('MSE', fontsize=15)

In [None]:
plt.figure(figsize=(12,12*0.618))
t = torch.min(torch.min(results[:,:,:,0], dim=0)[0], dim=0)[0]
v = torch.min(torch.min(results[:,:,:,1], dim=0)[0], dim=0)[0]
p = torch.min(torch.min(results[:,:,:,2], dim=0)[0], dim=0)[0]
plt.plot(LRs, t.detach().numpy(), c='blue')
plt.plot(LRs, v.detach().numpy(), c='green')
plt.plot(LRs, p.detach().numpy(), c='red')
plt.xlabel('Learning Rate', fontsize=15)
plt.xticks(LRs, 10.**np.array(LRs))
plt.ylabel('MSE', fontsize=15)