In [None]:
#| default_exp core

In [None]:
#| hide 
from nbdev.showdoc import *

In [None]:
#| export
#| code-summary: "Import all the packages needed for the project"
from fastcore.utils import *
from fastcore.test import *
from ids_finder.utils.basic import *
import polars as pl
import xarray as xr


try:
    import modin.pandas as pd
    import modin.pandas as mpd
    from modin.config import ProgressBar
    ProgressBar.enable()
except ImportError:
    import pandas as pd
import pandas
    
import numpy as np
from xarray_einstats import linalg

from datetime import timedelta

from loguru import logger


import pdpipe as pdp
from multipledispatch import dispatch

from typing import Any, Collection, Callable

from xarray.core.dataarray import DataArray




## Processing Stages

- [ ] Smoothing
- [ ] Interpolating

## ID identification (limited feature extraction / anomaly detection)

The first index is $$ \frac{\sigma(B)}{Max(\sigma(B_-),\sigma(B_+))} $$
The second index is $$ \frac{\sigma(B_- + B_+)} {\sigma(B_-) + \sigma(B_+)} $$
The ﬁrst two conditions guarantee that the ﬁeld changes of the IDs identiﬁed are large enough to be distinguished from the stochastic ﬂuctuations on magnetic ﬁelds, while the third is a supplementary condition toreduce the uncertainty of recognition.

third index (relative field jump) is $$ \frac{| \Delta \vec{B} |}{|B_{bg}|} $$ a supplementary condition to reduce the uncertainty of recognition

In [None]:
from ids_finder.utils.basic import _expand_selectors

In [None]:
# some helper functions
def pl_format_time(tau):
    return [
        pl.col("time").alias("tstart"),
        (pl.col("time") + tau).dt.cast_time_unit("ns").alias("tstop"),
        (pl.col("time") + tau / 2).dt.cast_time_unit("ns"),
    ]

def pl_dvec(columns, *more_columns):
    all_columns = _expand_selectors(columns, *more_columns)
    return [
        (pl.col(column).first() - pl.col(column).last()).alias(f"d{column}_vec")
        for column in all_columns
    ]


In [None]:
#| export
def compute_std(
    df: pl.DataFrame, 
    tau) -> pl.DataFrame:
    b_cols = ["BX", "BY", "BZ"]
    b_std_cols = [col_name + "_std" for col_name in b_cols]

    std_df = (
        df.group_by_dynamic("time", every=tau / 2, period=tau)
        .agg(
            pl.count(),
            pl.col(b_cols).std(ddof=0).map_alias(lambda col_name: col_name + "_std"),
        )
        .with_columns(
            pl_norm(b_std_cols).alias("B_std"),
        )
        .drop(b_std_cols)
    )
    return std_df


def compute_combinded_std(df: pl.DataFrame, tau) -> pl.DataFrame:
    b_cols = ["BX", "BY", "BZ"]
    b_combined_std_cols = [col_name + "_combined_std" for col_name in b_cols]
    offsets = [0 * tau, tau / 2]
    combined_std_dfs = []
    for offset in offsets:
        truncated_df = df.select(
            (pl.col("time") - offset).dt.truncate(tau, offset=offset).alias("time"),
            pl.col(b_cols),
        )

        prev_df = truncated_df.select(
            (pl.col("time") + tau).dt.cast_time_unit("ns"),
            pl.col(b_cols),
        )

        next_df = truncated_df.select(
            (pl.col("time") - tau).dt.cast_time_unit("ns"),
            pl.col(b_cols),
        )

        temp_combined_std_df = (
            pl.concat([prev_df, next_df])
            .group_by("time")
            .agg(
                pl.col(b_cols)
                .std(ddof=0)
                .map_alias(lambda col_name: col_name + "_combined_std"),
            )
            .with_columns(pl_norm(b_combined_std_cols).alias("B_combined_std"))
            .drop(b_combined_std_cols)
            .sort("time")
        )

        combined_std_dfs.append(temp_combined_std_df)

    combined_std_df = pl.concat(combined_std_dfs)
    return combined_std_df

In [None]:
#| export
@dispatch(pl.LazyFrame, object)
def compute_index_std(df: pl.LazyFrame, tau, join_strategy="inner"):  # noqa: F811
    """
    Compute the standard deviation index based on the given DataFrame and tau value.

    Parameters
    ----------
    - df (pl.LazyFrame): The input DataFrame.
    - tau (int): The time interval value.

    Returns
    -------
    - pl.LazyFrame: DataFrame with calculated 'index_std' column.

    Examples
    --------
    >>> index_std_df = compute_index_std_pl(df, tau)
    >>> index_std_df

    Notes
    -----
    Simply shift to calculate index_std would not work correctly if data is missing, like `std_next = pl.col("B_std").shift(-2)`.

    """

    if isinstance(tau, (int, float)):
        tau = timedelta(seconds=tau)

    if "B_std" in df.columns:
        std_df = df
    else:
        # Compute standard deviations
        std_df = compute_std(df, tau)

    # Calculate the standard deviation index
    prev_std_df = std_df.select(
        (pl.col("time") + tau).dt.cast_time_unit("ns"),
        pl.col("B_std").alias("B_std_prev"),
        pl.col("count").alias("count_prev"),
    )

    next_std_df = std_df.select(
        (pl.col("time") - tau).dt.cast_time_unit("ns"),
        pl.col("B_std").alias("B_std_next"),
        pl.col("count").alias("count_next")
    )

    index_std_df = (
        std_df.join(prev_std_df, on="time", how=join_strategy)
        .join(next_std_df, on="time", how=join_strategy)
        .with_columns(
            (pl.col("B_std") / (pl.max_horizontal("B_std_prev", "B_std_next"))).alias(
                "index_std"
            )
        )
    )
    return index_std_df

