## Background

Spacecraft-Solar equatorial

### Coordinate System of Data:

1. **SE (Solar Equatorial)**
    - Code: `se`
    - Resampling options: 
        - Number of seconds (1 or 60): `se_rN[N]s`
        - Resampled 1 hour: `se_r1h`

2. **PC (Planetocentric)**
    - Code: `pc`
    - Resampling options: 
        - Number of seconds (1 or 60): `pc_rN[N]s`
        
3. **SS (Sun-State)**
    - Code: `ss`
    - Resampling options: 
        - Number of seconds (1 or 60): `ss_rN[N]s`
        
4. **PL (Payload)**
    - Code: `pl`
    - Resampling options: 
        - Number of seconds (1 or 60): `pl_rN[N]s`


```txt
------------------------------------------------------------------------------
Juno Mission Phases                                                           
------------------------------------------------------------------------------
Start       Mission                                                           
Date        Phase                                                             
==============================================================================
2011-08-05  Launch                                                            
2011-08-08  Inner Cruise 1                                                    
2011-10-10  Inner Cruise 2                                                    
2013-05-28  Inner Cruise 3                                                    
2013-11-05  Quiet Cruise                                                      
2016-01-05  Jupiter Approach                                                  
2016-06-30  Jupiter Orbital Insertion                                         
2016-07-05  Capture Orbit                                                     
2016-10-19  Period Reduction Maneuver                                         
2016-10-20  Orbits 1-2                                                        
2016-11-09  Science Orbits                                                    
2017-10-11  Deorbit
```

```txt
File Naming Convention                                                        
==============================================================================
Convention:                                                                   
   fgm_jno_LL_CCYYDDDxx_vVV.ext                                               
Where:                                                                        
   fgm - Fluxgate Magnetometer three character instrument abbreviation        
   jno - Juno                                                                 
    LL - CODMAC Data level, for example, l3 for level 3                       
    CC - The century portion of a date, 20                                    
    YY - The year of century portion of a date, 00-99                         
   DDD - The day of year, 001-366                                             
    xx - Coordinate system of data (se = Solar equatorial, ser = Solar        
         equatorial resampled, pc = Planetocentric, ss = Sun-State,           
         pl = Payload)                                                        
     v - separator to denote Version number                                   
    VV - version                                                              
   ext - file extension (sts = Standard Time Series (ASCII) file, lbl = Label 
         file)                                                                
Example:                                                                      
   fgm_jno_l3_2014055se_v00.sts    
```

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| code-summary: import all the packages needed for the project

from ids_finder.utils import *
from ids_finder.core import *

import polars as pl
import xarray as xr
try:
    import modin.pandas as pd
    import modin.pandas as mpd
except ImportError:
    import pandas as pd

import pandas
import numpy as np
from xarray_einstats import linalg

from datetime import timedelta

from loguru import logger


import pytplot
from pytplot import timebar
from pytplot import get_data, store_data, tplot, split_vec, join_vec, tplot_options, options, tlimit, highlight, degap

import pdpipe as pdp


from collections.abc import Callable
from pandas import (
    DataFrame,
    Timestamp,
)
from xarray.core.dataarray import DataArray


In [None]:
sat = 'jno'
coord = 'se'
tau = timedelta(seconds=60)
data_resolution = timedelta(seconds=1)

if True:
    year = 2011
    files = f'../data/{sat}_{year}.parquet'
    output = f'../data/{sat}_candidates_{year}_tau_{tau.seconds}.parquet'

    data = pl.scan_parquet(files).set_sorted('time').collect()
    sat_fgm = df2ts(data, ["BX", "BY", "BZ"], attrs={"coordinate_system": coord, "units": "nT"})
    sat_state = df2ts(data, ["X", "Y", "Z"], attrs={"coordinate_system": coord, "units": "km"})

    indices = compute_indices(data, tau)
    # filter condition
    sparse_num = tau / data_resolution // 3
    filter_condition = get_ID_filter_condition(sparse_num = sparse_num)

    candidates_pl = indices.filter(filter_condition).with_columns(pl_format_time(tau))
    candidates = convert_to_dataframe(candidates_pl)
    
    ids = process_candidates(candidates, sat_fgm, sat_state, data_resolution)
    
    if isinstance(ids, mpd.DataFrame):
        ids._to_pandas().to_parquet(output)


    import ray
    ray.init()


