This notebook enables to evaluate the SGBT model.

In [None]:
import pickle
import csv
from pathlib import Path

import glob2
import numpy as np
from matplotlib import pyplot as plt
from sklearn import metrics

from src.utils.eval.eval_utils import compute_ROC

## Parameters

In [None]:
ROOT_DIR = "PATH/TO/DATA"  # path where we expect to find directories named "postives", "negatives" and a csv file
OUTPUT_DIR = "SGBT100/OHASISBIO-3"  # directory where to output files, in the data folder
CHECKPOINT = "../../../../data/model_saves/SGBT/save_model_100s"  # save of the SGBT model

MIN_ANNOTATORS_COUNT = 3  # minimum number of agreeing annotators needed to consider one positive pick

## Load data

In [None]:
with open(ROOT_DIR + "/dataset.csv", "r") as f:
    csv_reader = csv.reader(f, delimiter=",")
    lines = list(csv_reader)

pos_f = glob2.glob(f"{ROOT_DIR}/positives/*.npy")
pos_data = {p.split("/")[-1]:np.load(p) for p in pos_f}
neg_f = glob2.glob(f"{ROOT_DIR}/negatives/*.npy")
neg_data = {n.split("/")[-1]:np.load(n) for n in neg_f}

# remove any annotation whose number of agreeing annotators does not match the requirements
for i in range(len(lines)):
    if lines[i][2] == "positive":
        to_keep = []
        for j in range(3, len(lines[i]), 2):
            if int(lines[i][j+1]) >= MIN_ANNOTATORS_COUNT:
                to_keep.extend([lines[i][j]])
        lines[i][3:] = to_keep

posX, negX = [], []
original_idx_pos, original_idx_neg = [], []
for i, line in enumerate(lines):
    _X_list = posX if len(line) > 3 else negX # we have a positive sample only if it contains some positive timestamps
    _original_idx = original_idx_pos if line[2]=="positive" else original_idx_neg
    data = pos_data if line[2]=="positive" else neg_data
    station = line[0].split("/")[-1]
    idx = int(line[1])
    _X_list.append(data[station][idx])
    _original_idx.append(i)
    
posX, negX = np.array(posX), np.array(negX)
posY, negY = np.ones(len(posX)), np.zeros(len(negX))
X = np.concatenate((posX, negX))
Y = np.concatenate((posY, negY))
original_idx = np.concatenate((original_idx_pos, original_idx_neg))
print(f"{len(posX)} positive samples and {len(negX)} negative samples found")

## Load the model and apply it

In [None]:
with open(CHECKPOINT, 'rb') as f:
    model = pickle.load(f)
pred = model.predict_proba(X)[:,1]
print(metrics.roc_auc_score(Y, pred))

TP, FP, FN = [], [], []

for i, p in enumerate(pred):
    if p>0.5:
        # detection is positive
        if Y[i]==1:
            TP.append(original_idx[i])
        else:
            FP.append(original_idx[i])
    else:
        # detection is negative
        if Y[i]==1:
            FN.append(original_idx[i])

In [None]:
Path(f"../../../../data/npy/{OUTPUT_DIR}").mkdir(exist_ok=True, parents=True)
Path(f"../../../../data/figures/{OUTPUT_DIR}").mkdir(exist_ok=True, parents=True)

np.save(f"../../../../data/npy/{OUTPUT_DIR}/TP.npy", TP)
np.save(f"../../../../data/npy/{OUTPUT_DIR}/FP.npy", FP)
np.save(f"../../../../data/npy/{OUTPUT_DIR}/FN.npy", FN)

ROC curve

In [None]:
TPr, FPr = compute_ROC(pred[Y==1], np.count_nonzero(Y==1), pred[Y==0], np.count_nonzero(Y==0), thresh_delta=0.001)
plt.plot(FPr, TPr)
np.save(f"../../../../data/npy/{OUTPUT_DIR}/FPr.npy", FPr)
np.save(f"../../../../data/npy/{OUTPUT_DIR}/TPr.npy", TPr)
plt.xlim(0,1)
plt.ylim(0,1)
plt.ylabel("TP rate")
plt.xlabel("FP rate")
plt.title("ROC curve")
plt.savefig(f"../../../../data/figures/{OUTPUT_DIR}/ROC.png")