In [None]:
import csv
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt

from matplotlib.ticker import FormatStrFormatter
from scipy.signal import find_peaks
from sklearn.utils import shuffle

from tqdm import tqdm

from utils.eval.eval_utils import evaluate_peaks, compute_ROC, compute_residuals_histogram
from utils.training.metrics import accuracy_for_segmenter, AUC_for_segmenter
from utils.training.data_loading import get_line_to_dataset_waveform
from utils.training.keras_models import AcousticPhaseNet

## Parameters

In [None]:
ROOT_DIR = "/path/to/the/dataset"  # path where we expect to find directories named "postives", "negatives" and a csv file
BATCH_SIZE = 64
epoch = 22  # epoch checkpoint that we want to load
CHECKPOINT = f".../../../data/model_saves/AcousticPhaseNet/all/cp-{{epoch:04d}}.ckpt"  # path of the checkpoint to load
OUTPUT_DIR = "AcousticPhaseNet/dataset"  # directory where to output files, in the data folder

SIZE = int(2**(np.ceil(np.log2(100*240+1))))  # number of points in each file rounded to the next pow of 2
CHANNELS = 1  # 1 means grayscale 3 RGB
DURATION_S = 100  # duration of the spectrograms in s
OBJECTIVE_CURVE_WIDTH = 10  # defines width of objective function in s

ALLOWED_ERROR_S = 10  # tolerance when evaluating and time distance allowed between two peaks in the probabilities distribution
MIN_PROBA = 0.0001  # minimum value of the output of the segmenter model to record it
TIME_RES = DURATION_S / SIZE

data_loader = get_line_to_dataset_waveform(size=SIZE, duration_s=DURATION_S, objective_curve_width=OBJECTIVE_CURVE_WIDTH)

## Load model

In [None]:
model = AcousticPhaseNet
m = model(SIZE)
m.load_weights(CHECKPOINT)
m.compile(
            optimizer=tf.keras.optimizers.legacy.Adam(),
            loss=tf.losses.binary_crossentropy,
            metrics=[accuracy_for_segmenter, AUC_for_segmenter()])

## Load data

In [None]:
# open the csv listing data, shuffling the lines
with open(ROOT_DIR + "/dataset.csv", "r") as f:
    csv_reader = csv.reader(f, delimiter=",")
    lines = list(csv_reader)
lines = shuffle(lines)
print(len(lines), "files found")
x, y = data_loader(lines)
dataset = tf.data.Dataset.from_tensor_slices((list(x), list(y)))
dataset = dataset.batch(batch_size=BATCH_SIZE)

## Model execution and peaks finding

In [None]:
detected_peaks = []
ground_truth_peaks = []
for x, y in tqdm(dataset, total=1+int(len(lines)/BATCH_SIZE)):
    # predict the output for a whole batch
    predicted = m.predict(x, verbose=False)
    for i, p in enumerate(predicted):
        # for each output, apply a peaks finding algorithm
        detected_peaks.append(find_peaks(p, height=MIN_PROBA, distance=ALLOWED_ERROR_S/TIME_RES))
        _y = y[i,:,0] if len(y.shape) == 3 else y[i, :]
        ground_truth_peaks.append(find_peaks(_y, height=MIN_PROBA, distance=ALLOWED_ERROR_S/TIME_RES))
        
detected_peaks = [[(d[0][i]*TIME_RES, d[1]["peak_heights"][i]) for i in range(len(d[0]))] for d in detected_peaks]
ground_truth_peaks = [d[0]*TIME_RES for d in ground_truth_peaks]

## Peaks statistics

#### Get number of peaks

In [None]:
i,j = 0,0
for idx in range(len(detected_peaks)):
    i+=len(detected_peaks[idx])
    j+=len(ground_truth_peaks[idx])
print(f"{i} peaks found out of {j}")

#### ROC curve computing

In [None]:
TP, FP, TP_per_seg, TN_per_seg, FP_per_seg, FN_per_seg, P_per_seg, N_per_seg = evaluate_peaks(ground_truth_peaks, detected_peaks, ALLOWED_ERROR_S)
TPr, FPr = compute_ROC(TP_per_seg, P_per_seg, FP_per_seg, N_per_seg, thresh_delta=0.001)
plt.plot(FPr, TPr)
np.save(f"../../../data/npy/{OUTPUT_DIR}/FPr.npy", FPr)
np.save(f"../../../data/npy/{OUTPUT_DIR}/TPr.npy", TPr)
plt.xlim(0,1)
plt.ylim(0,1)
plt.ylabel("TP rate")
plt.xlabel("FP rate")
plt.title("ROC curve")
plt.savefig(f"../../../data/figures/{OUTPUT_DIR}/ROC.png")

#### Residuals histogram computing

In [None]:
BAR_WIDTH = 0.9
step = 40*TIME_RES

extremum = step * np.round(ALLOWED_ERROR_S / step)
allowed_d = np.arange(-extremum, extremum+step, step)
TP_by_distance = compute_residuals_histogram(allowed_d, TP)

np.save(f"../../../data/npy/{OUTPUT_DIR}/TP_by_distance.npy", TP_by_distance)

fig, ax = plt.subplots(1, 1, figsize=(8, 5))

plt.bar(allowed_d, TP_by_distance, width=BAR_WIDTH, align='center')
plt.xticks(allowed_d)
plt.xlim(allowed_d[0]-0.5-(1-BAR_WIDTH), allowed_d[-1]+0.5+(1-BAR_WIDTH))
ax.xaxis.set_major_formatter(FormatStrFormatter('%.1f'))

plt.xlabel('Time residuals (s)', fontsize=12)
plt.ylabel('Proportion of detections', fontsize=12)
plt.savefig(f'../../../data/figures/{OUTPUT_DIR}/histogram.png', bbox_inches='tight')