In [None]:
#| export
def compute_index_diff(df, tau):
    b_cols = ["BX", "BY", "BZ"]
    db_cols = ["d" + col_name + "_vec" for col_name in b_cols]

    index_diff = (
        df.with_columns(pl_norm(b_cols).alias("B"))
        .group_by_dynamic("time", every=tau / 2, period=tau)
        .agg(
            pl.col("B").mean().alias("B_mean"),
            *pl_dvec(b_cols),
        )
        .with_columns(
            pl_norm(db_cols).alias("dB_vec"),
        )
        .with_columns(
            (pl.col("dB_vec") / pl.col("B_mean")).alias("index_diff"),
        )
    )

    return index_diff


@dispatch(pl.LazyFrame, timedelta)
def compute_indices(
    df: pl.LazyFrame, 
    tau: timedelta,
) -> pl.LazyFrame:
    """
    Compute all index based on the given DataFrame and tau value.

    Parameters
    ----------
    df : pl.DataFrame
        Input DataFrame.
    tau : datetime.timedelta
        Time interval value.

    Returns
    -------
    tuple : 
        Tuple containing DataFrame results for fluctuation index, 
        standard deviation index, and 'index_num'.

    Examples
    --------
    >>> indices = compute_indices(df, tau)

    Notes
    -----
    - Simply shift to calculate index_std would not work correctly if data is missing, 
        like `std_next = pl.col("B_std").shift(-2)`.
    - Drop null though may lose some IDs (using the default `join_strategy`). 
        Because we could not tell if it is a real ID or just a partial wave 
        from incomplete data without previous or/and next std. 
        Hopefully we can pick up the lost ones with smaller tau.
    - TODO: Can be optimized further, but this is already fast enough.
        - TEST: if `join` can be improved by shift after filling the missing values.
        - TEST: if `list` in `polars` really fast?
    """
    join_strategy = "inner"
    
    std_df = compute_std(df, tau)
    combined_std_df = compute_combinded_std(df, tau)

    index_std = compute_index_std(std_df, tau)
    index_diff = compute_index_diff(df, tau)

    indices = (
        index_std.join(index_diff, on="time")
        .join(combined_std_df, on="time", how=join_strategy)
        .with_columns(
            pl.sum_horizontal("B_std_prev", "B_std_next").alias("B_added_std"),
        )
        .with_columns(
            (pl.col("B_std") / (pl.max_horizontal("B_std_prev", "B_std_next"))).alias(
                "index_std"
            ),
            (pl.col("B_combined_std") / pl.col("B_added_std")).alias(
                "index_fluctuation"
            ),
        )
    )

    return indices


@dispatch(pl.DataFrame, timedelta)
def compute_indices(    # noqa: F811
    df: pl.DataFrame, 
    tau: timedelta,
) -> pl.DataFrame:
    """
    wrapper for `compute_indices` with `pl.LazyFrame` input.
    """
    return compute_indices(df.lazy(), tau).collect()


### Index of the standard deviation

### Index of fluctuation

### Index of the relative field jump

In [None]:
#| export
@dispatch(object, xr.DataArray)
def get_candidate_data(candidate, data, coord:str=None, neighbor:int=0) -> xr.DataArray:
    duration = candidate['tstop'] - candidate['tstart']
    offset = neighbor*duration
    temp_tstart = candidate['tstart'] - offset
    temp_tstop = candidate['tstop'] + offset
    
    return data.sel(time=slice(temp_tstart,  temp_tstop))

@dispatch(object, pl.DataFrame)
def get_candidate_data(candidate, data, coord:str=None, neighbor:int=0) -> xr.DataArray:
    """
    Notes
    -----
    much slower than `get_candidate_data_xr`
    """
    duration = candidate['tstart'] - candidate['tstop']
    offset = neighbor*duration
    temp_tstart = candidate['tstart'] - offset
    temp_tstop = candidate['tstop'] + offset
    
    temp_data = data.filter(
        pl.col("time").is_between(temp_tstart, temp_tstop)
    )
    
    return df2ts(temp_data, ["BX", "BY", "BZ"], attrs={"coordinate_system": coord, "units": "nT"})

def get_candidates(candidates: pd.DataFrame, candidate_type=None, num:int=4):
    
    if candidate_type is not None:
        _candidates = candidates[candidates['type'] == candidate_type]
    else:
        _candidates = candidates
    
    # Sample a specific number of candidates if num is provided and it's less than the total number
    if num < len(_candidates):
        logger.info(f"Sampling {num} {candidate_type} candidates out of {len(_candidates)}")
        return _candidates.sample(num)
    else:
        return _candidates

## ID parameters (full feature extraction)

### Duration

Definitions of duration
- Define $d^* = \max( | dB / dt | ) $, and then define time interval where $| dB/dt |$ decreases to $d^*/4$

