# Tutorial on train the Neural Network Emulator

This tutorial explains how you can use SPEXAI to train a Neural Network emulator from SPEX data and evualute the preformance of the emulator.

In [None]:
#import liberaries and scripts
import torch

import sys  
f_dir = '~/SPEXAI/'
sys.path.insert(0, f_dir)
import dataloader
import train
import plot
import neuralnetwork

### Reading in the training data

The Neural Network (NN) emulator is trained on spectra generated by SPEX at different temperatures between 0.2 and 25 [KeV]. These spectra are under a Data directory with first column the flux and the second the energy bins. The files are index with an indexfile with there filename and corresponding temperature.

These text files can be read with ```dataloader.SpexAIMemoryDataset``` and then Standard Scaled and split up into a training dataset and validation dataset.

In [None]:
element = '1'

#directory of the training data
dir = '~/Data/'

#file directory of training data of specific element
datadir = dir+'element_'+element+'/'
indexdir = dir+'Split_data/element_'+element+'/'

#read in the training data
Z01 = dataloader.SpexAIMemoryDataset('Z'+element+'_list_train.txt', datadir, indexdir, min_flux=-10.,
                               min_energy=0.1, max_energy=25 , min_temp=0.2, max_temp=10)
#apply standart scaler on data
Z01.scale_data()
#split data in train and test subset
Z01.split_data()



### Define a neural network architecture to train

A neural network function can be defined with the ```neuralnetwork.FFN``` giving a Feed Forward Network (FFN) with where you can indicate the number of layers and neurons per layer, the activation function used and if you want to use dropout. ```neuralnetwork.CNN``` can in addition to FFN layers also add inverse Convulational layers.

In [None]:
nn_model = neuralnetwork.FFN(1, int(Z01.flux.size(1)), 3, 150, act_name='nonlin')

### Train the neural network emulator on the training data

The learning algorithm ```train.NeuralNetworkTrainer``` will train the NN on the training data from spex. The learning algorithm can also be adjusted with different hyperparameters, the optimizer, loss function and critearia for the learningrate regulazir can be changed. In adition ```mask``` and ```mask_test``` can be used to mask out part of the parameter space from the loss-calculation if the flux in this region is low enhough.

In [None]:
#train on Hydrogen (H): Z1
element = '1'

mask_train = torch.where(Z01.mask_train, 1., 0).type(torch.float32)
mask_test = torch.where(Z01.mask_test, 1., 0).type(torch.float32)

#train neural network
model = train.NeuralNetworkTrainer(Z01.x_scaled_train, Z01.y_scaled_train, Z01.x_scaled_test, Z01.y_scaled_test, nn_model, 
                                   scaler_flux=Z01.scaler_flux, mask=mask_train, mask_test=mask_test, save_model=False, element=element)

model.train(500, 128)

### Evualuate the training

you can also load in an model that has already been trained with ```train.NeuralNetworkTrainer.load``` and you can evualu the preformce of the NN emulator by looking at the loss function and the error fraction between the orginal SPEX spectra's and the  spectra's predicted by the NN emulator.

In [None]:
#load model
model.load('/Best_NN/Z1/FF_out(50125)_nL(3|150)_Act(tanh)_p(0.0)')

In [None]:
#print the mean squeared error over the validation dataset
print('Loss on test data is {:.3e}'.format(model.original_loss(model.model(Z01.x_scaled_test.to(model.device)).cpu().detach(), Z01.y_scaled_test)))

#plot the loss per epoch
plot.loss(model.loss_train, model.loss_test)

Show the error fraction in the 2D plot for differrent temperatures and energies, and show if points around a critical error fraction 1e-3 by having a split colorbar.

In [None]:
#draw predicted spectra's from the NN emulator
y_pred = Z01.power(model.predict(Z01.x_scaled_test))
y      = Z01.power(Z01.y_test, scaler=False)

#calculate error fraction between predicted spectra and validation spectra's
frac = dataloader.fraction(y_pred, y, absolute=False)

#plot the error fraction in the 2D plot for diff. temperatures and energies, and show if points around a critical error fraction 1e-3
plot.heatmap(Z01.energy, Z01.x_test, frac, Z01.mask_test, color='seismic', Z_center=-3)

In [None]:
#plot a spectra of spex and the spectra predicted by the NN at temperature of 5 KeV
temp = 5
plot.fraction([temp], Z01.x_test, y_pred, y, Z01.mask_test, Z01.energy)
