In [None]:
#| default_exp datamodels

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

from typing import Any, get_args
from types import UnionType, GenericAlias
from dataclasses import dataclass, asdict, fields
from pathlib import Path
from enum import Enum, EnumType
from itertools import chain

from json import dumps

class BaselineEstimationMethod(str, Enum):
    """
    Enum for baseline estimation methods used in analysis.
    """
    ASLS = 'asls'
    FABC = 'fabc'
    PASLS = 'pasls'
    SDD = 'sdd'

    def __str__(self):
        return self.value

class ROIMode(str, Enum):
    """
    Enum for ROI modes used in analysis.
    """
    GRID = 'grid'
    FILE = 'file'

    def __str__(self):
        return self.value

@dataclass
class BaseDataClass:
    def to_dict(self) -> dict[str, Any]:
        """Returning contents of the dataclass as a dictionary."""
        return asdict(self)


    def display_all_attributes(self) -> list[str]:
        return [f'{key}: {value}' for key, value in self.to_dict().items()]


    @classmethod
    def from_dict(cls, **params) -> "BaseDataClass":
        """Creating dataclass from dictionary with data validation."""

        def _unpack_unions(_type: UnionType | type) -> list[type]:
            """
            Unpack UnionType to get all types in the union.
            """
            if type(_type) is not UnionType:
                return [_type]
            else:
                return list(chain.from_iterable([_unpack_unions(t) for t in get_args(_type)]))

        def _unpack_parameterized_generic(_type: GenericAlias | type) -> type:
            """
            Unpack parameterized generic types to get the base type.
            """
            if hasattr(_type, '__origin__'):
                # if type is a parameterized generic, return its origin
                # checking if __args__ are correct is left as an exercise for the reader
                return _type.__origin__
            else:
                # otherwise, return the type itself
                return _type

        # getting all class fields
        all_fields = {field.name: field.type for field in fields(cls)}
        for key, value in params.items():
            # checking if input param is in fields
            if key in all_fields:
                if value is not None:
                    if type(all_fields[key]) is UnionType:
                        # if field is UnionType, check if value is in any of the types
                        if not any(isinstance(value, _unpack_parameterized_generic(t)) for t in _unpack_unions(all_fields[key])):
                            raise TypeError(f'Wrong type for the field {key}')
                    # checking if value type is correct
                    # bool is a subclass of int
                    elif type(value) is bool:
                        if not type(value) is all_fields[key]:
                            raise TypeError(f'Wrong type for the field {key}')
                    # instead of checking for EnumType, we check if the value is in the Enum
                    elif type(all_fields[key]) is EnumType:
                        if value not in list(all_fields[key]):
                            raise ValueError(f'Wrong value for the field {key}')
                    else:
                        if not isinstance(value, _unpack_parameterized_generic(all_fields[key])):
                            raise TypeError(f'Wrong type for the field {key}')
        return cls(**params)

    @classmethod
    def validate(cls, params) -> dict[str, Any]:
        instance = cls.from_dict(**params)
        return instance.to_dict()

