In [1]:
import os
from pathlib import Path

config_fpath = Path.home() / ".fiftyone" / "config.global_mongodb.json"
os.environ["FIFTYONE_CONFIG_PATH"] = str(config_fpath)

from typing import List
from tqdm import tqdm
import numpy as np
import gradio as gr

import fiftyone as fo
from fiftyone.core.utils import pprint
from fiftyone.utils.iou import compute_ious

metrics_description = """
### Detection Consistency (for "% Ground Truth Seen")
- **Definition**: This term refers to the system's ability to reliably detect and track objects of interest within the monitored environment. It encapsulates how well the system performs in identifying objects across different conditions and time windows.
- **Why It Matters**: It directly speaks to the system's effectiveness in covering the area it monitors, ensuring that fewer objects go unnoticed. Higher Detection Consistency means users can trust the system to catch more of what's happening, reducing the risk of missing critical events.

### Precision Alerting (for "% Matched Predictions")
- **Definition**: This focuses on the accuracy and relevance of the alerts generated by the system. It reflects the system's ability to distinguish between true objects of interest and false alarms.
- **Why It Matters**: By optimizing for precision alerting, the system minimizes unnecessary notifications, focusing users' attention on what truly matters. This is especially important in high-stakes environments where each alert could require significant resources to address.
"""

def load_sample():
    """
    Dummy function to mimic dataset loading, replace with actual dataset loading logic.
    """
    fpath = "/mnt/FiftyOneSentry/Database/Sentry_Videos/Videos_8Bit/Sentry_2023_02_Portugal/Sentry_recordings_2023_01_24_20_47_46/Sentry_TwFoV_record_2023_01_24_20_47_46.mp4"
    dataset = fo.load_dataset("SENTRY_VIDEOS_DATASET_QA")
    return dataset[fpath]


def find_range_index(ranges: List[tuple], n: int) -> int:
    """
    Finds the index of the range that contains the given number.

    Args:
        ranges: A list of tuples representing the ranges.
        n: The number to find the range index for.

    Returns:
        The index of the range that contains the number, or -1 if no range contains the
        number.
    """
    return next((i for i, (start, end) in enumerate(ranges) if start <= n <= end), -1)


def update_stats(
    frame: fo.Frame, gt_field: str, pred_field: str, gt_stats: dict, pred_stats: dict
) -> None:
    """
    Updates the statistics for ground truth and predictions based on the given frame.

    Args:
        frame: The frame containing the ground truth and predictions.
        pred_field: The field name for the predictions.
        gt_stats: A dictionary to store the ground truth statistics.
        pred_stats: A dictionary to store the prediction statistics.
    """
    preds: List[fo.Detection] = frame[f"{pred_field}.detections"]
    gts: List[fo.Detection] = frame[f"{gt_field}.detections"]

    for pred in preds:
        if pred.index not in pred_stats:
            pred_stats[pred.index] = {"pred_count": 1, "matched": 0}
        else:
            pred_stats[pred.index]["pred_count"] += 1

    for gt in gts:
        if gt.index not in gt_stats:
            gt_stats[gt.index] = {"gt_count": 1, "tp_count": 0}
        else:
            gt_stats[gt.index]["gt_count"] += 1

        ious = compute_ious([gt], preds)
        if np.any(ious > 0):
            gt_stats[gt.index]["tp_count"] += 1
            matched_idx = np.where(ious > 0)[1]
            for idx in matched_idx:
                pred_stats[preds[idx].index]["matched"] += 1


def calculate_chunk_metrics(
    time_ranges: List[tuple], frames: List[fo.Frame], gt_field: str, pred_field: str
) -> tuple:
    """
    Calculates the statistics for each chunk of frames.

    Args:
        time_ranges: A list of tuples representing the time ranges for each chunk.
        frames: A list of frames.
        pred_field: The field name for the predictions.

    Returns:
        Two lists of dictionaries representing the ground truth and prediction
        statistics for each chunk.
    """
    gt_stats_chunks = [{} for _ in range(len(time_ranges))]
    pred_stats_chunks = [{} for _ in range(len(time_ranges))]

    for n, frame in frames.items():
        i = find_range_index(time_ranges, n)
        if i == -1:
            continue
        try:
            update_stats(
                frame, gt_field, pred_field, gt_stats_chunks[i], pred_stats_chunks[i]
            )
        except TypeError:
            pass  # dirty fix for the error

    return gt_stats_chunks, pred_stats_chunks


