In [1]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score, f1_score
from pathlib import Path

In [None]:
csv_path = 'src/results/'  # Update this path if needed

# Load and display ground truth
big_gt_df = pd.read_csv(os.path.join(csv_path, "ground_truth.csv"))

# Group by subject and count number of annotations
subject_counts = big_gt_df.groupby("Patient").size()

# Print results
for subject, count in subject_counts.items():
    print(f"{subject}: {count} annotations")

chaco_Epi-001_20080306_02.ds: 25 annotations
chaco_Epi-001_20080306_04.ds: 38 annotations
chaco_Epi-001_20080306_06.ds: 42 annotations
conti_Epi-001_20090709_05.ds: 20 annotations
conti_Epi-001_20090709_07.ds: 25 annotations
conti_Epi-001_20090709_08.ds: 38 annotations
criso_Epi-001_20100322_05.ds: 6 annotations
criso_Epi-001_20100322_12.ds: 21 annotations
criso_Epi-001_20100322_13.ds: 9 annotations
ennso_Epi-001_20071017_05.ds: 24 annotations
ennso_Epi-001_20071017_06.ds: 21 annotations
ennso_Epi-001_20071017_07.ds: 20 annotations
lioen_Epi-001_20061211_08.ds: 9 annotations
lioen_Epi-001_20061211_11.ds: 19 annotations
pecva_Epi-001_20090203_07.ds: 22 annotations
pecva_Epi-001_20090203_08.ds: 36 annotations
pecva_Epi-001_20090203_09.ds: 20 annotations
racdy_Epi-001_20090929_02.ds: 46 annotations
racdy_Epi-001_20090929_05.ds: 62 annotations
racdy_Epi-001_20090929_06.ds: 70 annotations
samlo_Epi-001_20070402_04.ds: 7 annotations
samlo_Epi-001_20070402_11.ds: 22 annotations
samlo_Epi-001_

In [21]:
# Load and display predictions
big_pred_df = pd.read_csv(os.path.join(csv_path, "predictions.csv"))

In [26]:
# === Parameters ===
tolerance = 0.1  # 50 ms tolerance for spike match (in seconds)
subject = "chaco_Epi-001_20080306_06.ds"
model = "model_features_only.keras"
data_dir = Path("data/testData")

for model in ['model_features_only.keras', 'model_CNN.keras']:
    for subject_path in data_dir.rglob("*.ds"):
        print(f'{model}-{subject_path}')
        subject = os.path.basename(subject_path)
        # === Filter by subject and model ===
        gt_df = big_gt_df[(big_gt_df["Patient"] == subject)]
        pred_df = big_pred_df[(big_pred_df["Patient"] == subject) & (big_pred_df["Model"] == model)]

        # === Check if data is available ===
        if gt_df.empty or pred_df.empty:
            print(f"No matching data found for subject '{subject}' and model '{model}'.")

        else:

            # === Prepare ground truth array ===
            gt_spikes = np.array(sorted(gt_df["SpikeTime_s"].values))

            # === Prepare predicted spike array and scores ===
            pred_spikes = np.array(pred_df["SpikeTime_s"].values)
            pred_scores = np.array(pred_df["Probability"].values)

            # Drop duplicates by keeping the one with the highest probability
            pred_df = pred_df.sort_values("Probability", ascending=False)
            pred_df = pred_df.drop_duplicates(subset=["Patient", "SpikeTime_s"], keep="first")

            # === Label each prediction as TP or FP ===
            y_true = []
            y_scores = []

            used_gt_indices = set()

            for pred_time, score in zip(pred_spikes, pred_scores):
                # Find matching ground truth spike within tolerance
                match_idx = np.where(np.abs(gt_spikes - pred_time) <= tolerance)[0]
                match_idx = [i for i in match_idx if i not in used_gt_indices]

                if match_idx:
                    y_true.append(1)  # True Positive
                    used_gt_indices.add(match_idx[0])  # Avoid double matching
                else:
                    y_true.append(0)  # False Positive

                y_scores.append(score)

            # Add False Negatives (missed GT spikes)
            n_fn = len(gt_spikes) - len(used_gt_indices)
            if n_fn > 0:
                y_true += [1] * n_fn
                y_scores += [0.0] * n_fn  # Score 0 for missing spikes

            # === Compute PR Curve ===
            precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
            ap_score = average_precision_score(y_true, y_scores)

            # Remove points where precision or recall is zero
            valid_mask = (precision[:-1] > 0) & (recall[:-1] > 0) & (thresholds > 0.01)
            filtered_precision = precision[:-1][valid_mask]
            filtered_recall = recall[:-1][valid_mask]
            filtered_thresholds = thresholds[valid_mask]

            f1_scores = 2 * (precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1] + 1e-10)
            if len(f1_scores) > 0:
                best_idx = np.argmax(f1_scores)
                best_threshold = filtered_thresholds[best_idx]

            print(f"Best threshold by F1 score: {best_threshold:.3f}")
            print(f"Precision: {precision[best_idx]:.3f}, Recall: {recall[best_idx]:.3f}, F1: {f1_scores[best_idx]:.3f}")

            # # === Plot PR Curve ===
            # plt.figure(figsize=(4, 4))
            # plt.plot(recall, precision, marker='.', label=f'AP = {ap_score:.3f}')
            # plt.xlabel('Recall')
            # plt.ylabel('Precision')
            # plt.title(f'PR Curve - Subject: {subject} | Model: {model}')
            # plt.legend()
            # plt.grid(True)
            # plt.tight_layout()
            # plt.show()

            # plt.plot(thresholds, f1_scores)
            # plt.xlabel("Threshold")
            # plt.ylabel("F1 Score")
            # plt.title("F1 Score vs. Threshold")
            # plt.grid(True)
            # plt.show()

