# evaluations

> This module defines the outputs of the experiments you can run with pisces. 

In [1]:
#| default_exp evaluations

In [2]:
#| hide 
%load_ext autoreload
%autoreload 2

In [3]:
#| hide
from nbdev.showdoc import *


In [4]:
#| export

from enum import Enum
from typing import List

from enum import Enum, auto
from typing import Dict, List, Optional, Tuple, Union

from pathlib import Path

import numpy as np

## Test split evaluation

The fundamental organization of a `pisces` pipeline run is the train-test split. Because we tend to care about inference capabilities on unseen data, it's natural to think about the splits being parametrized by the testing set; often one uses LOO validation methods on patient data, so it links up naturally with thinking of the performance of a trained model _on this patient/subject._

In [5]:
#| export
from pisces.loader import StudyRecord

NameError: name 'ValidationMethod' is not defined

In [None]:

from dataclasses import dataclass

import numpy.typing as npt

from pisces.records import TimeseriesRecording


class PSGPredictionsShapeError(Exception):
    """
    An exception to be raised when the shape of the predictions is invalid.
    """
    def __init__(self, *args: object) -> None:
        super().__init__(*args)


class PSGPredictionsNotProbabilityVectorError(Exception):
    """
    An exception to be raised when the predictions are not a probability vector.
    """
    def __init__(self, *args: object) -> None:
        super().__init__(*args)


@dataclass
class PSGModelOutputs(TimeseriesRecording):
    """
    TimeseriesRecording output from models; models should type a specific subclass of
    this in the return type of their evaluate method.

    Probabilities should be a NumPy array shaped such that
        self.probabilities[i]
    indicates the probabilities of each sleep stage for times between
        self.time[i] and self.time[i+1]
    In particular, self.probabilities should have shape
        (N,) or (N, m)
    where N == len(self.time) and m = number of stages
    """

    probabilities: npt.NDArray[np.float64]

    def __post_init__(self):
        """This is a special dunder method for dataclasses that is called after __init__. You can override this in your subclass to do any additional validation; call super().__post_init__() within your override to run this class's method as part of it."""
        if len(self.time) != len(self.probabilities):
            raise PSGPredictionsShapeError(f"{len(self.probabilities)} == len(probabilities) must equal len(time) == {len(self.time)}")

        if not np.all(self.probabilities <= 1.0):
            raise PSGPredictionsNotProbabilityVectorError(
                "Probabilities vector must contain entries between 0.0 and 1.0, inclusive."
            )

        if self.probabilities.shape not in self.valid_probability_shapes(self.time):
            raise PSGPredictionsShapeError("Probabilities vector has invalid shape, must be one of: " + str(self.valid_probability_shapes()))
    
    def valid_probability_shapes(self) -> List[Tuple[int, ...]]:
        """
        Returns a list of valid shapes for the probabilities array
        """
        pass

    def _sort_specific_data(self, sort_idx: npt.NDArray[np.int64]) -> None:
        self.probabilities = self.probabilities[sort_idx]

    def _trim_specific_data(self, select_idx: np.ndarray) -> None:
        if select_idx.dtype == bool:
            # Reduce to case where indices to include are provided
            select_idx = np.nonzero(select_idx)
        self.probabilities = self.probabilities[select_idx]


class PSGSleepWakePredictions(PSGModelOutputs):
    def __post_init__(self):
        super().__post_init__()
    
    def valid_probability_shapes(self) -> List[Tuple[int, ...]]:
        return [(len(self.time),), (len(self.time), 1), (len(self.time), 2)]
    


In [None]:
class TestSplit:
    def __init__(self, test_split: List[StudyRecord]):
        self.test_split = test_split


class TestSplitEvaluation:
    pass

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()