---
title: "Datasets"
description: This notebook contains the code to deal with all candidates and individual candidate.
---

In [None]:
#| default_exp candidates

In [None]:
#| export
import polars as pl
import polars.selectors as cs
import pandas as pd
import pandas
import xarray as xr

from datetime import timedelta

#### `Kerdo`

In [None]:
#| export
from kedro.pipeline import Pipeline, node
from kedro.pipeline.modular_pipeline import pipeline
from ids_finder.utils.basic import load_catalog

In [None]:
#| output: false
catalog = load_catalog('../')
catalog.list()


## Combining magnetic field data and state data

In [None]:
#| export
import polars as pl

from ids_finder.utils.basic import df2ts, pl_norm
import xarray as xr
from xarray_einstats import linalg

In [None]:
# | export
def combine(candidates: pl.LazyFrame, states_data: pl.LazyFrame):
    vec_cols = ["v_x", "v_y", "v_z"]  # plasma velocity vector in any coordinate system
    b_vecL_cols = ["Vl_x", "Vl_y", "Vl_z"]  # major eigenvector in any coordinate system
    if not set(vec_cols).issubset(states_data.columns):
        raise ValueError(f"Missing columns {vec_cols}")
    if not set(b_vecL_cols).issubset(candidates.columns):
        raise ValueError(f"Missing columns {b_vecL_cols}")

    return candidates.sort("time").join_asof(states_data.sort("time"), on="time")

### Calculating additional features for the combined dataset

In [None]:
#| export
import astropy.units as u
from astropy.constants import mu0
from plasmapy.formulary.lengths import inertial_length
from plasmapy.formulary.speeds import Alfven_speed

In [None]:
#| export
def vector_project(v1,v2, dim="v_dim"):
    return xr.dot(v1 , v2, dims=dim) / linalg.norm(v2, dims=dim)

def vector_project_pl(df: pl.DataFrame, v1_cols, v2_cols, name=None):
    
    v1 = df2ts(df, v1_cols).assign_coords(v_dim=["r","t","n"])
    v2 = df2ts(df, v2_cols).assign_coords(v_dim=["r","t","n"]) 
    result = vector_project(v1, v2, dim="v_dim")
    
    return df.with_columns(
        pl.Series(result.data).alias(name or "v_proj")
    )

In [None]:
#| export
def compute_inertial_length(ldf: pl.LazyFrame):
    df = ldf.collect()
    
    density = df['plasma_density'].to_numpy() * u.cm**(-3)
    result = inertial_length(density, 'H+').to(u.km)
    
    return df.with_columns(
        ion_inertial_length = pl.Series(result.value)
    ).lazy()
    
def compute_Alfven_speed(ldf: pl.LazyFrame):
    df = ldf.collect()
    
    B = df['B'] if 'B' in df.columns else df['b_mag'] # backwards compatiblity
    density = df['plasma_density'].to_numpy() * u.cm**(-3)
    result = Alfven_speed(B.to_numpy() * u.nT, density=density, ion='p+').to(u.km/u.s)
    
    return df.with_columns(
        Alfven_speed = pl.Series(result.value)
    ).lazy()



def unitize(df: pl.LazyFrame):
    """unitize the columns in the dataframe with
    """
    
    j_factor = ((u.nT/u.s) * (1 / mu0 / (u.km/u.s) )).to(u.nA/u.m**2)

    return df.with_columns(
        j0 = pl.col('j0') * j_factor.value,
    )

In [None]:
# | export


def calc_combined_features(df: pl.LazyFrame):
    vec_cols = ["v_x", "v_y", "v_z"]  # plasma velocity vector in any coordinate system
    b_vecL_cols = ["Vl_x", "Vl_y", "Vl_z"]  # major eigenvector in any coordinate system

    result = (
        df.with_columns(
            duration=pl.col("d_tstop") - pl.col("d_tstart"),
        )
        .pipe(vector_project_pl, vec_cols, b_vecL_cols, name="v_l")
        .with_columns(v_mn=(pl.col("plasma_speed") ** 2 - pl.col("v_l") ** 2).sqrt())
        .with_columns(
            L_mn=pl.col("v_mn") * pl.col("duration").dt.nanoseconds() / 1e9,
            j0=pl.col("d_star") / pl.col("v_mn"),
        )
        .pipe(compute_inertial_length)
        .pipe(compute_Alfven_speed)
        .pipe(unitize)
        .with_columns(
            L_mn_norm=pl.col("L_mn") / pl.col("ion_inertial_length"),
            j0_norm=pl.col("d_star") / pl.col("Alfven_speed"),
        )
    )
    return result

### Pipelines

In [None]:
#| export
def combine_features(candidates: pl.LazyFrame, states_data: pl.LazyFrame):
    df = combine(candidates, states_data)
    updated_df = calc_combined_features(df)

    return updated_df.collect()