model_features_only.keras-data/testData/Pecva_AllDataset1200Hz/pecva_Epi-001_20090203_08.ds
Best threshold by F1 score: 0.436
Precision: 0.229, Recall: 0.750, F1: 0.351
model_features_only.keras-data/testData/Pecva_AllDataset1200Hz/pecva_Epi-001_20090203_09.ds
Best threshold by F1 score: 0.533
Precision: 0.212, Recall: 0.700, F1: 0.326
model_features_only.keras-data/testData/Pecva_AllDataset1200Hz/pecva_Epi-001_20090203_07.ds
Best threshold by F1 score: 0.444
Precision: 0.283, Recall: 0.591, F1: 0.382
model_features_only.keras-data/testData/Pecva_AllDataset1200Hz/pecva_Epi-001_20090203_08.ds/hz.ds
No matching data found for subject 'hz.ds' and model 'model_features_only.keras'.
model_features_only.keras-data/testData/Pecva_AllDataset1200Hz/pecva_Epi-001_20090203_09.ds/hz.ds
No matching data found for subject 'hz.ds' and model 'model_features_only.keras'.
model_features_only.keras-data/testData/Pecva_AllDataset1200Hz/pecva_Epi-001_20090203_07.ds/hz.ds
No matching data found for subject 

In [33]:
import pandas as pd
import numpy as np

# === Parameters ===
tolerance = 0.200  # 100 ms match window
gt_csv = "src/cache-directory/ground_truth.csv"
pred_csv = "src/cache-directory/predictions.csv"

# === Load CSVs ===
gt_df = pd.read_csv(gt_csv)
pred_df = pd.read_csv(pred_csv)

# Drop duplicates by keeping the one with the highest probability
pred_df = pred_df.sort_values("Probability", ascending=False)
pred_df = pred_df.drop_duplicates(subset=["Patient", "SpikeTime_s"], keep="first")

# === Optional: check ===
assert "SpikeTime_s" in gt_df.columns and "SpikeTime_s" in pred_df.columns

def compute_matches(model_onsets, gt_onsets, delta):
    true_positive = 0
    false_positive = 0
    false_negative = 0

    matched_gt = set()
    tp_distances = []
    fp_distances = []
    fn_distances = []

    for m in model_onsets:
        matched = False
        for g in gt_onsets:
            if abs(m - g) <= delta and g not in matched_gt:
                true_positive += 1
                matched_gt.add(g)
                tp_distances.append(abs(m - g))
                matched = True
                break
        if not matched:
            false_positive += 1
            # Distance to closest ground truth
            closest_gt = min(gt_onsets, key=lambda g: abs(m - g)) if gt_onsets else None
            if closest_gt is not None:
                fp_distances.append(abs(m - closest_gt))

    for g in gt_onsets:
        if g not in matched_gt:
            false_negative += 1
            # Distance to closest model prediction
            closest_model = min(model_onsets, key=lambda m: abs(m - g)) if model_onsets else None
            if closest_model is not None:
                fn_distances.append(abs(g - closest_model))

    return true_positive, false_positive, false_negative, tp_distances, fp_distances, fn_distances

# === Optimize threshold per model ===
thresholds = np.linspace(0.0, 1.0, 101)
final_results = []

for model_name in pred_df["Model"].unique():
    model_pred = pred_df[pred_df["Model"] == model_name]
    best_f1 = 0
    best_threshold = 0
    best_metrics = {}

    for thresh in thresholds:
        filtered_pred = model_pred[model_pred["Probability"] >= thresh]

        all_tp, all_fp, all_fn = 0, 0, 0
        for patient in filtered_pred["Patient"].unique():
            gt_times = gt_df[gt_df["Patient"] == patient]["SpikeTime_s"].tolist()
            pred_times = filtered_pred[filtered_pred["Patient"] == patient]["SpikeTime_s"].tolist()

            tp, fp, fn, _, _, _ = compute_matches(gt_times, pred_times, tolerance)
            all_tp += tp
            all_fp += fp
            all_fn += fn

        precision = all_tp / (all_tp + all_fp + 1e-10)
        recall = all_tp / (all_tp + all_fn + 1e-10)
        f1 = 2 * (precision * recall) / (precision + recall + 1e-10)

        if f1 > best_f1:
            best_f1 = f1
            best_threshold = thresh
            best_metrics = {
                "Model": model_name,
                "BestThreshold": round(thresh, 3),
                "TP": all_tp,
                "FP": all_fp,
                "FN": all_fn,
                "Precision": round(precision, 3),
                "Recall": round(recall, 3),
                "F1": round(f1, 3)
            }

    final_results.append(best_metrics)

# === Display results ===
best_perf_df = pd.DataFrame(final_results)
print(best_perf_df)


                       Model  BestThreshold   TP   FP   FN  Precision  Recall  \
0            model_CNN.keras           0.51  308  382  392      0.446   0.440   
1  model_features_only.keras           0.54  203  440  312      0.316   0.394   

      F1  
0  0.443  
1  0.351  
