This notebook enables to evaluate TiSSNet model.

In [None]:
import csv

import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt

from matplotlib.ticker import FormatStrFormatter
from scipy.signal import find_peaks
from tqdm import tqdm

from utils.eval.eval_utils import evaluate_peaks, compute_ROC, compute_residuals_histogram
from utils.training.metrics import accuracy_for_segmenter, AUC_for_segmenter
from utils.training.data_loading import lines_to_line_generator, get_line_to_spectro_seg
from utils.training.keras_models import TiSSNet

## Parameters

In [None]:
ROOT_DIR = "/path/to/the/dataset"  # path where we expect to find directories named "postives", "negatives" and a csv file
OUTPUT_DIR = "TiSSNet/dataset"  # directory where to output files, in the data folder
BATCH_SIZE = 64
epoch = 22  # epoch checkpoint that we want to load
CHECKPOINT = f"../../../../data/model_saves/TiSSNet/all/cp-{epoch:04d}.ckpt"  # path of the checkpoint to load

SIZE = (128, 186)  # number of pixels in the spectrograms
CHANNELS = 1  # 1 means grayscale 3 RGB
DURATION_S = 100  # duration of the spectrograms in s
OBJECTIVE_CURVE_WIDTH = 10  # defines dispersion of objective function in s

ALLOWED_ERROR_S = 10  # tolerance when evaluating and time distance allowed between two peaks in the probabilities distribution
MIN_PROBA = 0.0005  # minimum value of the output of the segmenter model to record it
TIME_RES = DURATION_S / SIZE[1]  # duration of each spectrogram pixel in seconds

## Load model

In [None]:
data_loader = get_line_to_spectro_seg(size=SIZE, duration_s=DURATION_S, channels=CHANNELS, objective_curve_width=OBJECTIVE_CURVE_WIDTH)
model = TiSSNet
m = model()
m.load_weights(CHECKPOINT)
m.compile(
            optimizer=tf.keras.optimizers.legacy.Adam(),
            loss=tf.losses.binary_crossentropy,
            metrics=[accuracy_for_segmenter, AUC_for_segmenter()])

## Load data

In [None]:
with open(ROOT_DIR + "/dataset.csv", "r") as f:
    csv_reader = csv.reader(f, delimiter=",")
    lines = list(csv_reader)
print(len(lines), "files found")
generator = lines_to_line_generator(lines, repeat=False)
dataset =tf.data.Dataset.from_generator(lambda: map(tuple, generator), output_signature=tf.TensorSpec(shape=[None], dtype=tf.string))
dataset = dataset.map(data_loader).batch(batch_size=BATCH_SIZE)

## Model execution and peaks finding

In [None]:
detected_peaks = []
ground_truth_peaks = []

for images, y in tqdm(dataset, total=1+int(len(lines)/BATCH_SIZE)):
    # predict the output for a whole batch
    predicted = m.predict(images, verbose=False)
    for i, p in enumerate(predicted):
        # for each output, apply a peaks finding algorithm
        detected_peaks.append(find_peaks(p, height=MIN_PROBA, distance=ALLOWED_ERROR_S/TIME_RES))
        _y = y[i,:,0] if len(y.shape) == 3 else y[i, :]
        ground_truth_peaks.append(find_peaks(_y, height=MIN_PROBA, distance=ALLOWED_ERROR_S/TIME_RES))
        
detected_peaks = [[(d[0][i]*TIME_RES, d[1]["peak_heights"][i]) for i in range(len(d[0]))] for d in detected_peaks]
ground_truth_peaks = [d[0]*TIME_RES for d in ground_truth_peaks]

## Peaks statistics

#### Get number of peaks

In [None]:
i,j = 0,0
for idx in range(len(detected_peaks)):
    i+=len(detected_peaks[idx])
    j+=len(ground_truth_peaks[idx])
print(f"{i} peaks found out of {j}")

#### ROC curve computing

In [None]:
TP, FP, TP_per_seg, TN_per_seg, FP_per_seg, FN_per_seg, P_per_seg, N_per_seg = evaluate_peaks(ground_truth_peaks, detected_peaks, ALLOWED_ERROR_S)
TPr, FPr = compute_ROC(TP_per_seg, P_per_seg, FP_per_seg, N_per_seg, thresh_delta=0.001)
plt.plot(FPr, TPr)
np.save(f"../../../../data/npy/{OUTPUT_DIR}/FPr.npy", FPr)
np.save(f"../../../../data/npy/{OUTPUT_DIR}/TPr.npy", TPr)
plt.xlim(0,1)
plt.ylim(0,1)
plt.ylabel("TP rate")
plt.xlabel("FP rate")
plt.title("ROC curve")
plt.savefig(f"../../../../data/figures/{OUTPUT_DIR}/ROC.png")

## Residuals histogram computing

In [None]:
BAR_WIDTH = 0.9
step = 2*TIME_RES

extremum = step * np.round(ALLOWED_ERROR_S / step)
allowed_d = np.arange(-extremum, extremum+step, step)
TP_by_distance = compute_residuals_histogram(allowed_d, TP)
TP_by_distance = list(TP_by_distance.values())

np.save(f"../../../../data/npy/{OUTPUT_DIR}/TP_by_distance.npy", TP_by_distance)

fig, ax = plt.subplots(1, 1, figsize=(8, 5))

plt.bar(allowed_d, TP_by_distance, width=BAR_WIDTH, align='center')
plt.xticks(allowed_d)
plt.xlim(allowed_d[0]-0.5-(1-BAR_WIDTH), allowed_d[-1]+0.5+(1-BAR_WIDTH))
ax.xaxis.set_major_formatter(FormatStrFormatter('%.1f'))

plt.xlabel('Time residuals (s)', fontsize=12)
plt.ylabel('Proportion of detections', fontsize=12)
plt.savefig(f'../../../../data/figures/{OUTPUT_DIR}/histogram.png', bbox_inches='tight')