In [None]:
import sys
sys.path.append('../')

import yaml

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import dataset
import networks
import utils
_Ext = utils.Extent
from sklearn.model_selection import KFold

In [None]:
with open('../settings.yaml', 'r') as file:
  config = yaml.safe_load(file)


train_cfg = config['NN_TRAINING']
nn_cfg = config['NEURAL_NETWORK']

# Load the data

In [None]:
# input data
cx_raw = np.load('../data/cx_aniso_raw.npy')
cy_raw = np.load('../data/cy_aniso_raw.npy')
beta_raw = np.load('../data/beta_aniso_raw.npy')
orient_raw = np.load('../data/orient_aniso_raw.npy')
aniso_raw = np.load('../data/aniso_raw.npy')

# output data
vf_raw = np.load('../data/vf_aniso_raw.npy')
constit_chol_raw = np.load('../data/const_chol_aniso_raw.npy')

In [None]:
num_samples = cx_raw.shape[0]

print(f'We have {num_samples} samples')

# Stack data

In [None]:
# input

# All the cx of the first cell stacked on top of the second and so on...
# note that the center cell corresponds to the 4th entry (look at neighbors code)
nn_in_raw = np.hstack(( cx_raw.reshape((num_samples, -1)),
                        cy_raw.reshape((num_samples, -1)),
                        beta_raw[:, np.newaxis],
                        orient_raw[:, np.newaxis],
                        aniso_raw[:, np.newaxis],))

print(f'input data contains {nn_in_raw.shape[0]} samples with {nn_in_raw.shape[1]} features')

# output
nn_out_raw = np.vstack((constit_chol_raw[:,0],
                        constit_chol_raw[:,1],
                        constit_chol_raw[:,2],
                        constit_chol_raw[:,3],
                        constit_chol_raw[:,4],
                        constit_chol_raw[:,5],
                        vf_raw
                        )).T

print(f'output data contains {nn_out_raw.shape[0]} samples with {nn_out_raw.shape[1]} features')

# Clean up data

In [None]:
output_mean, output_std = np.mean(nn_out_raw, axis=0), np.std(nn_out_raw, axis=0)

np.save('../data/output_mean', output_mean)
np.save('../data/output_std', output_std)


In [None]:
nn_out = utils.normalize_z_scale(nn_out_raw, output_mean, output_std)

# Create datasets and loader helpers

In [None]:
num_train = 10000
num_test = 1000
num_validate = 1000

voronoi_data = dataset.VoronoiDataset(voronoi_params=nn_in_raw,
                                      homogen_params=nn_out)

train_data = torch.utils.data.Subset(voronoi_data, np.arange(0, num_train))
test_data = torch.utils.data.Subset(voronoi_data, np.arange(num_train, num_train+ num_test))
val_data = torch.utils.data.Subset(voronoi_data, np.arange(num_train+ num_test, num_train+ num_test+num_validate))

In [None]:
num_workers = 0
batch_size = train_cfg['batch_size']

train_loader = torch.utils.data.DataLoader(train_data, 
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=num_workers)

# NN settings and init

In [None]:

nn_settings = networks.NNSettings(
                          input_dim = nn_in_raw.shape[1],
                          num_layers = nn_cfg['num_layers'],
                          num_neurons_per_layer = nn_cfg['neurons_per_layer'],
                          output_dim = nn_out.shape[1]
                          )

voro_net = networks.VoronoiNet(nn_settings)

# Train loop

In [None]:
# Optimizer
optimizer = torch.optim.Adam(voro_net.parameters(), lr= train_cfg['lr'])

In [None]:
def get_C_matrix(homo_params_unnormalized):
  L = torch.zeros((homo_params_unnormalized.shape[0], 3, 3))
  L[:,0,0] = torch.clip(homo_params_unnormalized[:, 0], min=1e-3)
  L[:,1,1] = torch.clip(homo_params_unnormalized[:, 1], min=1e-3)
  L[:,2,2] = torch.clip(homo_params_unnormalized[:, 2], min=1e-3)
  L[:,1,0] = homo_params_unnormalized[:, 3]
  L[:,2,0] = homo_params_unnormalized[:, 4]
  L[:,2,1] = homo_params_unnormalized[:, 5]
  C_matrix = torch.einsum('dij,djk->dik',L, torch.transpose(L, 1, 2))
  return C_matrix


In [None]:
homo_params_unnorm = torch.tensor(nn_out_raw)
C_matrix = get_C_matrix(homo_params_unnorm)
max_C_norm = torch.amax(torch.linalg.norm(C_matrix, dim=(1, 2)))

In [None]:
def loss_wrapper(homo_params, pred_homo_params):
  homo_params_unnorm = utils.unnormalize_z_scale(homo_params, torch.tensor(output_mean), torch.tensor(output_std))
  pred_homo_params_unnorm = utils.unnormalize_z_scale(pred_homo_params, torch.tensor(output_mean), torch.tensor(output_std))

  pred_vf = pred_homo_params_unnorm[:, 6].unsqueeze(1)
  vf = homo_params_unnorm[:, 6].unsqueeze(1)

  pred_C_matrix = get_C_matrix(pred_homo_params_unnorm).view(homo_params_unnorm.shape[0], -1)
  C_matrix = get_C_matrix(homo_params_unnorm).view(homo_params_unnorm.shape[0], -1)
  
  pred_C_vf = torch.hstack((pred_C_matrix, pred_vf))
  actual_C_vf = torch.hstack((C_matrix, vf))
  
  mean_squared_loss = torch.mean((pred_C_vf - actual_C_vf)**2 )

  return mean_squared_loss.item()

In [None]:
# number of epochs to train the model
num_epochs = train_cfg['num_epochs']

convg_history = {'train_loss': [], 'test_loss': [], 'val_loss': []}

for epoch in range(1, num_epochs+1):

  net_loss = 0.

  for data in train_loader:

    voro_params, homo_params = data
    optimizer.zero_grad()

    pred_homo_params = voro_net(voro_params)

    loss = torch.mean((pred_homo_params - homo_params)**2)

    loss.backward()
    optimizer.step()

    net_loss += loss.item()


  net_loss = net_loss/len(train_loader)

  print(f'epoch: {epoch:d} \t loss: {net_loss:.2E}')
  

  if epoch%1 == 0:
    voro_params, homo_params = train_data[:]
    pred_homo_params = voro_net(voro_params)
    
    loss = loss_wrapper(homo_params, pred_homo_params)
    convg_history['train_loss'].append(loss)

    voro_params, homo_params = test_data[:]
    pred_homo_params = voro_net(voro_params)

    
    test_loss = loss_wrapper(homo_params, pred_homo_params)
    convg_history['test_loss'].append(test_loss)


    voro_params, homo_params = val_data[:]
    pred_homo_params = voro_net(voro_params)
    
    val_loss = loss_wrapper(homo_params, pred_homo_params)
    convg_history['val_loss'].append(val_loss)


    print("-"*65)
    print(f'Validation: {epoch:d} \t loss: {val_loss:.2E}')
    print("-"*65)

In [None]:
plt.figure()
plt.plot(convg_history['train_loss'], label='Training Loss')
plt.plot(convg_history['test_loss'], label='Testing Loss')
plt.plot(convg_history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
torch.save(voro_net.state_dict(), '../data/voro_net.pt')

# Check if saving and loading works

In [None]:
voro_net = networks.VoronoiNet(nn_settings)
voro_net.load_state_dict(torch.load('../data/voro_net.pt'))
voro_net.eval()