## Datasets

In [None]:
#| export
from pydantic import BaseModel
from kedro.io import DataCatalog
from ids_finder.utils.basic import concat_partitions

In [None]:
#| export
class IDsDataset(BaseModel):
    class Config:
        arbitrary_types_allowed = True
        extra = "allow"

    sat_id: str
    tau: timedelta
    ts: timedelta = 1
    catalog: DataCatalog
    
    candidates: pl.DataFrame | None = None
    data: pl.LazyFrame | None = None    # data is large, so we use `pl.LazyFrame` to save memory

    or_df: pl.DataFrame | None = None  # occurence rate
    or_df_normalized: pl.DataFrame | None = None # normalized occurence rate

    def __init__(self, **data):
        super().__init__(**data)

        if self.candidates is None:
            self.load_candidates()
        if self.data is None:
            self.load_data()

    def load_candidates(self):
        candidates_format = f"candidates.{self.sat_id}_tau_{self.tau.seconds}s"

        self.candidates = self.catalog.load(candidates_format).fill_nan(None).with_columns(
            cs.float().cast(pl.Float64)
        ).collect()

    def load_data(self):
        data_format = f"{self.sat_id}.primary_mag_{self.ts.seconds}s"
        self.data = concat_partitions(self.catalog.load(data_format))
    
    def plot_candidates(self, **kwargs):
        pass

## Plotting

In [None]:
from ids_finder.utils.basic import calc_vec_mag
from pyspedas.cotrans.minvar_matrix_make import minvar_matrix_make
from pyspedas import tvector_rotate
from pytplot import timebar, store_data, tplot, split_vec, join_vec, tplot_options, options, highlight, degap
import pytplot

In [None]:
## Plotting
def time_stamp(ts):
    "Return POSIX timestamp as float."
    return pd.Timestamp(ts, tz="UTC").timestamp()

def plot_basic(
    data: xr.DataArray, 
    tstart: pd.Timestamp, 
    tstop: pd.Timestamp,
    tau: timedelta, 
    mva_tstart=None, mva_tstop=None, neighbor: int = 1
):
    if mva_tstart is None:
        mva_tstart = tstart
    if mva_tstop is None:
        mva_tstop = tstop

    mva_b = data.sel(time=slice(mva_tstart, mva_tstop))
    store_data("fgm", data={"x": mva_b.time, "y": mva_b})
    minvar_matrix_make("fgm")  # get the MVA matrix

    temp_tstart = tstart - neighbor * tau
    temp_tstop = tstop + neighbor * tau

    temp_b = data.sel(time=slice(temp_tstart, temp_tstop))
    store_data("fgm", data={"x": temp_b.time, "y": temp_b})
    temp_btotal = calc_vec_mag(temp_b)
    store_data("fgm_btotal", data={"x": temp_btotal.time, "y": temp_btotal})

    tvector_rotate("fgm_mva_mat", "fgm")
    split_vec("fgm_rot")
    pytplot.data_quants["fgm_btotal"]["time"] = pytplot.data_quants["fgm_rot"][
        "time"
    ]  # NOTE: whenever using `get_data`, we may lose precision in the time values. This is a workaround.
    join_vec(
        [
            "fgm_rot_x",
            "fgm_rot_y",
            "fgm_rot_z",
            "fgm_btotal",
        ],
        new_tvar="fgm_all",
    )

    options("fgm", "legend_names", [r"$B_x$", r"$B_y$", r"$B_z$"])
    options("fgm_all", "legend_names", [r"$B_l$", r"$B_m$", r"$B_n$", r"$B_{total}$"])
    options("fgm_all", "ysubtitle", "[nT LMN]")
    tstart_ts = time_stamp(tstart)
    tstop_ts = time_stamp(tstop)
    # .replace(tzinfo=ZoneInfo('UTC')).timestamp()
    highlight(["fgm", "fgm_all"], [tstart_ts, tstop_ts])
    
    degap("fgm")
    degap("fgm_all")

def format_candidate_title(candidate: pandas.Series):
    format_float = lambda x: rf"$\bf {x:.2f} $" if isinstance(x, (float, int)) else rf"$\bf {x} $"

    base_line = rf'$\bf {candidate.get("type", "N/A")} $ candidate (time: {candidate.get("time", "N/A")}) with index '
    index_line = rf'i1: {format_float(candidate.get("index_std", "N/A"))}, i2: {format_float(candidate.get("index_fluctuation", "N/A"))}, i3: {format_float(candidate.get("index_diff", "N/A"))}'
    info_line = rf'$B_n/B$: {format_float(candidate.get("BnOverB", "N/A"))}, $dB/B$: {format_float(candidate.get("dBOverB", "N/A"))}, $(dB/B)_{{max}}$: {format_float(candidate.get("dBOverB_max", "N/A"))},  $Q_{{mva}}$: {format_float(candidate.get("Q_mva", "N/A"))}'
    title = rf"""{base_line}
    {index_line}
    {info_line}"""
    return title

