---
title: "Datasets"
description: This notebook contains the code to deal with all candidates and individual candidate.
---

In [2]:
#| default_exp datasets

In [3]:
#| export
import polars as pl
import polars.selectors as cs
import pandas as pd
import pandas
import xarray as xr

from datetime import timedelta

#### `Kerdo`

In [4]:
#| export
from kedro.pipeline import Pipeline, node
from kedro.pipeline.modular_pipeline import pipeline
from ids_finder.utils.basic import load_catalog

In [5]:
#| output: false
catalog = load_catalog('../')
catalog.list()



[1m[[0m
    [32m'jno.raw_mag_1s'[0m,
    [32m'jno.raw_state'[0m,
    [32m'model.raw_jno_ss_se_1min'[0m,
    [32m'model.preprocessed_jno_ss_se_1min'[0m,
    [32m'JNO_index'[0m,
    [32m'sta.raw_state_merged'[0m,
    [32m'thb.raw_state_sw'[0m,
    [32m'parameters'[0m,
    [32m'params:tau'[0m,
    [32m'params:jno_start_date'[0m,
    [32m'params:jno_end_date'[0m,
    [32m'params:jno_1s_params'[0m,
    [32m'params:jno_1s_params.bcols'[0m,
    [32m'params:jno_1s_params.data_resolution'[0m,
    [32m'params:jno.extract_params'[0m,
    [32m'params:jno.extract_params.bcols'[0m,
    [32m'params:jno.extract_params.data_resolution'[0m,
    [32m'params:sta.extract_params'[0m,
    [32m'params:sta.extract_params.bcols'[0m,
    [32m'params:sta.extract_params.data_resolution'[0m,
    [32m'params:thb'[0m,
    [32m'params:thb.mag'[0m,
    [32m'params:thb.mag.bcols'[0m,
    [32m'params:thb.mag.time_resolution'[0m,
    [32m'params:thb.mag.coords'[0m,
    

## Datasets

In [15]:
#| export
from pydantic import BaseModel
from kedro.io import DataCatalog
from ids_finder.utils.basic import concat_partitions

Foundational Dataset Class

In [None]:
#| export
from ids_finder.utils.basic import df2ts
from ids_finder.utils.plot import plot_candidate

In [None]:

# | export
class IDsDataset(BaseModel):
    class Config:
        arbitrary_types_allowed = True
        extra = "allow"

    sat_id: str
    tau: timedelta
    ts: timedelta = timedelta(seconds=1)

    candidates: pl.DataFrame | None = None
    data: pl.LazyFrame | None = None # data is large, so we use `pl.LazyFrame` to save memory

    def plot_candidate(self, index = None, predicates = None):
        if index is not None:
            candidate = self.candidates.row(index, named=True)
        elif predicates is not None:
            candidate = self.candidates.filter(predicates).row(0, named=True)

        _data = self.data.filter(
            pl.col("time").is_between(candidate["tstart"], candidate["tstop"])
        )
        bcols = ["B_x", "B_y", "B_z"] if "B_x" in _data.columns else ["BX", "BY", "BZ"]
        sat_fgm = df2ts(_data, bcols)
        plot_candidate(candidate, sat_fgm)
        pass

    def plot_candidates(self, **kwargs):
        pass

Extended Dataset Class with support for `kedro`

In [None]:
#| export
class cIDsDataset(IDsDataset):
    catalog: DataCatalog
    
    or_df: pl.DataFrame | None = None  # occurence rate
    or_df_normalized: pl.DataFrame | None = None # normalized occurence rate

    def __init__(self, **data):
        super().__init__(**data)
        
        self._tau_str = f"tau_{self.tau.seconds}s"
        self._ts_mag_str = f"ts_{self.ts.seconds}s"
        
        if self.candidates is None:
            self.load_candidates()
        if self.data is None:
            self.load_data()

    def load_candidates(self):

        candidates_format = f"events.{self.sat_id.upper()}_{self._ts_mag_str}_{self._tau_str}"

        self.candidates = self.catalog.load(candidates_format).fill_nan(None).with_columns(
            cs.float().cast(pl.Float64),
            sat=pl.lit(self.sat_id),
        ).collect()

    def load_data(self):
        data_format = f"{self.sat_id}.MAG.primary_data_{self._ts_mag_str}"
        self.data = concat_partitions(self.catalog.load(data_format))

## Candidate class

In [None]:
#| export
from pprint import pprint

In [None]:
#| export
class CandidateID:
    def __init__(self, time, df: pl.DataFrame) -> None:
        self.time = pd.Timestamp(time)
        self.data = df.row(
            by_predicate=(pl.col("time") == self.time), 
            named=True
        )

    def __repr__(self) -> str:
        # return self.data.__repr__()
        pprint(self.data)
        return ''
    
    def plot(self, sat_fgm, tau):
        plot_candidate_xr(self.data, sat_fgm, tau)
        pass
        

In [None]:
sta_candidates_1s: pl.DataFrame = catalog.load('candidates.sta_1s')
jno_candidates_1s = catalog.load('candidates.jno_1s')

sta_mag : pl.LazyFrame = catalog.load('sta.inter_mag_rtn_1s')
jno_mag : pl.LazyFrame = catalog.load('sta.inter_mag_rtn_1s')

In [None]:
from ids_finder.utils.basic import df2ts, pmap
from fastcore.utils import *

In [None]:

def plot_candidate(candidate, mag_data: pl.LazyFrame, b_cols = ['BX', 'BY', 'BZ']):
    temp_tstart = candidate["tstart"]
    tmep_tstop = candidate["tstop"]
    tau = tmep_tstop - temp_tstart

    temp_mag_data = (
        mag_data.filter(pl.col("time").is_between(temp_tstart - tau, tmep_tstop + tau))
        .with_columns(pl.col("time").dt.cast_time_unit("ns"))
        .collect()
    )
    
    sat_fgm = df2ts(temp_mag_data, b_cols)
    plot_candidate_xr(candidate, sat_fgm, tau)

In [None]:
n = 3
# list(sta_candidates_1s.sample(n).iter_rows(named=True) | pmap(plot_candidate, mag_data=sta_mag))

candidates = jno_candidates_1s.sample(n)
list(candidates.iter_rows(named=True) | pmap(plot_candidate, mag_data=jno_mag))