def calculate_final_metrics(
    gt_stats_chunks: List[dict],
    pred_stats_chunks: List[dict],
    percentage_seen: float,
    percentage_matched: float,
) -> tuple:
    """
    Calculates the final metrics based on the ground truth and prediction statistics.

    Args:
        gt_stats_chunks (list):
            A list of dictionaries representing the ground truth statistics for each
            chunk.
        pred_stats_chunks (list):
            A list of dictionaries representing the prediction statistics for each
            chunk.
        percentage_seen (float):
            The minimum percentage of ground truth seen to consider it as a true
            positive.
        percentage_matched (float):
            The minimum percentage of matched predictions to consider it as a true
            positive.

    Returns:
        The Detection Consistency and precision alerting metrics.
    """
    tps = []
    fps = []
    for gt_chunk, pred_chunk in zip(gt_stats_chunks, pred_stats_chunks):

        for gt_stats in gt_chunk.values():
            gt_stats["percentage_seen"] = gt_stats["tp_count"] / gt_stats["gt_count"]
            gt_stats["true_positive"] = gt_stats["percentage_seen"] >= percentage_seen
            tps.append(gt_stats["true_positive"])

        for pred_stats in pred_chunk.values():
            pred_stats["percentage_matched"] = (
                pred_stats["matched"] / pred_stats["pred_count"]
            )
            pred_stats["true_positive"] = (
                pred_stats["percentage_matched"] >= percentage_matched
            )
            fps.append(pred_stats["true_positive"])

    detection_reliability = np.mean(tps)
    precision_alerting = np.mean(fps)
    return detection_reliability, precision_alerting


def calculate_chunks(
    gt_field: str,
    pred_field: str,
    time_window: float,
) -> tuple:
    """
    Calculates the Detection Consistency and precision alerting metrics.

    Args:
        sample_pkl_path (str): The path to the sample pickle file.
        pred_field (str): The field name for the predictions.
        time_window (float): The time window in seconds.
        percentage_seen (float): The minimum percentage of ground truth seen to consider
            it as a true positive.
        percentage_unmatched (float): The maximum percentage of unmatched predictions to
            consider it as a false positive.

    Returns:
        Tuple[float, float]: The Detection Consistency and precision alerting metrics.
    """
    sample = load_sample()
    frames = sample.frames

    frames_window = int(time_window * sample.metadata.frame_rate)
    time_ranges = [
        (i, min(i + frames_window, len(frames)))
        for i in range(1, len(frames) + 1, frames_window)
    ]

    return calculate_chunk_metrics(time_ranges, frames, gt_field, pred_field)


def gradio_interface(
    _,  # video,
    percentage_seen=0.1,
    percentage_matched=0.9,
):
    detection_reliability, precision_alerting = calculate_final_metrics(
        gt_stats_chunks, pred_stats_chunks, percentage_seen, percentage_matched
    )

    metrics = {
        "Detection Consistency": detection_reliability,
        "Precision Alerting": precision_alerting,
    }
    return metrics, metrics_description


iface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Video(
            "/Users/kevinserrano/Downloads/Sentry_TwFoV_record_2023_01_24_20_47_46.mp4",
            interactive=False,
            label="sample",
            autoplay=True
        ),
        gr.Slider(
            minimum=0, maximum=1, step=0.01, value=0.1, label="% Ground Truth Seen"
        ),
        gr.Slider(
            minimum=0,
            maximum=1,
            step=0.01,
            value=0.1,
            label="% Matched Predictions",
        ),
    ],
    outputs=[
        gr.Label(label="User Metrics"),
        gr.Markdown(metrics_description),
    ],
    title="User Metrics Calculator",
    allow_flagging='never',
)

gt_stats_chunks, pred_stats_chunks = calculate_chunks(
    "ground_truth_det",
    "volcanic-sweep-3_02_2023_N_LN1_ep288_TRACKER",
    10,
)

pprint(gt_stats_chunks)
pprint(pred_stats_chunks)

iface.launch()

[
    {
        1: {'gt_count': 301, 'tp_count': 0},
        2: {'gt_count': 301, 'tp_count': 215},
    },
    {
        1: {'gt_count': 300, 'tp_count': 4},
        2: {'gt_count': 300, 'tp_count': 237},
    },
    {
        1: {'gt_count': 300, 'tp_count': 15},
        2: {'gt_count': 300, 'tp_count': 250},
        3: {'gt_count': 280, 'tp_count': 206},
    },
    {
        3: {'gt_count': 201, 'tp_count': 201},
        1: {'gt_count': 201, 'tp_count': 56},
        2: {'gt_count': 201, 'tp_count': 180},
    },
]
[
    {1: {'pred_count': 229, 'matched': 215}},
    {
        1: {'pred_count': 292, 'matched': 237},
        23: {'pred_count': 4, 'matched': 4},
    },
    {
        1: {'pred_count': 300, 'matched': 250},
        31: {'pred_count': 207, 'matched': 206},
        23: {'pred_count': 15, 'matched': 15},
    },
    {
        31: {'pred_count': 201, 'matched': 201},
        1: {'pred_count': 201, 'matched': 180},
        44: {'pred_count': 56, 'matched': 56},
    },
]
Running on