In [None]:
#| export
THRESHOLD_RATIO  = 1/4

from typing import Tuple

def calc_duration(vec: xr.DataArray, threshold_ratio=THRESHOLD_RATIO) -> pandas.Series:
    # NOTE: gradient calculated at the edge is not reliable.
    vec_diff = vec.differentiate("time", datetime_unit="s").isel(time=slice(1,-1))
    vec_diff_mag = linalg.norm(vec_diff, dims='v_dim')

    # Determine d_star based on trend
    if vec_diff_mag.isnull().all():
        raise ValueError("The differentiated vector magnitude contains only NaN values. Cannot compute duration.")
    
    d_star_index = vec_diff_mag.argmax(dim="time")
    d_star = vec_diff_mag[d_star_index]
    d_time = vec_diff_mag.time[d_star_index]
    
    threshold = d_star * threshold_ratio

    start_time, end_time = find_start_end_times(vec_diff_mag, d_time, threshold)

    dict = {
        'd_star': d_star.item(),
        'd_time': d_time.values,
        'threshold': threshold.item(),
        'd_tstart': start_time,
        'd_tstop': end_time,
    }

    return pandas.Series(dict)

def calc_d_duration(vec: xr.DataArray, d_time, threshold) -> pd.Series:
    vec_diff = vec.differentiate("time", datetime_unit="s")
    vec_diff_mag = linalg.norm(vec_diff, dims='v_dim')

    start_time, end_time = find_start_end_times(vec_diff_mag, d_time, threshold)

    return pandas.Series({
        'd_tstart': start_time,
        'd_tstop': end_time,
    })
 
def find_start_end_times(vec_diff_mag: xr.DataArray, d_time, threshold) -> Tuple[pd.Timestamp, pd.Timestamp]:
    # Determine start time
    pre_vec_mag = vec_diff_mag.sel(time=slice(None, d_time))
    start_time = get_time_from_condition(pre_vec_mag, threshold, "last_below")

    # Determine stop time
    post_vec_mag = vec_diff_mag.sel(time=slice(d_time, None))
    end_time = get_time_from_condition(post_vec_mag, threshold, "first_below")

    return start_time, end_time


def get_time_from_condition(vec: xr.DataArray, threshold, condition_type) -> pd.Timestamp:
    if condition_type == "first_below":
        condition = vec < threshold
        index_choice = 0
    elif condition_type == "last_below":
        condition = vec < threshold
        index_choice = -1
    else:
        raise ValueError(f"Unknown condition_type: {condition_type}")

    where_result = np.where(condition)[0]

    if len(where_result) > 0:
        return vec.time[where_result[index_choice]].values
    return None

In [None]:
#| export
def calc_candidate_duration(candidate: pd.Series, data) -> pd.Series:
    try:
        candidate_data = get_candidate_data(candidate, data)
        return calc_duration(candidate_data)
    except Exception as e:
        # logger.debug(f"Error for candidate {candidate} at {candidate['time']}: {str(e)}") # can not be serialized
        print(f"Error for candidate {candidate} at {candidate['time']}: {str(e)}")
        raise e

def calc_candidate_d_duration(candidate, data) -> pd.Series:
    try:
        if pd.isnull(candidate['d_tstart']) or pd.isnull(candidate['d_tstop']):
            candidate_data = get_candidate_data(candidate, data, neighbor=1)
            d_time = candidate['d_time']
            threshold = candidate['threshold']
            return calc_d_duration(candidate_data, d_time, threshold)
        else:
            return pandas.Series({
                'd_tstart': candidate['d_tstart'],
                'd_tstop': candidate['d_tstop'],
            })
    except Exception as e:
        # logger.debug(f"Error for candidate {candidate} at {candidate['time']}: {str(e)}")
        print(f"Error for candidate {candidate} at {candidate['time']}: {str(e)}")
        raise e

In [None]:
#| export

def calibrate_candidate_duration(
    candidate: pd.Series, data:xr.DataArray, data_resolution, ratio = 3/4
):
    """
    Calibrates the candidate duration. 
    - If only one of 'd_tstart' or 'd_tstop' is provided, calculates the missing one based on the provided one and 'd_time'.
    - Then if this is not enough points between 'd_tstart' and 'd_tstop', returns None for both.
    
    
    Parameters
    ----------
    - candidate (pd.Series): The input candidate with potential missing 'd_tstart' or 'd_tstop'.
    
    Returns
    -------
    - pd.Series: The calibrated candidate.
    """
    
    start_notnull = pd.notnull(candidate['d_tstart'])
    stop_notnull = pd.notnull(candidate['d_tstop']) 
    
    match start_notnull, stop_notnull:
        case (True, True):
            d_tstart = candidate['d_tstart']
            d_tstop = candidate['d_tstop']
        case (True, False):
            d_tstart = candidate['d_tstart']
            d_tstop = candidate['d_time'] -  candidate['d_tstart'] + candidate['d_time']
        case (False, True):
            d_tstart = candidate['d_time'] -  candidate['d_tstop'] + candidate['d_time']
            d_tstop = candidate['d_tstop']
        case (False, False):
            return pandas.Series({
                'd_tstart': None,
                'd_tstop': None,
            })
    
    duration = d_tstop - d_tstart
    num_of_points_between = data.time.sel(time=slice(d_tstart, d_tstop)).count().item()
    
    if num_of_points_between <= (duration/data_resolution) * ratio:
        d_tstart = None
        d_tstop = None
    
    return pandas.Series({
        'd_tstart': d_tstart,
        'd_tstop': d_tstop,
    })

