## Results: Loss functions

In [1]:
import sys
sys.path.append("C:/Users/matth/Documents/Martinos Center/mrtoct") 

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.colors as colors
import matplotlib.cm as cm
import utils.test_utils as utils
import os
import pickle
import numpy as np
import gc

In [2]:
path_model = "C:/Users/matth/Documents/Martinos Center/Models/DeepBrain/"
path_dataset = "C:/Users/matth/Documents/Martinos Center/mrtoct/datasets/"

### Python commands

Run these lines in the command prompt to generate the dataset and train models.

##### Preprocessing:

* <font color = blue>python preprocessing.py --dataset ctmask_nosqrt --no_sqrt --tanh --mask_opt 1</font>

##### Training: 
        
* <font color = blue>python train.py --dataset all_mask_sqrt --name ctmask --n_epochs 30</font>
* <font color = blue>python train.py --dataset all_mask_sqrt --name ctmask_gdl --gdl --n_epochs 30</font>
* <font color = blue>python train.py --dataset all_mask_sqrt --name ctmask_l2 --loss_metric mse --n_epochs 30</font>
* <font color = blue>python train.py --dataset all_mask_sqrt --name ctmask_l2_gdl --loss_metric mse --gdl --n_epochs 30</font>
* <font color = blue>python train.py --dataset all_mask_sqrt --name ctmask_l1_l2 --l1_l2 --n_epochs 30</font>
* <font color = blue>python train.py --dataset all_mask_sqrt --name ctmask_l1_l2_gdl --l1_l2 --gdl --n_epochs 30</font>
* <font color = blue>python train.py --dataset all_mask_sqrt --name ctmask_perceptual --perceptual_loss --n_epochs 30</font>

##### Load data

In [3]:
model_names = ["ctmask", "ctmask_gdl", "ctmask_l2", "ctmask_l2_gdl"]
models, outputs = utils.retrieve_models(model_names, path_model)
trains, vals, bevels, begens, begts = utils.retrieve_history(outputs)

##### Metrics

In [4]:
dataset = np.load(path_dataset + "ctmask_nosqrt" + "/valid_eval.npz")
metrics = utils.compute_val_metrics(models, model_names, dataset, bevels, begens, begts)
metrics

100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [04:22<00:00, 65.51s/it]


Unnamed: 0,MAE,$\sigma_{MAE}$,MSE,$\sigma_{MSE}$,PSNR,$\sigma_{PSNR}$,SSIM,$\sigma_{SSIM}$
,,,,,,,,
ctmask,0.0931,0.0279,0.0345,0.0017,33.6613,5.8422,0.9656,0.0131
ctmask_gdl,0.0933,0.026,0.0327,0.0012,34.2013,5.5504,0.9618,0.0135
ctmask_l2,0.0954,0.0246,0.0318,0.001,34.4802,5.3421,0.9647,0.0127
ctmask_l2_gdl,0.0956,0.0255,0.0321,0.0011,34.3947,5.4117,0.9619,0.0138


##### Load data

In [5]:
model_names = ["ctmask_l1_l2", "ctmask_l1_l2_gdl", "ctmask_perceptual"]
models, outputs = utils.retrieve_models(model_names, path_model)
trains, vals, bevels, begens, begts = utils.retrieve_history(outputs)

##### Metrics

In [6]:
dataset = np.load(path_dataset + "ctmask_nosqrt" + "/valid_eval.npz")
metrics = utils.compute_val_metrics(models, model_names, dataset, bevels, begens, begts)
metrics

100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [03:18<00:00, 66.08s/it]


Unnamed: 0,MAE,$\sigma_{MAE}$,MSE,$\sigma_{MSE}$,PSNR,$\sigma_{PSNR}$,SSIM,$\sigma_{SSIM}$
,,,,,,,,
ctmask_l1_l2,0.0928,0.0259,0.0325,0.0012,34.2615,5.6593,0.9654,0.0127
ctmask_l1_l2_gdl,0.0938,0.0259,0.0324,0.0012,34.3025,5.5604,0.9645,0.013
ctmask_perceptual,0.1076,0.027,0.0391,0.0016,32.4163,5.1863,0.9352,0.0212


In [None]:
fig, ax = plt.subplots(figsize = (10, 5))

labels_sorted = [x for _, x in sorted(zip(MAE, model_names), reverse = True)]
MAE_sorted = sorted(MAE, reverse = True)

for i in range(len(MAE_sorted)):
    ax.text(s = "%.4f" %MAE_sorted[i], y = MAE_sorted[i] + 0.005, x=i - 0.3, fontsize=20)
    
X_axis = np.arange(len(MAE_sorted))
  
ax.bar(X_axis, MAE_sorted, width = 0.6)

ax.tick_params(axis='both', which='major', labelsize=20)
ax.set_xticks(X_axis, labels_sorted)
#ax.set_xlabel("Models", fontsize = 20)
ax.set_ylabel("Mean Absolute Error\n(MAE)", fontsize = 25)
ax.set_title("Best Performance (after 30 epochs)", fontsize = 30)
ax.set_ylim([0, 0.15])

os.chdir("C:/Users/matth/Documents/Martinos Center/ISMRM presentation/")
plt.savefig("Losses_barplot.svg", bbox_inches = 'tight')
plt.show()