# Metrics
Metrics we use repeatedly to benchmark machine learning models.

In [11]:
#| default_exp metrics

In [12]:
#| hide 
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
#| hide
from nbdev.showdoc import *
from fastcore.basics import *

In [14]:
#| export
import os
import re
import logging
import warnings
import numpy as np
import polars as pl
from pathlib import Path
from copy import deepcopy
from enum import Enum, auto
from functools import partial
from collections import defaultdict
from typing import Dict, List, Tuple
from pisces.deep_unet_support import *
from typing import DefaultDict, Iterable
from scipy.ndimage import gaussian_filter1d
from pisces.utils import determine_header_rows_and_delimiter

In [15]:
#| hide
import tempfile

## WASA _p_

We focus on the wake accuracy when the threshold for sleep-vs-wake binarization of class probabilities is chosen such that the sleep accuracy is _p_\%.

If we consider SLEEP to be the positive class, then the sleep accuracy is also the sensitivity, and the wake accuracy is the specificity. Thus, WASA _p_ is the specificity when the sensitivity is _p_\%.

### Keras Metrics

In [16]:
#| export

import keras.ops as ops
from keras.metrics import Metric, SpecificityAtSensitivity


In [17]:
#| export

class WASAMetric(Metric):
    def __init__(self, sleep_accuracy=0.95, from_logits: bool=False, **kwargs):
        name = f"WASA{int(100 * sleep_accuracy)}"
        super().__init__(name=name, **kwargs)
        self.sleep_accuracy = sleep_accuracy
        self.from_logits = from_logits
        self.specificity_metric = SpecificityAtSensitivity(sleep_accuracy)

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Convert 4-class probabilities to binary probabilities
        if sample_weight is None:
            sample_weight = ops.ones_like(y_true)
        
        if self.from_logits:
            y_pred = ops.softmax(y_pred)
        
        binary_probs = ops.sum(y_pred[..., 1:], axis=-1)  # Sum probabilities for classes 1, 2, 3 (sleep)
        binary_weight = ops.where(sample_weight > 0, 1.0, 0.0)
        binary_labels = ops.where(y_true > 0, 1.0, 0.0)  # 0 for wake, 1 for sleep
        self.specificity_metric.update_state(binary_labels, binary_probs, binary_weight)
    
    def result(self):
        return self.specificity_metric.result()

### NumPy Implementation

In [18]:
from typing import Tuple
import numpy as np

In [19]:
#| export


class PerformanceMetrics:
    def __init__(self, sleep_accuracy, wake_accuracy, tst_error, ):
        self.sleep_accuracy = sleep_accuracy
        self.wake_accuracy = wake_accuracy
        self.tst_error = tst_error


def apply_threshold(labels, predictions, threshold, wake_class:int = 0):
    true_wakes = np.where(labels == wake_class)[0]
    predicted_wakes = np.where(predictions > threshold)[0]

    # calculate the number of true positives
    wake_accuracy = len(set(true_wakes).intersection(
        set(predicted_wakes))) / len(true_wakes)

    # calculate the sleep accuracy
    true_sleeps = np.where(labels == 1 - wake_class)[0]
    predicted_sleeps = np.where((predictions <= threshold) & (labels != -1))[0]

    sleep_accuracy = len(set(true_sleeps).intersection(
        set(predicted_sleeps))) / len(true_sleeps)

    tst_error = (len(true_sleeps) - len(predicted_sleeps)) / 2  # Minutes


    return PerformanceMetrics(sleep_accuracy, wake_accuracy, tst_error)


def threshold_from_binary_search(labels, wake_probabilities,
                                 target_sleep_accuracy, wake_class: int = 0) -> float:

    # How close to the target wake false positive rate we need to be before stopping
    false_positive_buffer = 0.0001
    fraction_sleep_scored_as_sleep = -1
    binary_search_counter = 0

    max_attempts_binary_search = 50

    # While we haven't found the target wake false positive rate
    # (and haven't exceeded the number of allowable searches), keep searching:
    while (
        fraction_sleep_scored_as_sleep < target_sleep_accuracy - false_positive_buffer
        or fraction_sleep_scored_as_sleep
        >= target_sleep_accuracy + false_positive_buffer
    ) and binary_search_counter < max_attempts_binary_search:
        # If this is the first iteration on the binary search, initialize.
        if binary_search_counter == 0:
            threshold_for_sleep = 0.5
            threshold_delta = 0.25
        else:
            if (
                fraction_sleep_scored_as_sleep
                < target_sleep_accuracy - false_positive_buffer
            ):
                threshold_for_sleep = threshold_for_sleep + threshold_delta
                threshold_delta = threshold_delta / 2

            if (
                fraction_sleep_scored_as_sleep
                >= target_sleep_accuracy + false_positive_buffer
            ):
                threshold_for_sleep = threshold_for_sleep - threshold_delta
                threshold_delta = threshold_delta / 2

        performance = apply_threshold(
            labels, wake_probabilities, threshold_for_sleep, wake_class)
        fraction_sleep_scored_as_sleep = performance.sleep_accuracy
        print(f"WASA{int(target_sleep_accuracy * 100)}: {performance.wake_accuracy}")
        print("Fraction sleep correct: " + str(fraction_sleep_scored_as_sleep))
        print("Goal fraction sleep correct: " + str(target_sleep_accuracy))
        binary_search_counter = binary_search_counter + 1

    print("Declaring victory with " +
          str(fraction_sleep_scored_as_sleep) + "\n\n")

    print("Goal was: " + str(target_sleep_accuracy))
    return threshold_for_sleep

def wasa_metric(labels, predictions, weights, target_sleep_accuracy=0.95,
                wake_class: int = 0) -> Tuple[PerformanceMetrics, float]:
    labels = labels[weights > 0]
    predictions = predictions[weights > 0]

    labels[labels > 1] = 1

    threshold = threshold_from_binary_search(labels, predictions, target_sleep_accuracy, wake_class)

    perform = apply_threshold(
        labels, predictions, threshold, wake_class)

    return perform, threshold

In [20]:
#| hide
import nbdev
nbdev.nbdev_export()