---
title: ID properties
subtitle: Full feature extraction
---

In [None]:
# | default_exp core/propeties
# | export
# | code-summary: "Import all the packages needed for the project"
import polars as pl
import xarray as xr
import numpy as np
import ray

from discontinuitypy.propeties.duration import calc_duration
from discontinuitypy.propeties.mva import calc_mva_features_all
from typing import Literal

ray.init(ignore_reinit_error=True)

In [None]:
# | export
def get_data_at_times(data: xr.DataArray, times) -> np.ndarray:
    """
    Select data at specified times.
    """
    # Use xarray's selection capability if data supports it
    return data.sel(time=times, method="nearest").to_numpy()


def select_data_by_timerange(data: xr.DataArray, tstart, tstop, neighbor: int = 0):
    duration = tstop - tstart
    offset = neighbor * duration
    timerange = slice(tstart - offset, tstop + offset)
    return data.sel(time=timerange)


def get_candidate_data(candidate: dict, data: xr.DataArray, **kwargs):
    return select_data_by_timerange(
        data, candidate["tstart"], candidate["tstop"], **kwargs
    )

## Pipelines

In [None]:
# | exporti
def ld2dl(listdict: list[dict], func=np.array):
    """Convert a list of dictionaries to a dictionary of lists."""
    return {key: func([result[key] for result in listdict]) for key in listdict[0]}

In [None]:
# | export
def calc_events_tr_features(
    df: pl.DataFrame, data, tr_cols=["tstart", "tstop"], func=None, **kwargs
):
    tranges = df.select(tr_cols).to_numpy()
    data_ref = ray.put(data)

    @ray.remote
    def remote(tr, **kwargs):
        data = select_data_by_timerange(ray.get(data_ref), tr[0], tr[1])
        return func(data, **kwargs)

    results = ray.get([remote.remote(tr, **kwargs) for tr in tranges])
    return df.with_columns(**ld2dl(results))


def calc_events_duration(df, data, tr_cols=["tstart", "tstop"], **kwargs):
    return calc_events_tr_features(
        df, data, tr_cols, func=calc_duration, **kwargs
    ).drop_nulls()


def calc_events_mva_features(df, data, tr_cols=["t.d_start", "t.d_end"], **kwargs):
    return calc_events_tr_features(
        df, data, tr_cols, func=calc_mva_features_all, **kwargs
    )

In [None]:
# | export
def calc_normal_direction(v1, v2):
    """
    Computes the normal direction of two vectors.

    Parameters
    ----------
    v1 : array_like
        The first vector(s).
    v2 : array_like
        The second vector(s).
    """
    c = np.cross(v1, v2)
    return c / np.linalg.norm(c, axis=-1, keepdims=True)

In [None]:
# | export
def calc_events_normal_direction(
    df: pl.DataFrame, data: xr.DataArray, name="k", start="t.d_start", end="t.d_end"
):
    """
    Computes the normal directions(s) at two different time steps.
    """
    tstart, tstop = df[start].to_numpy(), df[end].to_numpy()

    vecs_before = get_data_at_times(data, tstart)
    vecs_after = get_data_at_times(data, tstop)

    normal_directions = calc_normal_direction(vecs_before, vecs_after)
    # need to convert to list first, as only 1D array is supported
    return df.with_columns(pl.Series(name, normal_directions))


# | export
def calc_events_vec_change(
    df: pl.DataFrame, data: xr.DataArray, name="dB", start="t.d_start", end="t.d_end"
):
    """
    Utils function to calculate features related to the change of the magnetic field
    """
    tstart, tstop = df[start].to_numpy(), df[end].to_numpy()

    vecs_before = get_data_at_times(data, tstart)
    vecs_after = get_data_at_times(data, tstop)
    dvecs = vecs_after - vecs_before

    return df.with_columns(pl.Series(name, dvecs))

## Data processing

In [None]:
# | export
def process_events(
    events: pl.DataFrame,  # potential candidates DataFrame
    data: xr.DataArray,
    method: Literal["fit", "derivative"] = "fit",
    **kwargs,
):
    "Process candidates DataFrame"

    if method == "fit":
        duration_method = "distance"
        duration_expr = pl.col("fit.vars.sigma") * 2
    else:
        duration_method = "derivative"
        duration_expr = (
            pl.col("t.d_end") - pl.col("t.d_start")
        ).dt.total_nanoseconds() / 1e9

    return (
        events.pipe(calc_events_duration, data=data, method=duration_method)
        .pipe(calc_events_mva_features, data=data, method=method)
        .pipe(calc_events_vec_change, data=data, name="dB")
        .pipe(calc_events_normal_direction, data=data, name="k")
    ).with_columns(duration=duration_expr)