22-Sep-23 10:46:03: Unable to poll TPU GCE metadata: HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/attributes/accelerator-type (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object>: Failed to establish a new connection: [Errno 8] nodename nor servname provided, or not known'))
22-Sep-23 10:46:03: Failed to detect number of TPUs: [Errno 2] No such file or directory: '/dev/vfio'
2023-09-22 10:46:03,886	INFO worker.py:1642 -- Started a local Ray instance.

Distributing Dataframe: 100%██████████ Elapsed time: 00:00, estimated remaining time: 00:00


### Download all the files

In [None]:
# !wget -r --no-parent --no-clobber 'https://pds-ppi.igpp.ucla.edu/data/JNO-SS-3-FGM-CAL-V1.0/DATA/CRUISE/SE/1SEC/'
# !aria2c -x 16 -s 16 'https://pds-ppi.igpp.ucla.edu/ditdos/download?id=pds://PPI/JNO-SS-3-FGM-CAL-V1.0/DATA/CRUISE/SE/1SEC'

In [None]:
# Convert data from `lbl` to `orc` format
import os
from pathlib import Path
from loguru import logger
import pdr
    
def convert_file(file_path: Path, target_suffix: str, conversion_func: callable, check_exist = True) -> None:
    target_file = file_path.with_suffix(target_suffix)
    if check_exist and target_file.exists():
        logger.info(f"File {target_file} already exists. Skipping...")
        return
    
    conversion_func(file_path, target_file)
    logger.info(f"Converted {file_path} to {target_file}")

def lbl_to_orc_conversion(src: Path, dest: Path) -> None:
    df = pdr.read(src).TABLE
    df.to_orc(dest)

def orc_to_parquet_conversion(src: Path, dest: Path) -> None:
    # We can also read partitioned datasets with multiple ORC files through the pyarrow.dataset interface.
    
    from pyarrow import orc
    import pyarrow.parquet as pq

    # import polars as pl
    # df = pl.from_arrow( orc.read_table(src) )
    # df.write_parquet(dest)
    
    table = orc.read_table(src)
    pq.write_table(table, dest)

def convert_format(format_from, format_to):
    conversion_map = {
        ('lbl', 'orc'): lbl_to_orc_conversion,
        ('orc', 'parquet'): orc_to_parquet_conversion
    }
    
    convert_func = conversion_map.get((format_from, format_to))
    if not convert_func:
        raise ValueError(f"Conversion from {format_from} to {format_to} is not supported")

    local_dir = Path(os.environ['HOME']) / 'juno'
    pattern = f'**/*.{format_from}'

    for file in local_dir.glob(pattern):
        convert_file(file, f".{format_to}", convert_func)

if __name__ == "__main__":
    # convert_format('lbl', 'orc')
    convert_format('orc', 'parquet')


# delete all files with extension
# find . -type f -name '*.parquet' -delete
# find . -type f -name '*.orc' -delete
# find . -type f -name '*.lbl' -delete

In [None]:
pds_dir = "https://pds-ppi.igpp.ucla.edu/data"

possible_coords = ["se", "ser", "pc", "ss", "pl"]
possible_exts = ["sts", "lbl"]
possible_data_rates = ["1s", "1min", "1h"]

juno_ss_config = {
    "DATA_SET_ID": "JNO-SS-3-FGM-CAL-V1.0",
    "FILE_SPECIFICATION_NAME": "INDEX/INDEX.LBL",
}

juno_j_config = {
    "DATA_SET_ID": "JNO-J-3-FGM-CAL-V1.0",
    "FILE_SPECIFICATION_NAME": "INDEX/INDEX.LBL",
}

In [None]:
import pdr

def download_and_read_file(config, index_table=False):
    """Download and read file for each config.

    Returns:
        DataFrame: The data read from the file.
    """
    # BUG: index file is not formatted properly according to `lbl` file, so can not be used with `pdr` for.
    # ValueError: time data "282T00:00:31.130,2019" doesn't match format "%Y-%jT%H:%M:%S.%f", at position 3553. You might want to try:
    # - passing `format` if your strings have a consistent format;
    # - passing `format='ISO8601'` if your strings are all ISO8601 but not necessarily in exactly the same format;
    # - passing `format='mixed'`, and the format will be inferred for each element individually. You might want to use `dayfirst` alongside this.

    local_dir = os.path.join(os.environ["HOME"], "juno", config["DATA_SET_ID"])
    base_url = f"{pds_dir}/{config['DATA_SET_ID']}"

    lbl_fn = config["FILE_SPECIFICATION_NAME"]

    if not index_table:
        parquet_fn = lbl_fn.replace("lbl", "parquet")
        parquet_fp = os.path.join(local_dir, parquet_fn)
        if os.path.exists(parquet_fp):
            return pandas.read_parquet(os.path.join(local_dir, parquet_fn))

    lbl_file_url = f"{base_url}/{lbl_fn}"
    lbl_fp = download_file(lbl_file_url, local_dir, lbl_fn)
    logger.debug(f"Reading {lbl_fp}")

    if index_table:
        tab_fn = lbl_fn.replace("LBL", "TAB")
        tab_fp = download_file(f"{base_url}/{tab_fn}", local_dir, tab_fn)
        tab_index = pandas.read_csv(tab_fp, delimiter=",", quotechar='"')
        tab_index.columns = tab_index.columns.str.replace(" ", "")
        return tab_index
    else:
        sts_fn = lbl_fn.replace("lbl", "sts")
        download_file(f"{base_url}/{sts_fn}", local_dir, sts_fn)
        return pdr.read(lbl_fp).TABLE

In [None]:

juno_ss_index = download_and_read_file(juno_ss_config, index_table=True)
juno_j_index = download_and_read_file(juno_j_config, index_table=True)

_index_time_format = "%Y-%jT%H:%M:%S.%f"

jno_pipeline = pdp.PdPipeline(
    [
        pdp.ColDrop(["PRODUCT_ID", "CR_DATE", "PRODUCT_LABEL_MD5CHECKSUM"]),
        pdp.ApplyByCols("SID", str.rstrip),
        pdp.ApplyByCols("FILE_SPECIFICATION_NAME", str.rstrip),
        pdp.ColByFrameFunc(
            "START_TIME",
            lambda df: pandas.to_datetime(df["START_TIME"], format=_index_time_format),
        ),
        pdp.ColByFrameFunc(
            "STOP_TIME",
            lambda df: pandas.to_datetime(df["STOP_TIME"], format=_index_time_format),
        ),
        # pdp.ApplyByCols(['START_TIME', 'STOP_TIME'], pandas.to_datetime, format=_index_time_format), # NOTE: This is slow
    ]
)
if True:
    index_df = pandas.concat(
        [jno_pipeline(juno_ss_index), jno_pipeline(juno_j_index)], ignore_index=True
    )


[32m2023-09-21 02:19:27.351[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mdownload_and_read_file[0m:[36m28[0m - [34m[1mReading /Users/zijin/juno/JNO-SS-3-FGM-CAL-V1.0/INDEX/INDEX.LBL[0m
[32m2023-09-21 02:19:27.372[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mdownload_and_read_file[0m:[36m28[0m - [34m[1mReading /Users/zijin/juno/JNO-J-3-FGM-CAL-V1.0/INDEX/INDEX.LBL[0m


In [None]:
def juno_load_fgm(trange: list, coord="se", data_rate="1s") -> DataFrame:
    """
    Get the data array for a given time range and coordinate.

    Parameters:
        trange (list): The time range.
        coord (str, optional): The coordinate. Defaults to 'se'.
        data_rate (str, optional): The data rate. Defaults to '1s'.

    Returns:
        pandas.DataFrame: The dataframe for the given time range and coordinate.
    """

    if len(trange) != 2:
        raise ValueError(
            "Expected trange to have exactly 2 elements: start and stop time."
        )

    start_time = pandas.Timestamp(trange[0])
    stop_time = pandas.Timestamp(trange[1])

    temp_index_df = index_df[
        (index_df["SID"] == get_sid(coord, data_rate))
    ].reset_index(drop=True)

    # Filtering
    relevant_files = temp_index_df[
        (temp_index_df["STOP_TIME"] > start_time)
        & (temp_index_df["START_TIME"] < stop_time)
    ]
    dataframes = [download_and_read_file(row) for _, row in relevant_files.iterrows()]

    # rows = [row for _, row in relevant_files.iterrows()]
    # with concurrent.futures.ThreadPoolExecutor() as executor:
    #     dataframes = list(executor.map(download_and_read_file, rows))

    combined_data = pandas.concat(dataframes)

    return pdp_process_juno_df(combined_data)

def get_sid(coord, data_rate):
    sid_mapping = {
        "pc": {"1s": "PC 1 SECOND", "1min": "PC 1 MINUTE", "": "PCENTRIC"},
        "pl": {"1s": "PAYLOAD 1 SECOND", "": "PAYLOAD"},
        "ss": {"1s": "SS 1 SECOND", "1min": "SS 1 MINUTE", "": "SUNSTATE"},
        "se": {"1s": "SE 1 SECOND", "1min": "SE 1 MINUTE", "": "SE"},
    }
    try:
        return sid_mapping[coord][data_rate]
    except KeyError:
        return None

_skip_cond = ~pdp.cond.HasAllColumns(["SAMPLE UTC", "DECIMAL DAY", "INSTRUMENT RANGE"])
pdp_process_juno_df = pdp.PdPipeline(
    [
        pdp.ColByFrameFunc(
            "time",
            lambda df: pandas.to_datetime(df["SAMPLE UTC"], format="%Y %j %H %M %S %f"),
            skip=_skip_cond,
        ),
        pdp.ColDrop(["SAMPLE UTC", "DECIMAL DAY", "INSTRUMENT RANGE"], skip=_skip_cond),
        pdp.df.set_index("time"),
        pdp.ColRename(col_renamer)
        # pdp.AggByCols('SAMPLE UTC', func=lambda time: pandas.to_datetime(time, format='%Y %j %H %M %S %f'), func_desc='Convert time to datetime') # NOTE: this is quite slow
        # pdp.df['time'] << pandas.to_datetime(pdp.df['SAMPLE UTC'], format='%Y %j %H %M %S %f'), # NOTE: this is not work
    ],
)

1 day of data resampled by 1 sec is about 12 MB.

So 1 year of data is about 4 GB, and 6 years of JUNO Cruise data is about 24 GB.

Downloading rate is about 250 KB/s, so it will take about 3 days to download all the data.

In [None]:
num_of_files = 6*365
jno_file_size = 12e3
thm_file_size = 40e3
files_size = jno_file_size + thm_file_size
downloading_rate = 250
processing_rate = 1/60

time_to_download = num_of_files * files_size / downloading_rate / 60 / 60
space_required = num_of_files * files_size / 1e6
time_to_process = num_of_files / processing_rate / 60 / 60

print(f"Time to download: {time_to_download:.2f} hours")
print(f"Disk space required: {space_required:.2f} GB")
print(f"Time to process: {time_to_process:.2f} hours")


Time to download: 126.53 hours
Disk space required: 113.88 GB
Time to process: 36.50 hours


### Convert data format

See [convert_format.py](convert_format.py)

### Check the data

In [None]:
jno_ss_index_df = index_df.loc[ lambda _: _['VOLUME_ID'] == 'JNOFGM_1000']

starting_date = jno_ss_index_df['START_TIME'].min().date()
ending_date = jno_ss_index_df['STOP_TIME'].max().date()

logger.info(f"Starting date: {starting_date}")
logger.info(f"Ending date: {ending_date}")


[32m2023-09-21 02:19:27.516[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 6>[0m:[36m6[0m - [1mStarting date: 2011-08-25[0m
[32m2023-09-21 02:19:27.517[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 7>[0m:[36m7[0m - [1mEnding date: 2016-06-29[0m


In [None]:
available_dates = pandas.concat([jno_ss_index_df['START_TIME'].dt.date, jno_ss_index_df['STOP_TIME'].dt.date]).unique()
full_year_range = pandas.date_range(start=starting_date, end=ending_date)

missing_dates = full_year_range[~full_year_range.isin(available_dates)]

if len(missing_dates) == 0:
    print(f"No days are missing.")
else:
    print(f"The following days are missing")
    for date in missing_dates:
        print(date.strftime('%Y-%m-%d'))


The following days are missing
2012-04-20
2012-04-21
2012-04-22
2012-04-23
2012-04-24
2012-05-15
2012-06-15
2012-07-04
2012-07-05
2012-07-06
2012-07-07
2012-07-08
2012-08-25
2012-08-26
2012-08-27
2012-08-28
2012-08-29
2012-08-30
2012-08-31
2012-09-01
2012-09-02
2012-09-03
2012-09-04
2012-09-05
2012-09-06
2012-09-07
2012-09-08
2012-09-09
2012-09-10
2012-09-11
2012-09-12
2012-09-13
2012-09-14
2012-09-15
2012-09-16
2012-09-17
2012-10-05
2012-10-06
2012-10-07
2012-10-08
2012-12-13
2012-12-14
2012-12-15
2012-12-16
2012-12-17
2012-12-18
2013-05-11
2013-05-12
2013-05-13
2013-05-14
2013-05-23
2013-05-24
2013-05-25
2013-05-26
2013-05-27
2013-05-28
2013-05-29
2013-05-30
2013-05-31
2013-06-01
2013-06-02
2013-06-03
2013-06-04
2013-06-05
2013-06-06
2013-06-07
2013-06-08
2013-06-09
2013-06-10
2013-06-11
2013-06-12
2013-06-13
2013-06-14
2013-06-15
2013-06-16
2013-06-17
2013-06-18
2013-06-19
2013-06-20
2013-06-21
2013-06-22
2013-06-23
2013-06-24
2013-06-25
2013-06-26
2013-06-27
2013-06-28
2013-06-29
2

### Clean the data

In [None]:
def batch_pre_process(year, force=False):

    trange = [f"{year}-01-01", f"{year+1}-01-02"]  # having some overlap
    dir_path = Path(os.environ["HOME"], "juno/JNO-SS-3-FGM-CAL-V1.0/")
    pattern = "**/*.parquet"
    data = dir_path / pattern
    
    file = f"data/jno_{year}.parquet"
    if os.path.exists(file) and not force:
        logger.info(f"File {file} exists. Skipping")
        return file
    logger.info(f"Preprocessing data for year {year}")
    
    lazy_df = pl.scan_parquet(data)
    temp_df = (
        lazy_df.filter(
            pl.col("time").is_between(pd.Timestamp(trange[0]), pd.Timestamp(trange[1])),
        )
        .sort(
            "time"
        )  # needed for `compute_index_std` to work properly as `group_by_dynamic` requires the data to be sorted
        .filter(
            pl.col(
                "time"
            ).is_first_distinct()  # remove duplicate time values for xarray to select data properly, though significantly slows down the computation
        )
        .rename({"BX SE": "BX", "BY SE": "BY", "BZ SE": "BZ"})
    )
    temp_df.collect().write_parquet(file)
    return file

starting_year = starting_date.year
ending_year = ending_date.year

In [None]:
for year in range(starting_year, ending_year+1):
    batch_pre_process(year)

[32m2023-09-21 02:19:27.541[0m | [1mINFO    [0m | [36m__main__[0m:[36mbatch_pre_process[0m:[36m10[0m - [1mFile data/jno_2011.parquet exists. Skipping[0m
[32m2023-09-21 02:19:27.542[0m | [1mINFO    [0m | [36m__main__[0m:[36mbatch_pre_process[0m:[36m10[0m - [1mFile data/jno_2012.parquet exists. Skipping[0m
[32m2023-09-21 02:19:27.542[0m | [1mINFO    [0m | [36m__main__[0m:[36mbatch_pre_process[0m:[36m10[0m - [1mFile data/jno_2013.parquet exists. Skipping[0m
[32m2023-09-21 02:19:27.543[0m | [1mINFO    [0m | [36m__main__[0m:[36mbatch_pre_process[0m:[36m10[0m - [1mFile data/jno_2014.parquet exists. Skipping[0m
[32m2023-09-21 02:19:27.544[0m | [1mINFO    [0m | [36m__main__[0m:[36mbatch_pre_process[0m:[36m10[0m - [1mFile data/jno_2015.parquet exists. Skipping[0m
[32m2023-09-21 02:19:27.544[0m | [1mINFO    [0m | [36m__main__[0m:[36mbatch_pre_process[0m:[36m10[0m - [1mFile data/jno_2016.parquet exists. Skipping[0m


### Load the data

In [None]:
if test:
    trange = ['2012','2012-6-01']
    trange = ['2016','2017']
    trange = ['2011','2013']

    coord = 'se'
    data_rate='1s'
    tau = 30
    
    # get a temporary dataframe for testing
    dir_path = Path(os.environ["HOME"], "juno/JNO-SS-3-FGM-CAL-V1.0/")
    pattern = "**/*.parquet"
    data = dir_path / pattern

    lazy_df = pl.scan_parquet(data)
    temp_df = (
        lazy_df.filter(
            pl.col("time").is_between(pd.Timestamp(trange[0]), pd.Timestamp(trange[1])),
        )
        .sort("time")
        .rename({"BX SE": "BX", "BY SE": "BY", "BZ SE": "BZ"})
    )

    lazy = False
    lazy = True
    if not lazy:
        temp_df = temp_df.collect()
        # juno_fgm_df = juno_load_fgm(trange, coord=coord, data_rate=data_rate)
        sat_fgm = sat_get_fgm_from_df(temp_df)
        sat_state = juno_get_state(temp_df)


## ID identification

The first index is $$ \frac{\sigma(B)}{Max(\sigma(B_-),\sigma(B_+))} $$
The second index is $$ \frac{\sigma(B_- + B_+)} {\sigma(B_-) + \sigma(B_+)} $$
The ﬁrst two conditions guarantee that the ﬁeld changes of the IDs identiﬁed are large enough to be distinguished from the stochastic ﬂuctuations on magnetic ﬁelds, while the third is a supplementary condition toreduce the uncertainty of recognition.

third index (relative field jump) is $$ \frac{| \Delta \vec{B} |}{|B_{bg}|} $$ a supplementary condition toreduce the uncertainty of recognition

### Index of datapoints

In [None]:
# | code-summary: get the number of data points in each time interval and inspect the result
if test:
    index_num = (
        temp_df.group_by_dynamic(
            "time",
            every=f"{tau//2}s",
            period=f"{tau}s",
        )
        .agg(pl.count())
    )

    sparse_num = tau // 4
    sparse_intervals = index_num.filter(pl.col("count") < sparse_num)
    # logger.info(f'Num of intervals where data are sparse: {len(sparse_intervals)} out of {len(data_points_df)} ({len(sparse_intervals)/len(data_points_df)*100:.2f}%)')


### Index of the standard deviation

In [None]:
help(compute_index_std)

Help on function compute_index_std in module utils:

compute_index_std(data, tau)
    helper function to compute standard deviation index


### Index of fluctuation

In [None]:
help(compute_index_fluctuation)
help(compute_index_fluctuation_xr)

Help on function compute_index_fluctuation in module utils:

compute_index_fluctuation(data, tau)
    helper function to compute fluctuation index
    
    Notes: the results returned are a little bit different for the two implementations (because of the implementation of `std`).

Help on function compute_index_fluctuation_xr in module utils:

compute_index_fluctuation_xr(data: xarray.core.dataarray.DataArray, tau: int) -> xarray.core.dataarray.DataArray
    Computes the fluctuation index for a given data array based on a specified time interval.
    
    Parameters:
    - data: The xarray DataArray containing the data to be processed.
    - tau: Time interval in seconds for resampling.
    
    Returns:
    - fluctuation: xarray DataArray containing the fluctuation indices.
    
    Notes
    -----
        ddof=0 is used for calculating the standard deviation. (ddof=1 is for sample standard deviation)


In [None]:
# i2 = index_fluctuation(juno_fgm_b, tau)
# index_fluctuation_df = compute_index_fluctuation(temp_df, tau)

### Index of the relative field jump

In [None]:
def index_diff(data: DataArray, tau):
    grouped_data = data.resample(time=pd.Timedelta(tau, unit='s'))

    dvecs = grouped_data.first()-grouped_data.last()
    vec_mean_mags = grouped_data.map(calc_vec_mean_mag)
    vec_diffs = linalg.norm(dvecs, dims='v_dim') / vec_mean_mags
    
    # vec_diffs = grouped_data.map(calc_vec_relative_diff) # NOTE: this is slower than the above implementation.
    # INFO: Do your spatial and temporal indexing (e.g. .sel() or .isel()) early in the pipeline, especially before calling resample() or groupby(). Grouping and resampling triggers some computation on all the blocks, which in theory should commute with indexing, but this optimization hasn’t been implemented in Dask yet. (See Dask issue #746).
    
    offset = pd.Timedelta(tau/2, unit='s')
    vec_diffs['time'] = vec_diffs['time'] + offset
    return vec_diffs



## Plotting

In [None]:
def get_candidate_data_xr(candidate, data, coord:str=None, neighbor:int=0) -> xr.DataArray:
    duration = candidate['tstop'] - candidate['tstart']
    offset = neighbor*duration
    temp_tstart = candidate['tstart'] - offset
    temp_tstop = candidate['tstop'] + offset
    
    return data.sel(time=slice(temp_tstart,  temp_tstop))

def get_candidate_data_pl(candidate, data, coord:str=None, neighbor:int=0) -> xr.DataArray:
    """
    Notes
    -----
    much slower than `get_candidate_data_xr`
    """
    duration = candidate['tstart'] - candidate['tstop']
    offset = neighbor*duration
    temp_tstart = candidate['tstart'] - offset
    temp_tstop = candidate['tstop'] + offset
    
    temp_data = data.filter(
        pl.col("time").is_between(temp_tstart, temp_tstop)
    )
    
    dims = ["v_dim", "time"]
    coords = {
        "time": temp_data['time'], 
        "v_dim": ["BX", "BY", "BZ"]
        }
    return xr.DataArray([ temp_data['BX'], temp_data['BY'], temp_data['BZ']], dims=dims, coords=coords)

def get_candidate_data(candidate, data, coord:str=None, neighbor:int=0) -> xr.DataArray:
    if isinstance(data, xr.DataArray):
        return get_candidate_data_xr(candidate, data, coord=coord, neighbor=neighbor)
    elif isinstance(data, pl.DataFrame):    
        return get_candidate_data_pl(candidate, data, coord=coord, neighbor=neighbor)

def get_candidates(candidates: DataFrame, candidate_type=None, num:int=4):
    
    if candidate_type is not None:
        _candidates = candidates[candidates['type'] == candidate_type]
    else:
        _candidates = candidates
    
    # Sample a specific number of candidates if num is provided and it's less than the total number
    if num < len(_candidates):
        logger.info(f"Sampling {num} {candidate_type} candidates out of {len(_candidates)}")
        return _candidates.sample(num)
    else:
        return _candidates

In [None]:
from pyspedas.cotrans import minvar_matrix_make
from pyspedas import tvector_rotate

In [None]:
def plot_basic(
    data, tstart, tstop, tau, mva_tstart=None, mva_tstop=None, neighbor: int = 1
):
    if mva_tstart is None:
        mva_tstart = tstart
    if mva_tstop is None:
        mva_tstop = tstop

    mva_b = data.sel(time=slice(mva_tstart, mva_tstop))
    store_data("fgm", data={"x": mva_b.time, "y": mva_b})
    minvar_matrix_make("fgm")  # get the MVA matrix

    temp_tstart = pd.Timestamp(tstart) - pd.Timedelta(neighbor * tau, unit="s")
    temp_tstop = pd.Timestamp(tstop) + pd.Timedelta(neighbor * tau, unit="s")

    temp_b = data.sel(time=slice(temp_tstart, temp_tstop))
    store_data("fgm", data={"x": temp_b.time, "y": temp_b})
    temp_btotal = calc_vec_mag(temp_b)
    store_data("fgm_btotal", data={"x": temp_btotal.time, "y": temp_btotal})

    tvector_rotate("fgm_mva_mat", "fgm")
    split_vec("fgm_rot")
    pytplot.data_quants["fgm_btotal"]["time"] = pytplot.data_quants["fgm_rot"][
        "time"
    ]  # NOTE: whenever using `get_data`, we may lose precision in the time values. This is a workaround.
    join_vec(
        [
            "fgm_rot_x",
            "fgm_rot_y",
            "fgm_rot_z",
            "fgm_btotal",
        ],
        new_tvar="fgm_all",
    )

    options("fgm", "legend_names", [r"$B_x$", r"$B_y$", r"$B_z$"])
    options("fgm_all", "legend_names", [r"$B_l$", r"$B_m$", r"$B_n$", r"$B_{total}$"])
    options("fgm_all", "ysubtitle", "[nT LMN]")
    highlight(["fgm", "fgm_all"], [tstart.timestamp(), tstop.timestamp()])
    degap("fgm")
    degap("fgm_all")

def format_candidate_title(candidate: pandas.Series):
    format_float = lambda x: rf"$\bf {x:.2f} $" if isinstance(x, (float, int)) else rf"$\bf {x} $"

    base_line = rf'$\bf {candidate.get("type", "N/A")} $ candidate (time: {candidate.get("time", "N/A")}) with index '
    index_line = rf'i1: {format_float(candidate.get("index_std", "N/A"))}, i2: {format_float(candidate.get("index_fluctuation", "N/A"))}, i3: {format_float(candidate.get("index_diff", "N/A"))}'
    info_line = rf'$B_n/B$: {format_float(candidate.get("BnOverB", "N/A"))}, $dB/B$: {format_float(candidate.get("dBOverB", "N/A"))}, $(dB/B)_{{max}}$: {format_float(candidate.get("dBOverB_max", "N/A"))},  $Q_{{mva}}$: {format_float(candidate.get("Q_mva", "N/A"))}'
    title = rf"""{base_line}
    {index_line}
    {info_line}"""
    return title


def plot_candidate(candidate: pandas.Series):
    if pd.notnull(candidate.get("d_tstart")) and pd.notnull(candidate.get("d_tstop")):
        plot_basic(
            sat_fgm,
            candidate["tstart"],
            candidate["tstop"],
            tau,
            candidate["d_tstart"],
            candidate["d_tstop"],
        )
    else:
        plot_basic(sat_fgm, candidate["tstart"], candidate["tstop"], tau)

    tplot_options("title", format_candidate_title(candidate))

    if "d_time" in candidate.keys():
        timebar(candidate["d_time"].timestamp(), color="red")
    if "d_tstart" in candidate.keys() and not pd.isnull(candidate["d_tstart"]):
        timebar(candidate["d_tstart"].timestamp())
    if "d_tstop" in candidate.keys() and not pd.isnull(candidate["d_tstop"]):
        timebar(candidate["d_tstop"].timestamp())

    # tplot(['fgm','fgm_all'])
    tplot("fgm_all")


def plot_candidates(
    candidates: pandas.DataFrame, candidate_type=None, num=4, plot_func=plot_candidate
):
    """Plot a set of candidates.

    Parameters:
    - candidates (pd.DataFrame): DataFrame containing the candidates.
    - candidate_type (str, optional): Filter candidates based on a specific type.
    - num (int): Number of candidates to plot, selected randomly.
    - plot_func (callable): Function used to plot an individual candidate.

    """

    # Filter by candidate_type if provided
    candidates = get_candidates(candidates, candidate_type, num)

    # Plot each candidate using the provided plotting function
    for _, candidate in candidates.iterrows():
        plot_func(candidate)

In [None]:
# single candidate test
if test:
    temp_candidate = candidates.iloc[1].to_dict()
    plot_candidate(temp_candidate)

## ID parameters

### Duration

Definitions of duration
- Define $d^* = \max( | dB / dt | ) $, and then define time interval where $| dB/dt |$ decreases to $d^*/4$

In [None]:
THRESHOLD_RATIO  = 1/4

from typing import Tuple

def calc_duration(vec: xr.DataArray, threshold_ratio=THRESHOLD_RATIO) -> pandas.Series:
    # NOTE: gradient calculated at the edge is not reliable.
    vec_diff = vec.differentiate("time", datetime_unit="s").isel(time=slice(1,-1))
    vec_diff_mag = linalg.norm(vec_diff, dims='v_dim')

    # Determine d_star based on trend
    if vec_diff_mag.isnull().all():
        raise ValueError("The differentiated vector magnitude contains only NaN values. Cannot compute duration.")
    
    d_star_index = vec_diff_mag.argmax(dim="time")
    d_star = vec_diff_mag[d_star_index]
    d_time = vec_diff_mag.time[d_star_index]
    
    threshold = d_star * threshold_ratio

    start_time, end_time = find_start_end_times(vec_diff_mag, d_time, threshold)

    return pandas.Series({
        'd_star': d_star.item(),
        'd_time': d_time.values,
        'threshold': threshold.item(),
        'd_tstart': start_time,
        'd_tstop': end_time,
    })

def calc_d_duration(vec: xr.DataArray, d_time, threshold) -> pd.Series:
    vec_diff = vec.differentiate("time", datetime_unit="s")
    vec_diff_mag = linalg.norm(vec_diff, dims='v_dim')

    start_time, end_time = find_start_end_times(vec_diff_mag, d_time, threshold)

    return pandas.Series({
        'd_tstart': start_time,
        'd_tstop': end_time,
    })
 
def find_start_end_times(vec_diff_mag: xr.DataArray, d_time, threshold) -> Tuple[pd.Timestamp, pd.Timestamp]:
    # Determine start time
    pre_vec_mag = vec_diff_mag.sel(time=slice(None, d_time))
    start_time = get_time_from_condition(pre_vec_mag, threshold, "last_below")

    # Determine stop time
    post_vec_mag = vec_diff_mag.sel(time=slice(d_time, None))
    end_time = get_time_from_condition(post_vec_mag, threshold, "first_below")

    return start_time, end_time


def get_time_from_condition(vec: xr.DataArray, threshold, condition_type) -> pd.Timestamp:
    if condition_type == "first_below":
        condition = vec < threshold
        index_choice = 0
    elif condition_type == "last_below":
        condition = vec < threshold
        index_choice = -1
    else:
        raise ValueError(f"Unknown condition_type: {condition_type}")

    where_result = np.where(condition)[0]

    if len(where_result) > 0:
        return vec.time[where_result[index_choice]].values
    return None

In [None]:
def calc_candidate_duration(candidate: pd.Series, data, get_candidate_data_fn:Callable =get_candidate_data_xr) -> pd.Series:
    try:
        candidate_data = get_candidate_data_fn(candidate, data)
        return calc_duration(candidate_data)
    except Exception as e:
        # logger.debug(f"Error for candidate {candidate} at {candidate['time']}: {str(e)}") # can not be serialized
        print(f"Error for candidate {candidate} at {candidate['time']}: {str(e)}")
        raise e

def calc_candidate_d_duration(candidate, data , get_candidate_data_fn:Callable =get_candidate_data) -> pd.Series:
    try:
        if pd.isnull(candidate['d_tstart']) or pd.isnull(candidate['d_tstop']):
            candidate_data = get_candidate_data_fn(candidate, data, neighbor=1)
            d_time = candidate['d_time']
            threshold = candidate['threshold']
            return calc_d_duration(candidate_data, d_time, threshold)
        else:
            return pandas.Series({
                'd_tstart': candidate['d_tstart'],
                'd_tstop': candidate['d_tstop'],
            })
    except Exception as e:
        # logger.debug(f"Error for candidate {candidate} at {candidate['time']}: {str(e)}")
        print(f"Error for candidate {candidate} at {candidate['time']}: {str(e)}")
        raise e

def calibrate_candidate_duration(candidate: pd.Series, data:xr.DataArray, data_resolution=data_resolution, ratio = 3/4):
    """
    Calibrates the candidate duration. 
    - If only one of 'd_tstart' or 'd_tstop' is provided, calculates the missing one based on the provided one and 'd_time'.
    - Then if this is not enough points between 'd_tstart' and 'd_tstop', returns None for both.
    
    
    Parameters
    ----------
    - candidate (pd.Series): The input candidate with potential missing 'd_tstart' or 'd_tstop'.
    
    Returns
    -------
    - pd.Series: The calibrated candidate.
    """
    
    start_notnull = pd.notnull(candidate['d_tstart'])
    stop_notnull = pd.notnull(candidate['d_tstop']) 
    
    match start_notnull, stop_notnull:
        case (True, True):
            d_tstart = candidate['d_tstart']
            d_tstop = candidate['d_tstop']
        case (True, False):
            d_tstart = candidate['d_tstart']
            d_tstop = candidate['d_time'] -  candidate['d_tstart'] + candidate['d_time']
        case (False, True):
            d_tstart = candidate['d_time'] -  candidate['d_tstop'] + candidate['d_time']
            d_tstop = candidate['d_tstop']
        case (False, False):
            return pandas.Series({
                'd_tstart': None,
                'd_tstop': None,
            })
    
    duration = d_tstop - d_tstart
    num_of_points_between = data.time.sel(time=slice(d_tstart, d_tstop)).count().item()
    
    if num_of_points_between <= (duration/data_resolution) * ratio:
        d_tstart = None
        d_tstop = None
    
    return pandas.Series({
        'd_tstart': d_tstart,
        'd_tstop': d_tstop,
    })

In [None]:
pdp_calc_duration = pdp.PdPipeline([
    pdp.ApplyToRows(calc_candidate_duration, func_desc='calculating duration parameters'),
    pdp.ApplyToRows(calc_candidate_d_duration, func_desc='calculating duration parameters if needed'),
])

pdp_calibrate_duration = pdp.PdPipeline([
    pdp.ApplyToRows(calibrate_candidate_duration, func_desc='calibrating duration parameters if needed'),
])
    

In [None]:
# candidates = pdp_calc_duration(candidates)

### ID classification

In this method, TDs and RDs satisfy $ \frac{ |B_N| }{ |B_{bg}| } < 0.2$ and $ | \frac{ \Delta |B| }{ |B_{bg}| } | > 0.4$ B BN bg ∣∣ ∣∣ , < D 0.2 B B bg ∣∣ ∣ ∣ , respectively. Moreover, IDs with < 0.4 B BN bg ∣∣ ∣∣ , < D 0.2 B B bg ∣∣ ∣ ∣ could be either TDs or RDs, and so are termed EDs. Similarly, NDs are defined as > 0.4 B BN bg ∣∣ ∣∣ , > D 0.2 B B bg ∣∣ ∣ ∣ because they can be neither TDs nor RDs. It is worth noting that EDs and NDs here are not physical concepts like RDs and TDs. RDs or TDs correspond to specific types of structures in the MHD framework, while EDs and NDs are introduced just to better quantify the statistical results.


Criteria Used to Classify Discontinuities on the Basis of Magnetic Data Type

| Type   |  $\|B_n/B\|$      | $\| \Delta B / B \|$  |
|----------|-------------|------|
| Rotational (RD) | large | small |
| Tangential (TD) | small |  large |
| Either (ED) | small | small |
| Neither (ND) | large | large |


#### minimum variance analysis (MVA)

To ensure the accuracy of MVA, only when the ratio of the middle to the minimum eigenvalue (labeled QMVA for simplicity) is larger than 3 are the results used for further analysis.

In [None]:
BnOverB_RD_lower_threshold = 0.4
dBOverB_RD_upper_threshold = 0.2

BnOverB_TD_upper_threshold = 0.2
dBOverB_TD_lower_threshold = dBOverB_RD_upper_threshold

BnOverB_ED_upper_threshold = BnOverB_RD_lower_threshold
dBOverB_ED_upper_threshold = dBOverB_TD_lower_threshold

BnOverB_ND_lower_threshold = BnOverB_TD_upper_threshold
dBOverB_ND_lower_threshold = dBOverB_RD_upper_threshold


In [None]:
np.array([True, False]) | np.array([True, True])

array([ True,  True])

In [None]:
from pyspedas.cotrans.minvar import minvar
def calc_classification_index(data: xr.DataArray):
    
    vrot, v, w = minvar(data.to_numpy()) # NOTE: using `.to_numpy()` will significantly speed up the computation.
    Vl = v[:,0] # Maximum variance direction eigenvector

    B_rot = xr.DataArray(vrot, dims=['time', 'v_dim'], coords={'time': data.time})
    B = calc_vec_mag(B_rot)
    B_n = B_rot.isel(v_dim=2)
    
    B_mean = B.mean(dim="time")
    B_n_mean = B_n.mean(dim="time")
    
    BnOverB = B_n_mean / B_mean
    # BnOverB = np.abs(B_n / B).mean(dim="time")

    dB = B.isel(time=-1) - B.isel(time=0)
    dBOverB = np.abs(dB / B_mean)
    dBOverB_max = (B.max(dim="time") - B.min(dim="time")) / B_mean
    
    
    return pandas.Series({
        'Vl_x': Vl[0],
        'Vl_y': Vl[1],
        'Vl_z': Vl[2],
        'eig0': w[0],
        'eig1': w[1],
        'eig2': w[2],
        'Q_mva': w[1]/w[2],
        'B': B_mean.item(),
        'B_n': B_n_mean.item(),
        'dB': dB.item(),
        'BnOverB': BnOverB.item(), 
        'dBOverB': dBOverB.item(),
        'dBOverB_max': dBOverB_max.item(),
        })

In [None]:
def classify_id(BnOverB, dBOverB):
    BnOverB = np.abs(np.asarray(BnOverB))
    dBOverB = np.asarray(dBOverB)

    s1 = (BnOverB > BnOverB_RD_lower_threshold)
    s2 = (dBOverB > dBOverB_RD_upper_threshold)
    s3 = (BnOverB > BnOverB_TD_upper_threshold)
    s4 = s2 # note: s4 = (dBOverB > dBOverB_TD_lower_threshold)
    
    RD = s1 & ~s2
    TD = ~s3 & s4
    ED = ~s1 & ~s4
    ND = s3 & s2

    # Create an empty result array with the same shape
    result = np.empty_like(BnOverB, dtype=object)

    result[RD] = "RD"
    result[TD] = "TD"
    result[ED] = "ED"
    result[ND] = "ND"

    return result


In [None]:
pdp_classify_id = pdp.PdPipeline([
    pdp.ApplyToRows(lambda candidate: calc_classification_index(sat_fgm.sel(time = slice(candidate['d_tstart'], candidate['d_tstop']))), func_desc='calculating index "q_mva", "BnOverB" and "dBOverB"'),
    pdp.ColByFrameFunc("type", lambda df: classify_id(df["BnOverB"], df["dBOverB"]),func_desc="classifying the type of the ID")
    # pdp.ApplyToRows(lambda candidate: classify_id(candidate["BnOverB"], candidate["dBOverB"]), colname="type", func_desc="classifying the type of the ID"),
])

In [None]:
# pdp_classify_id(candidates)

### Field rotation angles
The PDF of the field rotation angles across the solar-wind IDs is well fitted by the exponential function exp(−θ/)...

In [None]:
def calc_rotation_angle(v1, v2):
    """
    Computes the rotation angle between two vectors.
    
    Parameters:
    - v1: The first vector.
    - v2: The second vector.
    """
    
    if v1.shape != v2.shape:
        raise ValueError("Vectors must have the same shape.")

    # convert xr.Dataarray to numpy arrays
    if isinstance(v1, DataArray):
        v1 = v1.to_numpy()
    if isinstance(v2, DataArray):
        v2 = v2.to_numpy()
    
    # Normalize the vectors
    v1_u = v1 / np.linalg.norm(v1, axis=-1, keepdims=True)
    v2_u = v2 / np.linalg.norm(v2, axis=-1, keepdims=True)
    
    # Calculate the cosine of the angle for each time step
    cosine_angle = np.sum(v1_u * v2_u, axis=-1)
    
    # Clip the values to handle potential floating point errors
    cosine_angle = np.clip(cosine_angle, -1, 1)
    
    angle = np.arccos(cosine_angle)
    
    # Convert the angles from radians to degrees
    return np.degrees(angle)

def calc_candidate_rotation_angle(candidates, data:  xr.DataArray):
    """
    Computes the rotation angle(s) at two different time steps.
    """
    
    tstart = candidates['d_tstart']
    tstop = candidates['d_tstop']
    
    # Convert Series to numpy arrays if necessary
    if isinstance(tstart, pd.Series):
        tstart = tstart.to_numpy()
        tstop = tstop.to_numpy()
        # no need to Handle NaT values (as `calibrate_candidate_duration` will handle this)
    
    # Get the vectors at the two time steps
    vecs_before = data.sel(time=tstart, method="nearest")
    vecs_after = data.sel(time=tstop, method="nearest")
    
    # Compute the rotation angle(s)
    rotation_angles = calc_rotation_angle(vecs_before, vecs_after)
    return rotation_angles

In [None]:
pdp_calc_rotation_angle = pdp.ColByFrameFunc("rotation_angle", lambda df: calc_candidate_rotation_angle(df, data=sat_fgm), func_desc='calculating rotation angle')

### Assign satellite locations to the discontinuities

In [None]:
def get_candidate_location(candidate, location_data: DataArray):
    return location_data.sel(time = candidate['d_time']).to_series()

In [None]:
pdp_assign_coordinates = pdp.PdPipeline([
    pdp.ApplyToRows(lambda candidate: get_candidate_location(candidate, sat_state), func_desc='assigning coordinates'),
    # TODO: can we use `pdp.ColByFrameFunc` here?
])

## Processing the whole dataset

In [None]:
def get_ID_filter_condition(
    index_std_threshold = 2,
    index_fluc_threshold = 1,
    index_diff_threshold = 0.1,
    sparse_num = 15
):
    return (
        (pl.col("index_std") > index_std_threshold)
        & (pl.col("index_fluctuation") > index_fluc_threshold)
        & (pl.col("index_diff") > index_diff_threshold)
        & (
            pl.col("index_std").is_finite()
        )  # for cases where neighboring groups have std=0
        & (
            pl.col("count") > sparse_num
        )  # filter out sparse intervals, which may give unreasonable results.
    )



In [None]:
from pdpipe.util import out_of_place_col_insert

class ApplyToRows(pdp.ApplyToRows):
    """A pipeline stage that works with `modin` DataFrames.
    """
    def _transform(self, X, verbose):
        new_cols = X.apply(self._func, axis=1)
        if isinstance(new_cols, pd.Series):
            loc = len(X.columns)
            if self._follow_column:
                loc = X.columns.get_loc(self._follow_column) + 1
            return out_of_place_col_insert(
                X=X, series=new_cols, loc=loc, column_name=self._colname
            )
        if isinstance(new_cols, pd.DataFrame):
            sorted_cols = sorted(list(new_cols.columns))
            new_cols = new_cols[sorted_cols]
            if self._follow_column:
                inter_X = X
                loc = X.columns.get_loc(self._follow_column) + 1
                for colname in new_cols.columns:
                    inter_X = out_of_place_col_insert(
                        X=inter_X,
                        series=new_cols[colname],
                        loc=loc,
                        column_name=colname,
                    )
                    loc += 1
                return inter_X
            assign_map = {
                colname: new_cols[colname] for colname in new_cols.columns
            }
            return X.assign(**assign_map)
        raise TypeError(  # pragma: no cover
            "Unexpected type generated by applying a function to a DataFrame."
            " Only Series and DataFrame are allowed."
        )


In [None]:
def calc_candidate_classification_index(candidate, data):
    return calc_classification_index(
        data.sel(time=slice(candidate["d_tstart"], candidate["d_tstop"]))
    )

In [None]:
class IDsPipeline:
    def __init__(self, sat_fgm=None, sat_state=None):
        self.sat_fgm = sat_fgm
        self.sat_state = sat_state

        self.pipelines = {}

    # fmt: off
    def add_calc_duration(self):
        self.pipelines["calc_duration"] = pdp.PdPipeline([
            ApplyToRows(
                lambda candidate: calc_candidate_duration(candidate, self.sat_fgm),
                func_desc="calculating duration parameters"
            ),
            ApplyToRows(
                lambda candidate: calc_candidate_d_duration(candidate, self.sat_fgm),
                func_desc="calculating duration parameters if needed"
            ),
        ])
        return self

    def add_calibrate_duration(self):
        self.pipelines["calibrate_duration"] = \
            ApplyToRows(
                lambda candidate: calibrate_candidate_duration(candidate, self.sat_fgm),
                func_desc="calibrating duration parameters if needed"
            )
        return self

    def add_classify_id(self):
        self.pipelines["classify_id"] = pdp.PdPipeline([
            ApplyToRows(
                lambda candidate: calc_candidate_classification_index(candidate, self.sat_fgm),
                func_desc='calculating index "q_mva", "BnOverB" and "dBOverB"'
            ),
            pdp.ColByFrameFunc(
                "type",
                lambda df: classify_id(df["BnOverB"], df["dBOverB"]),
                func_desc="classifying the type of the ID"
            ),
        ])
        return self
    
    def add_calc_rotation_angle(self):
        self.pipelines["calc_rotation_angle"] = \
            pdp.ColByFrameFunc(
                "rotation_angle",
                lambda df: calc_candidate_rotation_angle(df, data=self.sat_fgm),
                func_desc="calculating rotation angle"
            )
            
        return self

    def add_assign_coordinates(self):
        self.pipelines["assign_coordinates"] = \
            ApplyToRows(
                lambda candidate: get_candidate_location(candidate, self.sat_state),
                func_desc="assigning coordinates"
            )
        # Return a new instance with added stage
        return self

    # fmt: on

    # ... you can add more methods as needed

    def get_pipeline(self, name):
        return self.pipelines.get(name)
    
# def process_candidates(
#     candidates: pl.DataFrame, sat_fgm: xr.DataArray, sat_state: xr.DataArray
# ):
#     candidates = convert_to_dataframe(candidates)
#     builder = IDsPipeline(sat_fgm, sat_state)

#     # Build pipelines
#     id_pipelines = builder.add_calc_duration().add_calibrate_duration().add_classify_id().add_calc_rotation_angle().add_assign_coordinates()

#     candidates = id_pipelines.get_pipeline("calc_duration").apply(candidates)

#     # calibrate duration
#     temp_candidates = candidates.loc[
#         lambda df: df["d_tstart"].isnull() | df["d_tstop"].isnull()
#     ]
#     if not temp_candidates.empty:
#         candidates.update(
#             id_pipelines.get_pipeline("calibrate_duration").apply(temp_candidates)
#         )

#     candidates = pdp.DropNa()(candidates)  # Remove candidates with NaN values

#     # Apply remaining pipelines (you can refactor this further if needed)
#     ids = (
#         id_pipelines.get_pipeline("classify_id") + 
#         id_pipelines.get_pipeline("calc_rotation_angle") +
#         id_pipelines.get_pipeline("assign_coordinates")
#     ).apply(candidates)
#     # Add other pipelines as needed

#     return ids

#### JUNO

In [None]:
from tqdm import tqdm


In [None]:
sat = 'jno'
coord = 'se'
tau = timedelta(seconds=60)
data_resolution = timedelta(seconds=1)

# if True:
    # year = 2011
for year in tqdm(range(starting_year, ending_year+1)):
    files = f'data/{sat}_{year}.parquet'
    output = f'data/{sat}_candidates_{year}_tau_{tau.seconds}.parquet'
    
    if os.path.exists(output):
        logger.info(f"Skipping {year} as the output file already exists.")
        continue


    data = pl.scan_parquet(files).set_sorted('time').collect()
    sat_fgm = df2ts(data, ["BX", "BY", "BZ"], attrs={"coordinate_system": coord, "units": "nT"})
    sat_state = df2ts(data, ["X", "Y", "Z"], attrs={"coordinate_system": coord, "units": "km"})

    indices = compute_indices(data, tau)
    # filter condition
    sparse_num = tau / data_resolution // 3
    filter_condition = get_ID_filter_condition(sparse_num = sparse_num)

    candidates = indices.filter(filter_condition).with_columns(pl_format_time(tau))
    
    ids = process_candidates(candidates, sat_fgm, sat_state)
    df = pandas.DataFrame(ids)
    df.to_parquet(output)
    # pandas.DataFrame(ids).to_parquet(output)



  0%|          | 0/6 [00:00<?, ?it/s][32m2023-09-21 02:46:32.823[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m15[0m - [1mSkipping 2011 as the output file already exists.[0m
[32m2023-09-21 02:46:32.826[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m15[0m - [1mSkipping 2012 as the output file already exists.[0m
[32m2023-09-21 02:46:32.826[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m15[0m - [1mSkipping 2013 as the output file already exists.[0m
[32m2023-09-21 02:46:32.827[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m15[0m - [1mSkipping 2014 as the output file already exists.[0m
[32m2023-09-21 02:46:32.827[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m15[0m - [1mSkipping 2015 as the output file already exists.[0m
[32m2023-09-21 02:46:32.828[0m | [1mINFO    [0m | [36m__main__[0m:[36m<cell line: 9>[0m:[36m15[0m - [1mSkipping 2016 

In [None]:
# Test different libraries to parallelize the computation
test = True
if test:
    pdp_test = ApplyToRows(
        lambda candidate: calc_candidate_duration(candidate, sat_fgm),  # fast a little bit
        # lambda candidate: calc_duration(get_candidate_data_xr(candidate, jno_fgm)),
        # lambda candidate: calc_duration(jno_fgm.sel(time=slice(candidate['tstart'], candidate['tstop']))),
        func_desc="calculating duration parameters",
    )
    candidates_pd = candidates.to_pandas()
    candidates_modin = pd.DataFrame(candidates_pd)
    
    # ---
    # successful cases
    # ---
    # candidates_pd.apply(lambda candidate: calc_candidate_duration(candidate, jno_fgm), axis=1) # Standard case: 37+s secs
    # candidates_pd.swifter.apply(calc_candidate_duration, axis=1, data=jno_fgm) # this works with dask, 80 secs
    # candidates_pd.swifter.set_dask_scheduler(scheduler="threads").apply(calc_candidate_duration, axis=1, data=jno_fgm) # this works with dask, 60 secs
    # candidates_pd.mapply(lambda candidate: calc_candidate_duration(candidate, jno_fgm), axis=1) # this works, 8 secs # not work? `DataFrame' object has no attribute 'mapply'
    # candidates_modin.apply(lambda candidate: calc_candidate_duration(candidate, jno_fgm), axis=1) # this works with ray, 8 secs # NOTE: can not work with dask
    # pdp_test(candidates_modin) # this works, 8 secs
    
    # ---
    # failed cases
    # ---
    # candidates_modin.apply(calc_candidate_duration, axis=1, data=jno_fgm) # AttributeError: 'DataFrame' object has no attribute 'sel'
    # pdp_test(candidates_modin) # TypeError: Unexpected type generated by applying a function to a DataFrame. Only Series and DataFrame are allowed.



Distributing Dataframe: 100%██████████ Elapsed time: 00:00, estimated remaining time: 00:00


#### THEMIS

In [None]:
sat = 'thb'
coord = 'gse'
tau = timedelta(seconds=60)
data_resolution = timedelta(seconds=4)

files = f'data/{sat}_fgs_{coord}.parquet'
output = f'data/{sat}_candidates_{year}_tau_{tau.seconds}.parquet'

if os.path.exists(output):
    logger.info(f"Skipping {year} as the output file already exists.")
    continue

data = pl.scan_parquet(files).set_sorted('time').collect()
sat_fgm = df2ts(data, ["BX", "BY", "BZ"], attrs={"coordinate_system": coord, "units": "nT"})
sat_state = pandas.read_parquet(f'data/{sat}_state.parquet')

indices = compute_indices(data, tau)
# filter condition
sparse_num = tau / data_resolution // 3
filter_condition = get_ID_filter_condition(sparse_num = sparse_num)

candidates = indices.filter(filter_condition).with_columns(pl_format_time(tau))

ids = process_candidates(candidates, sat_fgm, sat_state)
df = ids.to_pandas()
df.to_parquet(output)


Distributing Dataframe: 100%██████████ Elapsed time: 00:00, estimated remaining time: 00:00
  0%|          | 0/6 [01:39<?, ?it/s]


KeyboardInterrupt: 

### Pipelines

In [None]:
print(pdp_calc_duration)
if test:
    candidates = pdp_calc_duration.apply(candidates)

In [None]:
# Test: Inspect interesting candidates
if False:
    temp_candidates = candidates.loc[
        lambda df: df["d_tstart"].isnull() | df["d_tstop"].isnull()
    ]
    num = 28
    # temp_candidate = candidates.iloc[num]
    temp_candidate = temp_candidates.iloc[num]
    print(temp_candidate)
    plot_candidate(temp_candidate)

In [None]:
print(pdp_calibrate_duration)

if test: 
    temp_candidates = candidates.loc[
        lambda df: df["d_tstart"].isnull() | df["d_tstop"].isnull()
    ]
    if not temp_candidates.empty:
        display(temp_candidates)
        candidates.update(
            pdp_calibrate_duration.apply(temp_candidates)
        )  # This step is needed to classify the candidates

    candidates = pdp.DropNa()(candidates) # drop candidates with NaN values
    # pdp.RowDrop({"d_tstart": lambda x: pd.isnull(x), "d_tstop": lambda x: pd.isnull(x)}, reduce="all",)(candidates)
    # pdp.RowDrop([lambda x: pd.isnull(x)], reduce='all', columns=['d_tstart', 'd_tstop'])(candidates) # Notes: slower

In [None]:
print(pdp_classify_id)
if test: 
    candidates = pdp_classify_id(candidates)

In [None]:
print(pdp_calc_rotation_angle)
if test: 
    candidates = pdp_calc_rotation_angle(candidates)

In [None]:
print(pdp_assign_coordinates)
if test: 
    candidates = pdp_assign_coordinates(candidates)

In [None]:
pipelines = pdp_calc_duration + pdp_calibrate_duration+ pdp_classify_id +  pdp_calc_rotation_angle + pdp_assign_coordinates
print(pipelines)
# candidates = pipelines(info)

## Results

In [None]:
# read candidates from files in current directory
pattern = 'data/candidates*.parquet'
data = Path() / pattern

candidates = pl.scan_parquet(data).collect()

In [None]:
len(candidates.columns)

### Plotting candidates of different types of discontinuities

In [None]:
plot_candidates(candidates, candidate_type='TD')

In [None]:
plot_candidates(candidates, candidate_type='RD')

In [None]:
plot_candidates(candidates, candidate_type='ED')

In [None]:
plot_candidates(candidates, candidate_type='ND')

### Occurrence rates

In [None]:
# calculate the occurence rates of different types of ID
def occurence_rate(candidates, candidate_type):
    return len(candidates[candidates['type'] == candidate_type]) / len(candidates)

def time_occurence_rate(candidates):
    if len(candidates) <= 1:
        return None
    else:
        return (candidates.iloc[-1]['tstop'] - candidates.iloc[0]['tstart']) / (len(candidates) -1)

CANDIDATE_TYPES = ['RD', 'TD', 'ED', 'ND']

for candidate_type in CANDIDATE_TYPES:
    logger.info(f"Occurrence rate of {candidate_type}: {occurence_rate(candidates, candidate_type)}")
    logger.info(f"Time occurrence rate of {candidate_type}: {time_occurence_rate(candidates[candidates['type'] == candidate_type])}")

In [None]:
pdp.ColByFrameFunc("R", lambda df: df[['X','Y', 'Z']].apply(np.linalg.norm, axis=1), func_desc='calculating R')(candidates)

In [None]:
candidates.plot(x="X", y="d_star")

### Duration

In [None]:
# candidates.update(pdp_calibrate_duration.apply(temp_candidates))

### Waiting time

### Amplitude

In [None]:
temp_candidates = get_candidates(candidates, 'RD')
temp_candidates = pdp_calc_duration(temp_candidates)
temp_candidates

In [None]:
plot_candidates(temp_candidates)

In [None]:
# test minvar and principal axes vectors
test_data = np.array([[1,1,0],[-1,-1,0]])
vrot, v, w = minvar(test_data)
Vi = v[:,0]
print(Vi)

In [None]:
# test minvar_matrix_make
in_var_name = "fgm"

vrot, v, w = minvar(get_data(in_var_name, xarray=True))

minvar_matrix_make(in_var_name)
tvector_rotate(f'{in_var_name}_mva_mat', in_var_name)
(get_data(f"{in_var_name}_rot").y==vrot).all()

### Cases

In [None]:
# plot_candidates(candidates.loc[lambda _: _['time']=='2012-07-10 02:31:15'])
temp_trange = ['2012-07-15 03:44', '2012-07-15 03:47']
temp_data = sat_fgm.sel(time=slice(*temp_trange))
temp_data.plot.scatter(x='time', hue='v_dim')
# temp_data.resample(time=pd.Timedelta(tau, unit='s')).map(calc_vec_std)
compute_index_std(temp_data, tau)

#### Case: neighboring data is missing, causing the calculation of the standard deviation index to be Inf

In [None]:
# Case: neighboring data is missing, causing the calculation of the standard deviation index to be Inf.
temp_trange = ['2012-07-10 02:30', '2012-07-10 02:32']
temp_data = sat_fgm.sel(time=slice(*temp_trange))
temp_data.plot.scatter(x='time', hue='v_dim')
# temp_data.resample(time=pd.Timedelta(tau, unit='s')).map(calc_vec_std)
compute_index_std(temp_data, tau)

#### Caveats

In [None]:
plot_candidates(candidates)

##### NOTE: Not very accurate for waving magnetic field...

In [None]:
temp_candidate = {'time': Timestamp('2012-05-01 00:39:12'),
 'tstart': Timestamp('2012-05-01 00:38:56'),
 'tstop': Timestamp('2012-05-01 00:39:28'),
 'i1': 2.891042053414383,
 'i2': 2.389699609352786,
 'i3': 1.3916002784658887,
 'd_star': 0.27143595,
 'd_time': Timestamp('2012-05-01 00:39:18.672000'),
 'd_tstart': Timestamp('2012-05-01 00:39:14.672000'),
 'd_tstop': Timestamp('2012-05-01 00:39:19.671000'),
}

plot_candidate(temp_candidate)

In [None]:
data = get_candidate_data_xr(temp_candidate, neighbor=1)
vec_diff = data.differentiate("time", datetime_unit="s", edge_order=2).isel(time=slice(1,-1))
vec_diff_mag = linalg.norm(vec_diff, dims='v_dim')
vec_diff_mag.plot()

##### NOTE: Small threshold_ratio values will tend to make the duration longer if the duration can be determined.


In [None]:
# Test different threshold ratios
threshold_ratios = [1/8, 1/4, 0.3, 1/3, 1/2]
for threshold_ratio in threshold_ratios:
    temp_candidate.update(calc_duration(get_candidate_data_xr(temp_candidate), threshold_ratio=threshold_ratio).to_dict())
    plot_candidate(temp_candidate)