### ID classification

In this method, TDs and RDs satisfy $ \frac{ |B_N| }{ |B_{bg}| } < 0.2$ and $ | \frac{ \Delta |B| }{ |B_{bg}| } | > 0.4$ B BN bg ∣∣ ∣∣ , < D 0.2 B B bg ∣∣ ∣ ∣ , respectively. Moreover, IDs with < 0.4 B BN bg ∣∣ ∣∣ , < D 0.2 B B bg ∣∣ ∣ ∣ could be either TDs or RDs, and so are termed EDs. Similarly, NDs are defined as > 0.4 B BN bg ∣∣ ∣∣ , > D 0.2 B B bg ∣∣ ∣ ∣ because they can be neither TDs nor RDs. It is worth noting that EDs and NDs here are not physical concepts like RDs and TDs. RDs or TDs correspond to specific types of structures in the MHD framework, while EDs and NDs are introduced just to better quantify the statistical results.


Criteria Used to Classify Discontinuities on the Basis of Magnetic Data Type

| Type   |  $\|B_n/B\|$      | $\| \Delta B / B \|$  |
|----------|-------------|------|
| Rotational (RD) | large | small |
| Tangential (TD) | small |  large |
| Either (ED) | small | small |
| Neither (ND) | large | large |


#### minimum variance analysis (MVA)

To ensure the accuracy of MVA, only when the ratio of the middle to the minimum eigenvalue (labeled QMVA for simplicity) is larger than 3 are the results used for further analysis.

Parameters:
- `Vl_x`, `Vl_y`, `Vl_z`: Maximum variance direction eigenvector


In [None]:
#| export
BnOverB_RD_lower_threshold = 0.4
dBOverB_RD_upper_threshold = 0.2

BnOverB_TD_upper_threshold = 0.2
dBOverB_TD_lower_threshold = dBOverB_RD_upper_threshold

BnOverB_ED_upper_threshold = BnOverB_RD_lower_threshold
dBOverB_ED_upper_threshold = dBOverB_TD_lower_threshold

BnOverB_ND_lower_threshold = BnOverB_TD_upper_threshold
dBOverB_ND_lower_threshold = dBOverB_RD_upper_threshold

In [None]:
#| exports
def minvar(data):
    """
    see `pyspedas.cotrans.minvar`
    This program computes the principal variance directions and variances of a
    vector quantity as well as the associated eigenvalues.

    Parameters
    -----------
    data:
        Vxyz, an (npoints, ndim) array of data(ie Nx3)

    Returns
    -------
    vrot:
        an array of (npoints, ndim) containing the rotated data in the new coordinate system, ijk.
        Vi(maximum direction)=vrot[0,:]
        Vj(intermediate direction)=vrot[1,:]
        Vk(minimum variance direction)=Vrot[2,:]
    v:
        an (ndim,ndim) array containing the principal axes vectors
        Maximum variance direction eigenvector, Vi=v[*,0]
        Intermediate variance direction, Vj=v[*,1] (descending order)
    w:
        the eigenvalues of the computation
    """

    #  Min var starts here
    # data must be Nx3
    vecavg = np.nanmean(np.nan_to_num(data, nan=0.0), axis=0)

    mvamat = np.zeros((3, 3))
    for i in range(3):
        for j in range(3):
            mvamat[i, j] = np.nanmean(np.nan_to_num(data[:, i] * data[:, j], nan=0.0)) - vecavg[i] * vecavg[j]

    # Calculate eigenvalues and eigenvectors
    w, v = np.linalg.eigh(mvamat, UPLO='U')

    # Sorting to ensure descending order
    w = np.abs(w)
    idx = np.flip(np.argsort(w))

    # IDL compatability
    if True:
        if np.sum(w) == 0.0:
            idx = [0, 2, 1]

    w = w[idx]
    v = v[:, idx]

    # Rotate intermediate var direction if system is not Right Handed
    YcrossZdotX = v[0, 0] * (v[1, 1] * v[2, 2] - v[2, 1] * v[1, 2])
    if YcrossZdotX < 0:
        v[:, 1] = -v[:, 1]
        # v[:, 2] = -v[:, 2] # Should not it is being flipped at Z-axis?

    # Ensure minvar direction is along +Z (for FAC system)
    if v[2, 2] < 0:
        v[:, 2] = -v[:, 2]
        v[:, 1] = -v[:, 1]

    vrot = np.array([np.dot(row, v) for row in data])

    return vrot, v, w