@dataclass
class Config(BaseDataClass):
    """
    Configuration for analysis.

    Attributes:
        ### General Settings ###

        data_source_path (str, default=None):
            Path to the source data file or directory to be analyzed. Must comply with the source data structure
            that is defined for the corresponding usage modes (see here:
            https://indoc-research.github.io/NeuralActivityCubic/using_the_gui.html#source-data-structure).
            Alternatively, source data locations can be defined using `recording_filepath`, `roi_filepath`,
            and `focus_area_filepath`.

        recording_filepath (str, default=None):
            Path to the recording file to be analyzed. Can be used instead of `data_source_path` to
            define the source data location.

        roi_filepath (str | list[str], default=None):
            Path or list of Paths to files that define the ROIs that are to be analyzed when `roi_mode = file`.
            Can be used instead of `data_source_path` to define source data locations.

        focus_area_filepath (str | list[str], default=None):
            Path or list of Paths to files that define the focus areas to which analysis shall be restricted
            when `focus_area_enabled = True`. Can be used instead of `data_source_path` to define source data
            locations.

        roi_mode (str, default='grid'):
            Mode for defining regions of interest (ROIs) that are analyzed for activity. Options are `grid` for
            automatic grid-based ROIs creation and `file` to load predefined ROIs from supplied files.

        batch_mode (bool, default=False):
            Whether to enable batch mode for processing multiple recordings sequentially. Requires
            `data_source_path` to be used and is not compatible with definition of individual source data
            locations.

        focus_area_enabled (bool, default=False):
            Whether to restrict analysis only to ROIs within specific focus area(s).


        ### Analysis Settings ###

        grid_size (int, default=10):
            Size (in pixels) of the individual squares forming the ROI grid when `roi_mode = grid`. For example,
            a value of 10 generates a grid composed of 10 × 10 pixel ROIs.

        signal_to_noise_ratio (float, default=3.0):
            Minimum signal-to-noise ratio (SNR) used by SciPy's `find_peaks_cwt` function (see here:
            https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.find_peaks_cwt.html) as `min_snr`
            for identifying peaks in the ROI signal intensity traces.

        noise_window_size (int, default=200):
            Window size (in frames) used by SciPy's `find_peaks_cwt` function (see here:
            https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.find_peaks_cwt.html) as `window_size`
            for estimating the local noise level when identifying signal peaks.

        mean_signal_threshold (float, default=10.0): # previously: signal_average_threshold
            Minimum average intensity across the entire analysis interval required for a ROI to be considered for
            peak detection. Helps exclude regions with low baseline signal by filtering out background noise before
            analysis.

        min_peak_count (int, default=2):
            Minimum number of detected peaks required in a ROI for it to be included in the final analysis results.
            ROIs with fewer peaks than this threshold are excluded. Set to `0` if all ROIs shall be included.

        baseline_estimation_method (str, default='asls'):
            Method used to estimate the signal baseline, required for calculating area-under-curve (AUC) of detected
            peaks. Options are based on the pybaselines library (see here:
            https://pybaselines.readthedocs.io/en/latest/) and are:
                - `asls`: Asymmetric Least Squares.
                - `fabc`: Fully Automatic Baseline Correction.
                - `pasls`: Peaked Signal's Asymmetric Least Squares.
                - `sdd`: Standard Deviation Distribution.
            Each method is applied with its default parameters as defined in pybaselines.

        include_variance (bool, default=False):
            Whether to compute signal variance as a proxy for neuronal excitability. Enables sliding window
            variance analysis for each ROI.

        variance_window_size (int, default=15):
            Size of the sliding window (in frames) used to compute signal variance for each ROI when
            `include_variance = True`.

        use_frame_range (bool, default=False):
            Whether to analyze only a specific frame interval from the recording. When enabled, analysis is limited
            to frames between `frame_start` and `frame_end`, inclusive.

        start_frame_idx (int, default=0):
            Index of the first frame to include in the analysis interval (inclusive) if `use_frame_range = True`.

        end_frame_idx (int, default=500):
            Index of the last frame to include in the analysis interval (inclusive) if `use_frame_range = True`.

        customize_octave_filtering (bool, default=False):
            Enables manual configuration of octave-based peak filtering via `min_octave_span`. This option should
            only be used by advanced users familiar with na3`s internal logic.

        min_octave_span  (float, default=1.0):
            Minimum number of octaves a peak ridge must span to be considered if `customize_octave_filtering = True`.
            Used to compute `min_length` for SciPy´s `find_peaks_cwt` function, based on the number of frames.


        ### Results Settings ###

        results_filepath (str, default=None):
            Path to the directory where analysis results will be saved. If not specified, results are saved in the
            current working directory.

        export_to_nwb (bool, default=True):
            Whether to generate an additional NWB (NeurodataWithoutBorders - https://nwb.org/) file alongside
            the standard result outputs. NWB is an open standard for organizing and sharing
            neurophysiology data, supporting long-term accessibility, reproducibility, and
            integration with other neuroscience tools. Enabling this option enhances data
            portability and compliance with community best practices.

        save_overview_png (bool, default=True):
            Whether to save an overview PNG image summarizing the analysis results.

        save_summary_results (bool, default=True):
            Whether to save detailed results, including the following files, depending on your analysis settings:
                - Individual_traces_with_identified_events.pdf
                - all_peak_results.csv
                - Amplitude_and_dF_over_F_results.csv
                - AUC_results.csv
                - Variance_area_results.csv

        save_single_trace_results (bool, default=False):
            Whether to save individual trace results for each ROI separately.
    """
    batch_mode: bool = False
    baseline_estimation_method: BaselineEstimationMethod = BaselineEstimationMethod.ASLS
    customize_octave_filtering: bool = False
    data_source_path: str = None
    end_frame_idx: int = 500
    export_to_nwb: bool = True
    focus_area_enabled: bool = False
    focus_area_filepath: str | list[str] = None
    grid_size: int = 10
    include_variance: bool = False
    mean_signal_threshold: float = 10.0
    min_octave_span: float = 1.0
    min_peak_count: int = 2
    noise_window_size: int = 200
    recording_filepath: str = None
    results_filepath: str = None
    roi_filepath: str | list[str] = None
    roi_mode: ROIMode = ROIMode.GRID
    save_overview_png: bool = True
    save_single_trace_results: bool = False
    save_summary_results: bool = True
    signal_to_noise_ratio: float = 3.0
    start_frame_idx: int = 0
    use_frame_range: bool = False
    variance_window_size: int = 15

    # lazy collection of all filepaths
    _paths = ['data_source_path', 'recording_filepath', 'roi_filepath', 'focus_area_filepath', 'results_filepath']


    def __post_init__(self):

        def _transform_filepath_or_list(fp: str | list[str] | None) -> Path | list[Path] | None:
            """
            Transform a string or list of strings to a Path or list of Paths.
            """
            if isinstance(fp, str):
                return Path(fp)
            elif isinstance(fp, list):
                return [Path(p) for p in fp]
            else:
                return None

        if self.data_source_path is not None:
            if self.roi_filepath is not None or self.recording_filepath is not None:
                raise ValueError('Cannot specify both `data_source_path` and `roi_filepath` or `recording_filepath`')
        for attr in self._paths:
            setattr(self, attr, _transform_filepath_or_list(getattr(self, attr)))

    def to_json(self) -> str:
        """Returning contents of the dataclass as a JSON string."""
        config_dict = self.to_dict()
        config_dict["data_source_path"] = str(config_dict["data_source_path"]) if config_dict["data_source_path"] else None
        config_dict["recording_filepath"] = str(config_dict["recording_filepath"]) if config_dict["recording_filepath"] else None
        config_dict["results_filepath"] = str(config_dict["results_filepath"]) if config_dict["results_filepath"] else None
        config_dict["focus_area_filepath"] = str(config_dict["focus_area_filepath"]) if config_dict["focus_area_filepath"] else None
        if type(config_dict["roi_filepath"]) is list:
            config_dict["roi_filepath"] = [str(roi) for roi in config_dict["roi_filepath"]]
        elif type(config_dict["roi_filepath"]) is str:
            config_dict["roi_filepath"] = str(config_dict["roi_filepath"])
        else:
            config_dict["roi_filepath"] = None
        return dumps(config_dict)


