This notebook enables to evaluate ResNet-50 model.

In [None]:
import csv

import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from tqdm import tqdm

from utils.eval.eval_utils import compute_ROC
from utils.training.data_loading import get_load_spectro_for_class
from utils.training.keras_models import resnet

## Parameters

In [None]:
ROOT_DIR = "/path/to/the/dataset"  # path where we expect to find directories named "postives", "negatives" and a csv file
BATCH_SIZE = 64
epoch = 31  # epoch checkpoint that we want to load
CHECKPOINT = f"../../../data/model_saves/ResNet-50/all/cp-{{epoch:04d}}.ckpt"

ALLOWED_ERROR_S = 10  # tolerance when evaluating and time distance allowed between two peaks in the probabilities distribution
OUTPUT_DIR = "ResNet-50/dataset"  # directory where to output files, in the data folder

load = get_load_spectro_for_class(224, 3)

## Load model

In [None]:
m = resnet()
m.load_weights(CHECKPOINT)
m.compile(
            optimizer=tf.keras.optimizers.legacy.Adam(),
            loss=tf.losses.binary_crossentropy,
            metrics=['Accuracy','AUC'])

## Load data

In [None]:
with open(ROOT_DIR + "/dataset.csv", "r") as f:
    csv_reader = csv.reader(f, delimiter=",")
    lines = list(csv_reader)
print(len(lines), "files found")
lines = [l[0] for l in lines]
dataset = tf.data.Dataset.from_tensor_slices(lines)
dataset = dataset.map(load).batch(batch_size=BATCH_SIZE)

## Model execution

In [None]:
detected = []
ground_truth = []

for images, y in tqdm(dataset, total=1+int(len(lines)/BATCH_SIZE)):
    predicted = m.predict(images, verbose=False)
    detected.extend(predicted[:,0])
    ground_truth.extend(y)
    
detected = np.array(detected)
ground_truth = np.array(ground_truth)    

## ROC curve computing

In [None]:
TPr, FPr = compute_ROC(detected[ground_truth==1], np.count_nonzero(ground_truth==1), detected[ground_truth==0], np.count_nonzero(ground_truth==0), thresh_delta=0.001)
plt.plot(FPr, TPr)
np.save(f"../../../data/npy/{OUTPUT_DIR}/FPr.npy", FPr)
np.save(f"../../../data/npy/{OUTPUT_DIR}/TPr.npy", TPr)
plt.xlim(0,1)
plt.ylim(0,1)
plt.ylabel("TP rate")
plt.xlabel("FP rate")
plt.title("ROC curve")
plt.savefig(f"../../../data/figures/{OUTPUT_DIR}/ROC.png")