In [None]:
#| export
def calc_classification_index(
    data: xr.DataArray
) -> pandas.Series:

    vrot, v, w = minvar(data.to_numpy()) # NOTE: using `.to_numpy()` will significantly speed up the computation.
    Vl = v[:,0] # Maximum variance direction eigenvector

    B_rot = xr.DataArray(vrot, dims=['time', 'v_dim'], coords={'time': data.time})
    B = calc_vec_mag(B_rot)
    
    # Compute dB for each component
    dB_values = [B_rot.isel(v_dim = i, time=0) - B_rot.isel(v_dim = i, time=-1) for i in range(3)]
    
    # Compute mean values
    B_mean = B.mean(dim="time")
    B_n_mean = B_rot.isel(v_dim=2).mean(dim="time")
    BnOverB = B_n_mean / B_mean # BnOverB = np.abs(B_n / B).mean(dim="time")

    dB = B.isel(time=-1) - B.isel(time=0)
    dBOverB = np.abs(dB / B_mean)
    dBOverB_max = (B.max(dim="time") - B.min(dim="time")) / B_mean
    
    results = {
        'Vl_x': Vl[0],
        'Vl_y': Vl[1],
        'Vl_z': Vl[2],
        'eig0': w[0],
        'eig1': w[1],
        'eig2': w[2],
        'Q_mva': w[1]/w[2],
        'B': B_mean.item(),
        'B_n': B_n_mean.item(),
        'dB': dB.item(),
        'BnOverB': BnOverB.item(), 
        'dBOverB': dBOverB.item(),
        'dBOverB_max': dBOverB_max.item(),
        'dB_l': dB_values[0].item(),
        'dB_m': dB_values[1].item(),
        'dB_n': dB_values[2].item(),
        }
    return pandas.Series(results)
    
    

In [None]:
#| export
def classify_id(BnOverB, dBOverB):
    BnOverB = np.abs(np.asarray(BnOverB))
    dBOverB = np.asarray(dBOverB)

    s1 = (BnOverB > BnOverB_RD_lower_threshold)
    s2 = (dBOverB > dBOverB_RD_upper_threshold)
    s3 = (BnOverB > BnOverB_TD_upper_threshold)
    s4 = s2 # note: s4 = (dBOverB > dBOverB_TD_lower_threshold)
    
    RD = s1 & ~s2
    TD = ~s3 & s4
    ED = ~s1 & ~s4
    ND = s3 & s2

    # Create an empty result array with the same shape
    result = np.empty_like(BnOverB, dtype=object)

    result[RD] = "RD"
    result[TD] = "TD"
    result[ED] = "ED"
    result[ND] = "ND"

    return result

### Field rotation angles
The PDF of the field rotation angles across the solar-wind IDs is well fitted by the exponential function exp(−θ/)...

In [None]:
#| export
def calc_rotation_angle(v1, v2):
    """
    Computes the rotation angle between two vectors.
    
    Parameters:
    - v1: The first vector.
    - v2: The second vector.
    """
    
    if v1.shape != v2.shape:
        raise ValueError("Vectors must have the same shape.")

    # convert xr.Dataarray to numpy arrays
    if isinstance(v1, DataArray):
        v1 = v1.to_numpy()
    if isinstance(v2, DataArray):
        v2 = v2.to_numpy()
    
    # Normalize the vectors
    v1_u = v1 / np.linalg.norm(v1, axis=-1, keepdims=True)
    v2_u = v2 / np.linalg.norm(v2, axis=-1, keepdims=True)
    
    # Calculate the cosine of the angle for each time step
    cosine_angle = np.sum(v1_u * v2_u, axis=-1)
    
    # Clip the values to handle potential floating point errors
    cosine_angle = np.clip(cosine_angle, -1, 1)
    
    angle = np.arccos(cosine_angle)
    
    # Convert the angles from radians to degrees
    return np.degrees(angle)

def calc_candidate_rotation_angle(candidates, data:  xr.DataArray):
    """
    Computes the rotation angle(s) at two different time steps.
    """
    
    tstart = candidates['d_tstart']
    tstop = candidates['d_tstop']
    
    # Convert Series to numpy arrays if necessary
    if isinstance(tstart, pd.Series):
        tstart = tstart.to_numpy()
        tstop = tstop.to_numpy()
        # no need to Handle NaT values (as `calibrate_candidate_duration` will handle this)
    
    # Get the vectors at the two time steps
    vecs_before = data.sel(time=tstart, method="nearest")
    vecs_after = data.sel(time=tstop, method="nearest")
    
    # Compute the rotation angle(s)
    rotation_angles = calc_rotation_angle(vecs_before, vecs_after)
    return rotation_angles

### Assign satellite locations to the discontinuities

In [None]:
#| export
def get_candidate_location(candidate, location_data: DataArray):
    return location_data.sel(time = candidate['d_time'], method="nearest").to_series()

## Processing the whole dataset

In [None]:
#| export
def get_ID_filter_condition(
    index_std_threshold = 2,
    index_fluc_threshold = 1,
    index_diff_threshold = 0.1,
    sparse_num = 15
):
    return (
        (pl.col("index_std") > index_std_threshold)
        & (pl.col("index_fluctuation") > index_fluc_threshold)
        & (pl.col("index_diff") > index_diff_threshold)
        & (
            pl.col("index_std").is_finite()
        )  # for cases where neighboring groups have std=0
        & (
            pl.col("count") > sparse_num
        )  # filter out sparse intervals, which may give unreasonable results.
        & (
            pl.col("count_prev") > sparse_num
        ) 
        & (
            pl.col("count_next") > sparse_num
        )
    )


In [None]:
#| export
from pdpipe.util import out_of_place_col_insert

patch `pdp.ApplyToRows` to work with `modin` and `xorbits` DataFrames