@dataclass
class Peak(BaseDataClass):
    frame_idx: int
    intensity: float
    amplitude: float | None = None
    delta_f_over_f: float | None = None
    has_neighboring_intersections: bool | None = None
    frame_idxs_of_neighboring_intersections: tuple | None = None
    area_under_curve: float | None = None
    peak_type: str | None = None

## Tests:

Setup for testing:

In [None]:
from fastcore.test import test_fail

filepath = '../test_data/00/spiking_neuron.avi'

correct_general_config = Config().to_dict()
correct_general_config['data_source_path'] = filepath

example_general_config = dict(
    batch_mode=False,
    baseline_estimation_method=BaselineEstimationMethod.ASLS,
    customize_octave_filtering=False,
    data_source_path=filepath,
    end_frame_idx=500,
    export_to_nwb=False,
    focus_area_enabled=False,
    focus_area_filepath=None,
    grid_size=10,
    include_variance=False,
    mean_signal_threshold=10.0,
    min_octave_span=1.0,
    min_peak_count=2,
    noise_window_size=200,
    recording_filepath=None,
    results_filepath=None,
    roi_filepath=None,
    roi_mode=ROIMode.GRID,
    save_overview_png=True,
    save_single_trace_results=False,
    save_summary_results=True,
    signal_to_noise_ratio=3.0,
    start_frame_idx=0,
    use_frame_range=False,
    variance_window_size=15,
)

correct_peak_config = {
    'frame_idx': 10,
    'intensity': 10.0,
    'amplitude': 10.0,
    'delta_f_over_f': 10.0,
    'has_neighboring_intersections': True,
    'frame_idxs_of_neighboring_intersections': (1,2),
    'area_under_curve': 10.0,
    'peak_type': 'normal',
}
minimal_peak_config = {
    'frame_idx': 10,
    'intensity': 10.0,
}

def test_correct_config():
    return Config.from_dict(**correct_general_config)

def test_minimal_config():
    return Config.from_dict(data_source_path=filepath)

