In [None]:
#| default_exp datamodels

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
#| hide

from typing import Any
from dataclasses import dataclass, asdict, fields
from pathlib import Path
from fastcore.test import test_fail

@dataclass
class BaseDataClass:
    def to_dict(self) -> dict[str, Any]:
        """Returning contents of the dataclass as a dictionary."""
        return asdict(self)

    @classmethod
    def from_dict(cls, **params) -> "BaseDataClass":
        """Creating dataclass from dictionary with data validation."""
        # getting all class fields
        all_fields = {field.name: field.type for field in fields(cls)}
        cleaned_params = {}
        for param in params.items():
            key, value = param
            # checking if input param is in fields
            if key in all_fields:
                # checking if value type is correct
                # bool is a subclass of int
                # print(value, type(value), all_fields[key])
                if type(value) is bool:
                    if type(value) is all_fields[key]:
                        cleaned_params[key] = value
                else:
                    if isinstance(value, all_fields[key]):
                        cleaned_params[key] = value
        return cls(**cleaned_params)

    @classmethod
    def validate(cls, params) -> dict[str, Any]:
        instance = cls.from_dict(**params)
        return instance.to_dict()

@dataclass
class Config(BaseDataClass):
    window_size: int = 10
    signal_to_noise_ratio: float = 3.0
    noise_window_size: int = 200
    signal_average_threshold: float = 10.0
    minimum_activity_counts: int = 2
    baseline_estimation_method: str = 'asls'
    include_variance: bool = False
    variance: int = 15
    limit_analysis_to_frame_interval: bool = False
    start_frame_idx: int = 0
    end_frame_idx: int = 500
    configure_octaves: bool = False
    octaves_ridge_needs_to_spann: float = 1.0
    save_overview_png: bool = True
    save_detailed_results: bool = True
    batch_mode: bool = False
    focus_area_enabled: bool = False
    roi_mode: str = 'grid'
    save_single_trace_results: bool = False
    data_source_path: Path = None
    recording_filepath: Path = None
    focus_area_filepath: Path = None
    filepath_analyzed_rois: list[str] = None

@dataclass
class AnalysisConfig(BaseDataClass):
    window_size: int
    limit_analysis_to_frame_interval: bool
    start_frame_idx: int
    end_frame_idx: int
    signal_average_threshold: float
    signal_to_noise_ratio: float
    octaves_ridge_needs_to_spann: float
    noise_window_size: int
    baseline_estimation_method: str
    include_variance: bool
    variance: int

@dataclass
class ResultsConfig(BaseDataClass):
    save_overview_png: bool
    save_detailed_results: bool
    save_single_trace_results: bool
    minimum_activity_counts: int
    signal_average_threshold: float
    signal_to_noise_ratio: float

@dataclass
class AnalysisJobConfig(BaseDataClass):
    roi_mode: str
    batch_mode: bool
    focus_area_enabled: bool
    data_source_path: Path


@dataclass
class Peak(BaseDataClass):
    frame_idx: int
    intensity: float
    amplitude: float | None = None
    delta_f_over_f: float | None = None
    has_neighboring_intersections: bool | None = None
    frame_idxs_of_neighboring_intersections: tuple | None = None
    area_under_curve: float | None = None
    peak_type: str | None = None

In [None]:
#| export
#| hide
# from neuralactivitycubic.view import WidgetsInterface

correct_general_config = Config().to_dict()  # needs to be added here until implemented in GUI

recording_filepath = Path('../test_data/00/spiking_neuron.avi')
correct_analysis_job_config = {
    'roi_mode': 'grid',
    'batch_mode': True,
    'focus_area_enabled': True,
    'data_source_path': recording_filepath,
}
correct_peak_config = {
    'frame_idx': 10,
    'intensity': 10.0,
    'amplitude': 10.0,
    'delta_f_over_f': 10.0,
    'has_neighboring_intersections': True,
    'frame_idxs_of_neighboring_intersections': (1,2),
    'area_under_curve': 10.0,
    'peak_type': 'normal',
}
minimal_peak_config = {
    'frame_idx': 10,
    'intensity': 10.0,
}

def test_correct_analysis_config():
    return AnalysisConfig.from_dict(**correct_general_config)

def test_correct_analysis_job_config():
    return AnalysisJobConfig.from_dict(**correct_analysis_job_config)

def test_correct_results_config():
    return ResultsConfig.from_dict(**correct_general_config)

def test_correct_peak_config():
    return Peak.from_dict(**correct_peak_config)

def test_minimal_peak_config():
    return Peak.from_dict(**minimal_peak_config)

incomplete_analysis_config = correct_general_config.copy()
incomplete_analysis_config.pop('window_size')

def test_incomplete_analysis_config():
    return AnalysisConfig.from_dict(**incomplete_analysis_config)

wrong_analysis_config = correct_general_config.copy()
wrong_analysis_config['window_size'] = 'haha'

def test_wrong_analysis_config():
    return AnalysisConfig.from_dict(**wrong_analysis_config)

incomplete_results_config = correct_general_config.copy()
incomplete_results_config.pop('signal_to_noise_ratio')

def test_incomplete_results_config():
    return ResultsConfig.from_dict(**incomplete_results_config)

wrong_results_config = correct_general_config.copy()
wrong_results_config['signal_to_noise_ratio'] = True

def test_wrong_results_config():
    return ResultsConfig.from_dict(**wrong_results_config)

incomplete_peak_config = correct_peak_config.copy()
incomplete_peak_config.pop('frame_idx')

def test_incomplete_peak_config():
    return Peak.from_dict(**incomplete_peak_config)

wrong_peak_config = correct_peak_config.copy()
wrong_peak_config['frame_idx'] = False

def test_wrong_peak_config():
    return Peak.from_dict(**wrong_peak_config)

# Add a check to ensure Config contains all required fields for other dataclasses
def _check_config_fields():
    config_fields = set(Config.__dataclass_fields__.keys())
    required = {
        'AnalysisConfig': set(AnalysisConfig.__dataclass_fields__.keys()),
        'ResultsConfig': set(ResultsConfig.__dataclass_fields__.keys()),
        'AnalysisJobConfig': set(AnalysisJobConfig.__dataclass_fields__.keys()),
        # Peak is not a config, so skip
    }
    for name, fields_set in required.items():
        missing = fields_set - config_fields
        if missing:
            print(f"Config is missing fields required by {name}: {missing}")
        else:
            print(f"Config contains all fields required by {name}.")

_check_config_fields()

In [None]:
#| hide

# correct inputs tests
assert isinstance(test_correct_analysis_config(), AnalysisConfig)
assert isinstance(test_correct_analysis_job_config(), AnalysisJobConfig)
assert isinstance(test_correct_results_config(), ResultsConfig)
assert isinstance(test_correct_peak_config(), Peak)
assert isinstance(test_minimal_peak_config(), Peak)

# incomplete inputs tests
test_fail(test_incomplete_analysis_config)
test_fail(test_incomplete_results_config)
test_fail(test_incomplete_peak_config)

# wrong inputs tests
test_fail(test_wrong_analysis_config)
test_fail(test_wrong_results_config)
test_fail(test_wrong_peak_config)
