---
title: "Datasets"
description: This module contains the useful functions to deal with candidates datasets and individual candidate.
---

In [None]:
# | default_exp datasets
# | export
import polars as pl
from datetime import timedelta
from pathlib import Path
from pydantic import BaseModel, Field, ConfigDict
from beforerr.project import savename, datadir, produce_or_load_file
from space_analysis.core import MagVariable
from space_analysis.meta import PlasmaDataset, TempDataset
from discontinuitypy.utils.naming import standardize_plasma_data
from discontinuitypy.detection.variance import detect_variance

from typing_extensions import deprecated
from typing import Callable, Literal
from loguru import logger

## Datasets

Fundamental class

In [None]:
# | export
from discontinuitypy.utils.basic import df2ts
from discontinuitypy.integration import update_events
from discontinuitypy.core.pipeline import ids_finder

In [None]:
# | exporti
def select_row(df: pl.DataFrame, index: int):
    if "index" not in df.columns:
        df = df.with_row_index()
    predicate = pl.col("index") == index
    return df.row(by_predicate=predicate, named=True)

In [None]:
# | export
class IdsEvents(BaseModel):
    """Core class to handle discontinuity events in a dataset."""

    model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)

    name: str = None
    data: pl.LazyFrame = None
    mag_meta: MagVariable = MagVariable()

    ts: timedelta = None
    """time resolution of the dataset"""
    events: pl.DataFrame = None
    detect_func: Callable = detect_variance
    detect_kwargs: dict = Field(default_factory=dict)
    method: Literal["fit", "derivative"] = "fit"

    file_fmt: str = "arrow"
    file_path: Path = datadir()

    @deprecated("Use `find_events` instead")
    def find_events(self, **kwargs):
        data, _ = self.produce_or_load(**kwargs)
        self.events = data
        return self

    @property
    def file_prefix(self):
        return f"events_{self.name or ''}".removesuffix("_")

    @property
    def config_detection(self):
        detect_kwargs = dict(ts=self.ts or self.mag_meta.ts) | self.detect_kwargs
        return dict(
            detection_df=self.data,
            detect_func=self.detect_func,
            detect_kwargs=detect_kwargs,
            bcols=self.mag_meta.B_cols,
            method=self.method,
        )

    def produce_or_load(self, **kwargs):
        config = self.config_detection | kwargs
        return produce_or_load_file(
            f=ids_finder,
            config=config,
            file=self.file,
        )

    @property
    def file(self):
        fname = savename(
            c=self.config_detection,
            prefix=self.file_prefix,
            suffix=self.file_fmt,
            allowedtypes=(str, timedelta, dict, Callable),
            expand=["detect_kwargs"],
        )
        return self.file_path / fname

    def get_event(self, index: int):
        return select_row(self.events, index)

    def get_event_data(
        self,
        event,
        start_col="t.d_start",
        end_col="t.d_end",
        offset=timedelta(seconds=1),
        **kwargs,
    ):
        start = event[start_col] - offset
        end = event[end_col] + offset

        _data = self.data.filter(pl.col("time").is_between(start, end))
        return df2ts(_data, self.mag_meta.B_cols)

In [None]:
# | export
def log_event_change(event, logger=logger):
    logger.debug(
        f"""CHANGE INFO
        n.change: {event.get('n.change')}
        v.ion.change: {event.get('v.ion.change')}
        T.change: {event.get('T.change')}
        v.Alfven.change: {event.get('v.Alfven.change')}
        v.ion.change.l: {event.get('v.ion.change.l')}
        v.Alfven.change.l: {event.get('v.Alfven.change.l')}
        """
    )

In [None]:
# | export
class IDsDataset(IdsEvents):
    """Extend the IdsEvents class to handle plasma and temperature data."""

    data: pl.LazyFrame = Field(default=None, alias="mag_data")

    plasma_data: pl.LazyFrame = None
    plasma_meta: PlasmaDataset = PlasmaDataset()

    ion_temp_data: pl.LazyFrame = None
    ion_temp_meta: TempDataset = TempDataset()
    e_temp_data: pl.LazyFrame = None
    e_temp_meta: TempDataset = TempDataset()

    def plot(self, type="overview", event=None, index=None, **kwargs):
        event = event or self.get_event(index)
        if type == "overview":
            return self.overview_plot(event, **kwargs)

    @property
    def config_updates(self):
        return dict(
            plasma_data=self.plasma_data,
            plasma_meta=self.plasma_meta,
            ion_temp_data=self.ion_temp_data,
            e_temp_data=self.e_temp_data,
        )

    @property
    def file(self):
        # add prefix 'updated' to file name
        old_file = super().file
        return old_file.with_name(f"updated_{old_file.stem}{old_file.suffix}")

    def produce_or_load(self, **kwargs):
        events, file = super().produce_or_load(**kwargs)
        config = self.config_updates | dict(events=events) | kwargs
        return produce_or_load_file(f=update_events, config=config, file=self.file)

    def standardize(self):
        self.plasma_data = standardize_plasma_data(self.plasma_data, self.plasma_meta)
        return self

In [None]:
# def overview_plot(
#     self, event: dict, start=None, stop=None, offset=timedelta(seconds=1), **kwargs
# ):
#     # BUG: to be fixed
#     start = start or event["tstart"]
#     stop = stop or event["tstop"]

#     start -= offset
#     stop += offset

#     _plasma_data = self.plasma_data.filter(
#         pl.col("time").is_between(start, stop)
#     ).collect()

#     _mag_data = (
#         self.data.filter(pl.col("time").is_between(start, stop))
#         .collect()
#         .melt(
#             id_vars=["time"],
#             value_vars=self.bcols,
#             variable_name="B comp",
#             value_name="B",
#         )
#     )

#     v_df = _plasma_data.melt(
#         id_vars=["time"],
#         value_vars=self.plasma_meta.velocity_cols,
#         variable_name="veloity comp",
#         value_name="v",
#     )

#     panel_mag = _mag_data.hvplot(
#         x="time", y="B", by="B comp", ylabel="Magnetic Field", **kwargs
#     )
#     panel_n = _plasma_data.hvplot(
#         x="time", y=self.plasma_meta.density_col, **kwargs
#     ) * _plasma_data.hvplot.scatter(
#         x="time", y=self.plasma_meta.density_col, **kwargs
#     )

#     panel_v = v_df.hvplot(
#         x="time", y="v", by="veloity comp", ylabel="Plasma Velocity", **kwargs
#     )
#     panel_temp = _plasma_data.hvplot(
#         x="time", y=self.plasma_meta.temperature_col, **kwargs
#     )

#     mag_vlines = hv.VLine(event["t.d_start"]) * hv.VLine(event["t.d_end"])
#     plasma_vlines = hv.VLine(event.get("time_before")) * hv.VLine(
#         event.get("time_after")
#     )

#     logger.info(f"Overview plot: {event['tstart']} - {event['tstop']}")
#     log_event_change(event)

#     return (
#         panel_mag * mag_vlines
#         + panel_n * plasma_vlines
#         + panel_v * plasma_vlines
#         + panel_temp * plasma_vlines
#     ).cols(1)

In [None]:
# | hide
from nbdev import nbdev_export

nbdev_export()