In [None]:
#| export
@patch
def _transform(self: pdp.ApplyToRows, X, verbose):
    new_cols = X.apply(self._func, axis=1)
    if isinstance(new_cols, (pd.Series, pandas.Series)):
        loc = len(X.columns)
        if self._follow_column:
            loc = X.columns.get_loc(self._follow_column) + 1
        return out_of_place_col_insert(
            X=X, series=new_cols, loc=loc, column_name=self._colname
        )
    if isinstance(new_cols, (mpd.DataFrame, pandas.DataFrame)):
        sorted_cols = sorted(list(new_cols.columns))
        new_cols = new_cols[sorted_cols]
        if self._follow_column:
            inter_X = X
            loc = X.columns.get_loc(self._follow_column) + 1
            for colname in new_cols.columns:
                inter_X = out_of_place_col_insert(
                    X=inter_X,
                    series=new_cols[colname],
                    loc=loc,
                    column_name=colname,
                )
                loc += 1
            return inter_X
        assign_map = {
            colname: new_cols[colname] for colname in new_cols.columns
        }
        return X.assign(**assign_map)
    raise TypeError(  # pragma: no cover
        "Unexpected type generated by applying a function to a DataFrame."
        " Only Series and DataFrame are allowed."
    )

In [None]:
#| export
def calc_candidate_classification_index(candidate, data):
    return calc_classification_index(
        data.sel(time=slice(candidate["d_tstart"], candidate["d_tstop"]))
    )

In [None]:
#| export
def convert_to_dataframe(
    data: pl.DataFrame, # orignal Dataframe
)->pd.DataFrame:
    "convert data into a pandas/modin DataFrame"
    if isinstance(data, pl.LazyFrame):
        data = data.collect().to_pandas(use_pyarrow_extension_array=True)
    if isinstance(data, pl.DataFrame):
        data = data.to_pandas(use_pyarrow_extension_array=True)
    if not isinstance(data, pd.DataFrame):  # `modin` supports
        data = pd.DataFrame(data)
    return data

`Pipelines` Class for processing IDs

In [None]:
#| export
class IDsPipeline:
    def __init__(self):
        pass
    # fmt: off
    def calc_duration(self, sat_fgm: xr.DataArray):
        return pdp.PdPipeline([
            pdp.ApplyToRows(
                lambda candidate: calc_candidate_duration(candidate, sat_fgm),
                func_desc="calculating duration parameters"
            ),
            pdp.ApplyToRows(
                lambda candidate: calc_candidate_d_duration(candidate, sat_fgm),
                func_desc="calculating duration parameters if needed"
            )
        ])

    def calibrate_duration(self, sat_fgm, data_resolution):
        return \
            pdp.ApplyToRows(
                lambda candidate: calibrate_candidate_duration(candidate, sat_fgm, data_resolution),
                func_desc="calibrating duration parameters if needed"
            )

    def classify_id(self, sat_fgm):
        return pdp.PdPipeline([
            pdp.ApplyToRows(
                lambda candidate: calc_candidate_classification_index(candidate, sat_fgm),
                func_desc='calculating index "q_mva", "BnOverB" and "dBOverB"'
            ),
            pdp.ColByFrameFunc(
                "type",
                lambda df: classify_id(df["BnOverB"], df["dBOverB"]),
                func_desc="classifying the type of the ID"
            ),
        ])
    
    def calc_rotation_angle(self, sat_fgm):
        return \
            pdp.ColByFrameFunc(
                "rotation_angle",
                lambda df: calc_candidate_rotation_angle(df, sat_fgm),
                func_desc="calculating rotation angle",
            ) 

    def assign_coordinates(self, sat_state: xr.DataArray):
        "NOTE: not optimized, quite slow"
        return \
            pdp.ApplyToRows(
                lambda candidate: get_candidate_location(candidate, sat_state),
                func_desc="assigning coordinates",
            )
    # fmt: on
    # ... you can add more methods as needed

Notes that the candidates only require a small portion of the data so we can compress the data to speed up the processing.

In [None]:
# | export
def compress_data_by_cands(
    data: pl.DataFrame, candidates: pl.DataFrame, tau: timedelta
):
    """Compress the data for parallel processing"""
    ttstarts = candidates["tstart"] - tau
    ttstops = candidates["tstop"] + tau

    ttstarts_index = data["time"].search_sorted(ttstarts)
    ttstops_index = data["time"].search_sorted(ttstops)

    indices = np.concatenate(
        [
            np.arange(ttstart_index, ttstop_index + 1)
            for ttstart_index, ttstop_index in zip(ttstarts_index, ttstops_index)
        ]
    )  # faster than `pl.arange`
    indices_unique = (
        pl.Series(indices).unique().sort()
    )  # faster than `np.unique(index)`
    return data[indices_unique]


# data.filter(
#     pl.any_horizontal(
#         pl.col('time').is_between(*ttrange) for ttrange in ttranges
#     )
# )

In [None]:
# | export


def sort_df(df: pl.DataFrame, col='time'):
    if df.get_column(col).is_sorted():
        return df.set_sorted(col)
    else:
        return df.sort(col)