In [None]:

def plot_candidate_xr(candidate: pandas.Series, sat_fgm: xr.DataArray, tau: timedelta):
    if pd.notnull(candidate.get("d_tstart")) and pd.notnull(candidate.get("d_tstop")):
        plot_basic(
            sat_fgm,
            candidate["tstart"],
            candidate["tstop"],
            tau,
            candidate["d_tstart"],
            candidate["d_tstop"],
        )
    else:
        plot_basic(sat_fgm, candidate["tstart"], candidate["tstop"], tau)

    tplot_options("title", format_candidate_title(candidate))

    if "d_time" in candidate.keys():
        d_time_ts = time_stamp(candidate["d_time"])
        timebar(d_time_ts, color="red")
    if "d_tstart" in candidate.keys() and not pd.isnull(candidate["d_tstart"]):
        d_start_ts = time_stamp(candidate["d_tstart"])
        timebar(d_start_ts)
    if "d_tstop" in candidate.keys() and not pd.isnull(candidate["d_tstop"]):
        d_stop_ts = time_stamp(candidate["d_tstop"])
        timebar(d_stop_ts)

    tplot("fgm_all")

In [None]:

def plot_candidates(
    candidates: pandas.DataFrame, candidate_type=None, num=4, plot_func=plot_candidate
):
    """Plot a set of candidates.

    Parameters:
    - candidates (pd.DataFrame): DataFrame containing the candidates.
    - candidate_type (str, optional): Filter candidates based on a specific type.
    - num (int): Number of candidates to plot, selected randomly.
    - plot_func (callable): Function used to plot an individual candidate.

    """

    # Filter by candidate_type if provided
    
    candidates = get_candidates(candidates, candidate_type, num)

    # Plot each candidate using the provided plotting function
    for _, candidate in candidates.iterrows():
        plot_func(candidate)

## Candidate class

In [None]:
#| export
from pprint import pprint

In [None]:
#| export
class CandidateID:
    def __init__(self, time, df: pl.DataFrame) -> None:
        self.time = pd.Timestamp(time)
        self.data = df.row(
            by_predicate=(pl.col("time") == self.time), 
            named=True
        )

    def __repr__(self) -> str:
        # return self.data.__repr__()
        pprint(self.data)
        return ''
    
    def plot(self, sat_fgm, tau):
        plot_candidate_xr(self.data, sat_fgm, tau)
        pass
        

In [None]:
sta_candidates_1s: pl.DataFrame = catalog.load('candidates.sta_1s')
jno_candidates_1s = catalog.load('candidates.jno_1s')

sta_mag : pl.LazyFrame = catalog.load('sta.inter_mag_rtn_1s')
jno_mag : pl.LazyFrame = catalog.load('sta.inter_mag_rtn_1s')

In [None]:
from ids_finder.utils.basic import df2ts, pmap
from fastcore.utils import *

In [None]:

def plot_candidate(candidate, mag_data: pl.LazyFrame, b_cols = ['BX', 'BY', 'BZ']):
    temp_tstart = candidate["tstart"]
    tmep_tstop = candidate["tstop"]
    tau = tmep_tstop - temp_tstart

    temp_mag_data = (
        mag_data.filter(pl.col("time").is_between(temp_tstart - tau, tmep_tstop + tau))
        .with_columns(pl.col("time").dt.cast_time_unit("ns"))
        .collect()
    )
    
    sat_fgm = df2ts(temp_mag_data, b_cols)
    plot_candidate_xr(candidate, sat_fgm, tau)

In [None]:
n = 3
# list(sta_candidates_1s.sample(n).iter_rows(named=True) | pmap(plot_candidate, mag_data=sta_mag))

candidates = jno_candidates_1s.sample(n)
list(candidates.iter_rows(named=True) | pmap(plot_candidate, mag_data=jno_mag))


## Pipelines

In [None]:
#| export
from kedro.pipeline import Pipeline, node
from kedro.pipeline.modular_pipeline import pipeline

In [None]:
# | export
def create_candidate_pipeline(
    sat_id, 
    tau: str = "60s",
    ts_state: str = "1h",
    **kwargs) -> Pipeline:

    node_combine_features = node(
        combine_features,
        inputs=[
            f"{sat_id}.feature_tau_{tau}",
            f"{sat_id}.primary_state_{ts_state}",
        ],
        outputs=f"candidates.{sat_id}_tau_{tau}",
    )

    nodes = [node_combine_features]
    return pipeline(nodes)