# Compute metrics

This notebook allows you to compute the metrics of a classification based on the true value and a given model.

In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from helpers import *

## Data loading

The loading is designed for csv file generated by the Collect Information option of the ImageJ plugin. The class of the nuclei is missing must be indicated (classes mask option).

[TODO] write the path of your file, create your dataset and write the possible classes

In [None]:
path = "my_path/measurements.csv"

dataset = get_data(path, "Human")

classes = ["Mouse", "Human"]

In [None]:
dataset[CLASS_COLUMN].value_counts()

## Data processing
Shuffle, normalize and split the data between inputs and targets.

[TODO] write the path of your normalization file

In [None]:
inputs, targets = process_prediction_data(dataset, classes, 'models/classification/normalization.json')

CSBDeep need 3D images with more than 1 element in each dimension, we create an image of size (2, 2, feature_size / 4). If the feature size is not a multiple of 4, it will be padded with 0.

In [None]:
inputs = resize_inputs(inputs)

## Load model

[TODO] write the path of your model file

In [None]:
model = keras.models.load_model('models/classification/model.h5')
model.summary()

## Compute the metrics

In [None]:
metrics = evaluate(model, inputs, targets)

print('Accuracy: %f' % metrics[0])

for i, class_name in enumerate(classes):
    print('------------------------------------------------------------------------')
    print('Class: ' + class_name)
    print('Precision: %f' % metrics[1][i])
    print('Recall: %f' % metrics[2][i])
    print('F1 score: %f' % metrics[3][i])
    print('ROC AUC: %f' % metrics[4][i])
print('------------------------------------------------------------------------')

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1, figsize=(10,10))
sns.heatmap(metrics[5], ax=ax, vmin=0, vmax=1, center=0, square=True, cbar_kws={"shrink": 0.7}, xticklabels=classes, yticklabels=classes, annot=True, fmt="f")
ax.set_title("Confusion matrix")
ax.set_xlabel('Estimated classes')
ax.set_ylabel('Real classes')
plt.show()