def process_candidates(
    candidates_pl: pl.DataFrame,  # potential candidates DataFrame
    sat_fgm: xr.DataArray,  # satellite FGM data
    data_resolution: timedelta,  # time resolution of the data
    sat_state: pl.DataFrame = None  # satellite state data
) -> pl.DataFrame:  # processed candidates DataFrame
    
    test_eq(sat_fgm.shape[1],3)
    candidates = convert_to_dataframe(candidates_pl)
    
    id_pipelines = IDsPipeline()
    candidates = id_pipelines.calc_duration(sat_fgm).apply(candidates)

    # calibrate duration
    temp_candidates = candidates.loc[
        lambda df: df["d_tstart"].isnull() | df["d_tstop"].isnull()
    ]  # temp_candidates = candidates.query('d_tstart.isnull() | d_tstop.isnull()') # not implemented in `modin`

    if not temp_candidates.empty:
        temp_candidates_updated = id_pipelines.calibrate_duration(
            sat_fgm, data_resolution
        ).apply(temp_candidates)
        candidates.update(temp_candidates_updated)

    ids = (
        id_pipelines.classify_id(sat_fgm)
        + id_pipelines.calc_rotation_angle(sat_fgm)
    ).apply(
        candidates.dropna()  # Remove candidates with NaN values)
    )

    if isinstance(ids, mpd.DataFrame):
        ids = ids._to_pandas()
    if isinstance(ids, pandas.DataFrame):
        ids_pl = pl.DataFrame(ids)

    ids_pl = sort_df(ids_pl, col="d_time")
    sat_state = sort_df(sat_state, col="time")
    
    ids_pl = ids_pl.join_asof(
        sat_state, left_on="d_time", right_on="time", strategy="nearest"
    ).drop("time_right")

    return ids_pl

In [None]:
# | export

def ids_finder(data: pl.DataFrame, tau, params):
    tau = timedelta(seconds=tau)
    data_resolution = timedelta(seconds=params["data_resolution"])
    bcols = params["bcols"]
    data = data.sort("time")

    # get candidates
    indices = compute_indices(data, tau)
    sparse_num = tau / data_resolution // 3
    filter_condition = get_ID_filter_condition(sparse_num=sparse_num)
    candidates_pl = indices.filter(filter_condition).with_columns(pl_format_time(tau))
    candidates = convert_to_dataframe(candidates_pl)

    data_c = compress_data_by_cands(data, candidates_pl, tau)
    sat_fgm = df2ts(data_c, bcols, attrs={"units": "nT"})
    ids = process_candidates(candidates, sat_fgm, data_resolution)
    return ids


def extract_features(
    partitioned_input: Dict[str, Callable[[], Any]], tau, params
) -> pl.DataFrame:
    ids = pl.concat(
        [
            ids_finder(
                partition_load()
                if isinstance(partition_load, Callable)
                else partition_load,
                tau,
                params,
            )
            for partition_load in partitioned_input.values()
        ]
    )  # load the actual partition data

    return ids.unique(["d_time", "d_tstart", "d_tstop"])

## Test

Generally `mapply` and `modin` are the fastest. `xorbits` is expected to be the fastest but it is not and it is the slowest one.

In [None]:
#| notest
sat = 'jno'
coord = 'se'
bcols = ["BX", "BY", "BZ"]
tau = timedelta(seconds=60)
data_resolution = timedelta(seconds=1)

if True:
    year = 2012
    files = f'../data/{sat}_data_{year}.parquet'
    output = f'../data/{sat}_candidates_{year}_tau_{tau.seconds}.parquet'

    data = pl.scan_parquet(files).set_sorted('time').collect()

    indices = compute_indices(data, tau)
    # filter condition
    sparse_num = tau / data_resolution // 3
    filter_condition = get_ID_filter_condition(sparse_num = sparse_num)

    candidates_pl = indices.filter(filter_condition).with_columns(pl_format_time(tau)).sort('time')
    
    data_c = compress_data_by_cands(data, candidates_pl, tau)
    sat_fgm = df2ts(data_c, bcols, attrs={"coordinate_system": coord, "units": "nT"})

### Test parallelization

In [None]:
#| notest
candidates_pd = candidates_pl.to_pandas()
candidates_modin = mpd.DataFrame(candidates_pd)
# candidates_x = xpd.DataFrame(candidates_pd)

In [None]:
#| code-summary: Test different libraries to parallelize the computation
#| notest
if True:
    pdp_test = pdp.ApplyToRows(
        lambda candidate: calc_candidate_duration(candidate, sat_fgm),  # fast a little bit
        # lambda candidate: calc_duration(get_candidate_data_xr(candidate, sat_fgm)),
        # lambda candidate: calc_duration(sat_fgm.sel(time=slice(candidate['tstart'], candidate['tstop']))),
        func_desc="calculating duration parameters",
    )
    
    # process_candidates(candidates_modin, sat_fgm, sat_state, data_resolution)
    
    # ---
    # successful cases
    # ---
    # candidates_pd.mapply(lambda candidate: calc_candidate_duration(candidate, sat_fgm), axis=1) # this works, 4.2 secs
    # candidates_pd.mapply(calc_candidate_duration, axis=1, data=sat_fgm) # this works, but a little bit slower, 6.7 secs
    
    # candidates_pd.apply(calc_candidate_duration, axis=1, data=sat_fgm) # Standard case: 24+s secs
    # candidates_pd.swifter.apply(calc_candidate_duration, axis=1, data=sat_fgm) # this works with dask, 80 secs
    # candidates_pd.swifter.set_dask_scheduler(scheduler="threads").apply(calc_candidate_duration, axis=1, data=sat_fgm) # this works with dask, 60 secs
    # candidates_modin.apply(lambda candidate: calc_candidate_duration(candidate, sat_fgm), axis=1) # this works with ray, 6 secs # NOTE: can not work with dask
    # candidates_x.apply(calc_candidate_duration, axis=1, data=sat_fgm) # 30 seconds
    # pdp_test(candidates_modin) # this works, 8 secs
    
    # ---
    # failed cases
    # ---
    # candidates_modin.apply(calc_candidate_duration, axis=1, data=sat_fgm) # AttributeError: 'DataFrame' object has no attribute 'sel'