def test_config_with_specific_filepaths():
    specific_config = correct_general_config.copy()
    specific_config['data_source_path'] = None
    specific_config['recording_filepath'] = filepath
    specific_config['roi_filepath'] = filepath
    specific_config['focus_area_filepath'] = filepath
    return Config.from_dict(**specific_config)

def test_config_with_specific_filepaths_list():
    specific_config = correct_general_config.copy()
    specific_config['data_source_path'] = None
    specific_config['recording_filepath'] = filepath
    specific_config['roi_filepath'] = [filepath, filepath]
    specific_config['focus_area_filepath'] = filepath
    return Config.from_dict(**specific_config)

def test_config_to_json():
    config = Config.from_dict(**example_general_config)
    return config.to_json()

# enumerators are special cases, so we need to check them separately
def test_config_enum_from_str():
    incorrect_config = correct_general_config.copy()
    incorrect_config['baseline_estimation_method'] = 'asls'  # valid value for baseline_estimation_method
    return Config.from_dict(**incorrect_config)

def test_incorrect_config_enum_from_str():
    incorrect_config = correct_general_config.copy()
    incorrect_config['baseline_estimation_method'] = 'not_valid'  # invalid value for baseline_estimation_method
    return Config.from_dict(**incorrect_config)

# general invalid types checks
def test_incorrect_config_batch_mode():
    incorrect_config = correct_general_config.copy()
    incorrect_config['batch_mode'] = 'invalid_value'  # invalid type for batch_mode
    return Config.from_dict(**incorrect_config)

def test_incorrect_config_end_frame_idx():
    incorrect_config = correct_general_config.copy()
    incorrect_config['end_frame_idx'] = 'invalid_value'  # invalid type for end_frame_idx
    return Config.from_dict(**incorrect_config)

# enumerators are special cases, so we need to check them separately
def test_incorrect_config_baseline_estimation_method():
    incorrect_config = correct_general_config.copy()
    incorrect_config['baseline_estimation_method'] = 1337  # invalid type for baseline_estimation_method
    return Config.from_dict(**incorrect_config)

# boolean fields are also special cases, so we need to check them separately
def test_incorrect_config_customize_octave_filtering():
    incorrect_config = correct_general_config.copy()
    incorrect_config['customize_octave_filtering'] = 'invalid_value'  # invalid type for customize_octave_filtering
    return Config.from_dict(**incorrect_config)

# roi_filepath is a special case, so we need to check it separately
def test_incorrect_config_roi_filepath():
    incorrect_config = correct_general_config.copy()
    incorrect_config['roi_filepath'] = (1, 2)  # invalid type for roi_filepath
    return Config.from_dict(**incorrect_config)

# other filepaths should not be set up with data_source_path
def test_incorrect_config_singular_path():
    incorrect_config = correct_general_config.copy()
    incorrect_config['recording_filepath'] = filepath
    incorrect_config['roi_filepath'] = filepath
    incorrect_config['focus_area_filepath'] = filepath
    return Config.from_dict(**incorrect_config)

def test_correct_peak_config():
    return Peak.from_dict(**correct_peak_config)

def test_minimal_peak_config():
    return Peak.from_dict(**minimal_peak_config)

def test_incomplete_peak_config():
    incomplete_peak_config = correct_peak_config.copy()
    incomplete_peak_config.pop('frame_idx')
    return Peak.from_dict(**incomplete_peak_config)

def test_wrong_peak_config():
    wrong_peak_config = correct_peak_config.copy()
    wrong_peak_config['frame_idx'] = 'invalid_value'
    return Peak.from_dict(**wrong_peak_config)

Run tests:

In [None]:
# correct inputs tests
assert isinstance(test_correct_config(), Config)
assert isinstance(test_minimal_config(), Config)
assert isinstance(test_config_enum_from_str(), Config)
assert isinstance(test_config_with_specific_filepaths(), Config)
assert isinstance(test_config_with_specific_filepaths_list(), Config)
assert isinstance(test_correct_peak_config(), Peak)
assert isinstance(test_minimal_peak_config(), Peak)
assert test_config_to_json() == dumps(example_general_config)

# incomplete inputs tests
test_fail(test_incomplete_peak_config)

# wrong inputs tests
test_fail(test_incorrect_config_batch_mode)
test_fail(test_incorrect_config_baseline_estimation_method)
test_fail(test_incorrect_config_customize_octave_filtering)
test_fail(test_incorrect_config_end_frame_idx)
test_fail(test_incorrect_config_roi_filepath)
test_fail(test_incorrect_config_singular_path)
test_fail(test_incorrect_config_enum_from_str)
test_fail(test_wrong_peak_config)
