In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import utils
from utils import Visuals
import evaluators
import tensorflow_addons as tfa
from data_parameters import data_param

print(tf.config.experimental.get_visible_devices())

per_sys = data_param['per_sys']
N_elem = len(per_sys)
spectrum_length = data_param['spectrum_length']
max_buffer = data_param['max_buffer']
batch_size = 32

val_dataset = utils.get_dataset(tf.data.TFRecordDataset.list_files('Core-loss EELS TFRecord/validationset/VALIDATION*.tfrecords', shuffle=True))
val_dataset = val_dataset.batch(batch_size,drop_remainder=True).prefetch(tf.data.AUTOTUNE) 

test_dataset = utils.get_dataset(tf.data.TFRecordDataset.list_files('Core-loss EELS TFRecord/testset/TEST*.tfrecords', shuffle=True))
test_dataset = test_dataset.shuffle(max_buffer,reshuffle_each_iteration=False).batch(batch_size,drop_remainder=True).prefetch(tf.data.AUTOTUNE)  

confmat_dataset = utils.get_dataset(tf.data.TFRecordDataset.list_files("Core-loss EELS TFRecord/single_element_spec/*.tfrecords", shuffle=True))
confmat_dataset = confmat_dataset.batch(8000).prefetch(tf.data.AUTOTUNE) #confusion matrix function only takes a single batch at the moment

## load a model

In [None]:
model = tf.keras.models.load_model('trained element identification models/2ViT_3UNet_ensemble', custom_objects={'custom_loss' : utils.custom_loss}) 

## set threshold

In [None]:
threshold = 0.35 # this is the threshold that leads to equal precision and recall for the 2xVit+3xUNet ensemble on simulated data.
#note that the simulated data has a lot of edges with boundlessly small jump ratios (and SNR) that bring this threshold down a lot. The optimal threshold 
# for the 2xVit+3xUNet ensemble on experimental data is 0.75 , for the ViT it is 0.80 and for the UNet is is 0.95

## or determine the best threshold from validation data

In [None]:
dset = val_dataset

metrics = [tfa.metrics.F1Score(N_elem, 'weighted', th, name = f'f1_{th}') for th in np.arange(0.05,1,0.05)]
model.compile(
            optimizer=tf.optimizers.Adam(learning_rate=0.001),  
                loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),   
                metrics= metrics)
fscores = model.evaluate(dset)[1::]

metrics = [tf.metrics.Precision(th, name = f'prec_{th}') for th in np.arange(0.05,1,0.05)]
model.compile(
            optimizer=tf.optimizers.Adam(learning_rate=0.001),  
                loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),   
                metrics= metrics)
precisions = model.evaluate(dset)[1::]

metrics = [tf.metrics.Recall(th, name = f'rec_{th}') for th in np.arange(0.05,1,0.05)]
model.compile(
            optimizer=tf.optimizers.Adam(learning_rate=0.001),  
                loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),   
                metrics= metrics)
recalls = model.evaluate(dset)[1::]


plt.figure(figsize=(12,5))
plt.scatter(np.arange(0.05,1,0.05),fscores, marker='o',c = 'k',label = r'F$_1$')
plt.scatter(np.arange(0.05,1,0.05),precisions,marker='s',c= 'r' , label = 'precision')
plt.scatter(np.arange(0.05,1,0.05),recalls,marker = 'v', c = 'b', label = 'recall')
plt.xticks(np.arange(0.05,1,0.05))
plt.legend()
plt.xlabel('Threshold')
plt.ylabel('Value')


## get metrics for testset

In [None]:
#selected part of test dataset (full one takes quite a while)
evaluators.numeric_metrics(model=model,dataset=test_dataset.take(500),threshold=threshold,drop_carbon=False)

EMR = evaluators.match_rate(model=model,threshold=threshold,drop_carbon=False)
print(EMR.calculate(test_dataset.take(500)).numpy())

## get confusion matrix

In [None]:
CMM = evaluators.confusionmatrix(class_names = per_sys,model = model,dataset = confmat_dataset)
CMM.plot_confusion_matrix()

## visualize predictions

In [None]:
visuals = Visuals()
visuals.visual_prediction(model=model,dataset=test_dataset.take(1),start=0,end=6,threshold=threshold)
