---
title: "Datasets"
description: This module contains the useful functions to deal with candidates datasets and individual candidate.
---

In [2]:
#| default_exp datasets

In [3]:
#| export
import polars as pl
import polars.selectors as cs
import pandas as pd
import pandas
import xarray as xr

from datetime import timedelta

## Datasets

In [15]:
#| export
from pydantic import BaseModel
from kedro.io import DataCatalog
from ids_finder.utils.basic import concat_partitions

Foundational Dataset Class

In [None]:
#| export
from ids_finder.utils.basic import df2ts
from ids_finder.utils.plot import plot_candidate as _plot_candidate

In [None]:
# | export
class IDsDataset(BaseModel):
    class Config:
        arbitrary_types_allowed = True
        extra = "allow"

    sat_id: str
    tau: timedelta
    ts: timedelta = timedelta(seconds=1)

    candidates: pl.DataFrame | None = None
    data: pl.LazyFrame | None = (
        None  # data is large, so we use `pl.LazyFrame` to save memory
    )
    bcols: list[str] = ["B_x", "B_y", "B_z"]

    def get_candidate(self, index=None, predicates=None, **kwargs):
        if index is not None:
            candidate = self.candidates.row(index, named=True)
        elif predicates is not None:
            candidate = self.candidates.filter(predicates).row(0, named=True)
        return candidate

    def get_candidate_data(self, candidate=None, index=None, predicates=None, **kwargs):
        if candidate is None:
            candidate = self.get_candidate(index, predicates, **kwargs)

        _data = self.data.filter(
            pl.col("time").is_between(candidate["tstart"], candidate["tstop"])
        )
        return df2ts(_data, self.bcols)

    def plot_candidate(self, candidate=None, index=None, predicates=None, **kwargs):
        if candidate is None:
            candidate = self.get_candidate(index, predicates, **kwargs)
        sat_fgm = self.get_candidate_data(candidate)

        return _plot_candidate(candidate, sat_fgm, **kwargs)

    def plot_candidates(self, **kwargs):
        pass

Extended Dataset Class with support for `kedro`

In [None]:
# | export
class cIDsDataset(IDsDataset):
    catalog: DataCatalog

    _load_data_format = "{sat}.MAG.primary_data_{ts}"
    _load_events_format = "events.{sat}_{ts}_{tau}"
    or_df: pl.DataFrame | None = None  # occurence rate
    or_df_normalized: pl.DataFrame | None = None  # normalized occurence rate

    def __init__(self, **data):
        super().__init__(**data)

        tau_str = f"tau_{self.tau.seconds}s"
        ts_mag_str = f"ts_{self.ts.seconds}s"

        self._tau_str = tau_str
        self._ts_mag_str = ts_mag_str

        self.events_format = self._load_events_format.format(
            sat=self.sat_id, ts=ts_mag_str, tau=tau_str
        )

        if data.get("data_format") is None:
            self.data_format = self._load_data_format.format(
                sat=self.sat_id, ts=ts_mag_str
            )

        if self.candidates is None:
            self.load_events()
        if self.data is None:
            self.load_data()

    def load_events(self):
        data_format = self.events_format
        self.candidates = (
            self.catalog.load(data_format)
            .fill_nan(None)
            .with_columns(
                cs.float().cast(pl.Float64),
                sat=pl.lit(self.sat_id),
            )
            .collect()
        )

    def load_data(self):
        data_format = self.data_format
        self.data = concat_partitions(self.catalog.load(data_format))

## Candidate class

In [None]:
#| export
from pprint import pprint

In [None]:
#| export
class CandidateID:
    def __init__(self, time, df: pl.DataFrame) -> None:
        self.time = pd.Timestamp(time)
        self.data = df.row(
            by_predicate=(pl.col("time") == self.time), 
            named=True
        )

    def __repr__(self) -> str:
        # return self.data.__repr__()
        pprint(self.data)
        return ''

In [None]:
from ids_finder.utils.basic import df2ts, pmap
from fastcore.utils import *