### Test feature engineering

In [None]:
data = data_c.to_pandas().set_index('time')

In [None]:
from tsflex.features import MultipleFeatureDescriptors, FeatureCollection

from tsflex.features.integrations import catch22_wrapper
from pycatch22 import catch22_all

In [None]:
tau_pd = pd.Timedelta(tau)

catch22_feats = MultipleFeatureDescriptors(
    functions=catch22_wrapper(catch22_all),
    series_names=bcols,  # list of signal names
    windows = tau_pd, strides=tau_pd/2,
)

fc = FeatureCollection(catch22_feats)
features = fc.calculate(data, return_df=True)  # calculate the features on your data




06-Oct-23 21:38:31: Finished function [[wrapped]__catch22_all] on [('BX',)] with window-stride [0 days 00:01:00, ('0 days 00:00:30',)] in [35.801105260849 seconds]!
06-Oct-23 21:38:31: Finished function [[wrapped]__catch22_all] on [('BY',)] with window-stride [0 days 00:01:00, ('0 days 00:00:30',)] in [35.89824819564819 seconds]!
06-Oct-23 21:38:31: Finished function [[wrapped]__catch22_all] on [('BZ',)] with window-stride [0 days 00:01:00, ('0 days 00:00:30',)] in [36.15569996833801 seconds]!


In [None]:
from ydata_profiling import ProfileReport

In [None]:
features_pl = pl.DataFrame(features.reset_index()).sort('time')
df = candidates_pl.join_asof(features_pl, on='time').to_pandas()

In [None]:
profile = ProfileReport(df, title="JUNO Candidates Report")
profile.to_file("jno.html")

### Benchmark

In [None]:
import timeit

In [None]:
def benchmark(task_dict, number=1):
    results = {}
    for name, (data, task) in task_dict.items():
        try:
            time_taken = timeit.timeit(
                lambda: task(data),
                number=number
            )
            results[name] = time_taken / number
        except Exception as e:
            results[name] = str(e)
    return results

In [None]:
#| notest
func = lambda candidate: calc_candidate_duration(candidate, sat_fgm)
task_dict = {
    'pandas': (candidates_pd, lambda _: _.apply(func, axis=1)),
    'pandas-mapply': (candidates_pd, lambda _: _.mapply(func, axis=1)),
    'modin': (candidates_modin, lambda _: _.apply(func, axis=1)),
    # 'xorbits': (candidates_x, lambda _: _.apply(func, axis=1)),
}

results = benchmark(task_dict)

## Notes

### TODOs

1. Feature engineering
2. Feature selection

## Obsolete codes

In [None]:
def calc_vec_mean_mag(vec: xr.DataArray):
    return linalg.norm(vec, dims="v_dim").mean(dim="time")


def calc_vec_std(vec: xr.DataArray):
    """
    Computes the standard deviation of a vector.
    """
    return linalg.norm(vec.std(dim="time"), dims="v_dim")


def calc_vec_relative_diff(vec: xr.DataArray):
    """
    Computes the relative difference between the last and first elements of a vector.
    """
    dvec = vec.isel(time=-1) - vec.isel(time=0)
    return linalg.norm(dvec, dims="v_dim") / linalg.norm(vec, dims="v_dim").mean(
        dim="time"
    )

#### `process_candidates`
Assign coordinates using `Dataframe.apply` is not optimized, quite slow...

In [None]:
def process_candidates(
    candidates: pd.DataFrame, # potential candidates DataFrame
    sat_fgm: xr.DataArray, # satellite FGM data
    sat_state: xr.DataArray, # satellite state data
    data_resolution: timedelta, # time resolution of the data
) -> pd.DataFrame: # processed candidates DataFrame
    id_pipelines = IDsPipeline()

    candidates = id_pipelines.calc_duration(sat_fgm).apply(candidates)

    # calibrate duration
    temp_candidates = candidates.loc[
        lambda df: df["d_tstart"].isnull() | df["d_tstop"].isnull()
    ]  # temp_candidates = candidates.query('d_tstart.isnull() | d_tstop.isnull()') # not implemented in `modin`

    if not temp_candidates.empty:
        candidates.update(
            id_pipelines.calibrate_duration(sat_fgm, data_resolution).apply(
                temp_candidates
            )
        )

    ids = (
        id_pipelines.classify_id(sat_fgm)
        + id_pipelines.calc_rotation_angle(sat_fgm)
        + id_pipelines.assign_coordinates(sat_state)
    ).apply(
        candidates.dropna()  # Remove candidates with NaN values)
    )

    return ids