---
title: Pipelines
subtitle: The data flow abstraction
---

We are using `Kedro` to build a data pipeline. A pipeline is a collection of nodes that are connected to each other. Each node is a function that takes inputs and produces outputs. The inputs and outputs are data sets of different layer/level.

This notebook mainly demonstrate the concept and common building blocks of a pipeline, see each mission notebook for implementation details.

In [None]:
#| default_exp pipelines/default

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import show_doc

In [None]:
#| hide
import polars as pl
from typing import Any, Dict, List, Tuple, Union

In [None]:
# Kerdo
from kedro.pipeline import Pipeline, node
from kedro.pipeline.modular_pipeline import pipeline

## Magnetic field data pipeline

### Downloading data

In [None]:
def download_mag_data(
    start: str = None,
    end: str = None,
    ts: str = None,  # time resolution
    probe: str = None,
    coord: str = None,
):
    ...


### Preprocessing data

Some common preprocessing steps are:

- Partition data by year, see `ids_finder.utils.basic.partition_data_by_year`

In [None]:
def preprocess_mag_data(
    raw_data: Any = None,
    start: str = None,
    end: str = None,
    ts: str = None,  # time resolution
    coord: str = None,
) -> pl.DataFrame:
    """
    Preprocess the raw dataset (only minor transformations)

    - Applying naming conventions for columns
    - Parsing and typing data (like from string to datetime for time columns)
    - Structuring the data (like pivoting, unpivoting, etc.)
    - Changing storing format (like from `csv` to `parquet`)
    - Dropping null columns
    - Resampling data to a given time resolution
    - ... other 'transformations' commonly performed at this stage.
    """
    pass


### Processing data

Note: we process the data every year to minimize the memory usage and to avoid the failure of the processing (so need to process all the data again if only fails sometimes).

In [None]:
#| exports

def process_mag_data(
    raw_data: Any | pl.DataFrame,
    ts: str = None,  # time resolution
    coord: str = None,
) -> pl.DataFrame | Dict[str, pl.DataFrame]:
    """
    Corresponding to primary data layer, where source data models are transformed into domain data models

    - Transforming data to RTN (Radial-Tangential-Normal) coordinate system
    - Smoothing data
    - Resampling data to a given time resolution
    - Partitioning data, for the sake of memory
    """
    pass

def extract_features():
    pass

#### Pipeline

In [None]:
# | exports
def create_mag_data_pipeline(
    sat_id: str,  # satellite id, used for namespace
    ts: str = '1s',  # time resolution,
    tau: str = '60s',  # time window
    **kwargs,
) -> Pipeline:
    
    node_download_mag_data = node(
        download_mag_data,
        inputs=dict(
            start="params:start_date",
            end="params:end_date",
        ),
        outputs=f"raw_mag",
        name=f"download_{sat_id.upper()}_magnetic_field_data",
    )

    node_preprocess_mag_data = node(
        preprocess_mag_data,
        inputs=dict(
            raw_data=f"raw_mag",
            start="params:start_date",
            end="params:end_date",
        ),
        outputs=f"inter_mag_{ts}",
        name=f"preprocess_{sat_id.upper()}_magnetic_field_data",
    )

    node_process_mag_data = node(
        process_mag_data,
        inputs=f"inter_mag_{ts}",
        outputs=f"primary_mag_rtn_{ts}",
        name=f"process_{sat_id.upper()}_magnetic_field_data",
    )

    node_extract_features = node(
        extract_features,
        inputs=[f"primary_mag_rtn_{ts}", "params:tau", "params:extract_params"],
        outputs=f"feature_tau_{tau}",
        name=f"extract_{sat_id}_features",
    )

    nodes = [
        node_download_mag_data,
        node_preprocess_mag_data,
        node_process_mag_data,
        node_extract_features,
    ]

    pipelines = pipeline(
        nodes,
        namespace=sat_id,
        parameters={
            "params:start_date": "params:jno_start_date",
            "params:end_date": "params:jno_end_date",
            "params:tau": tau,
        },
    )

    return pipelines

In [None]:
class DatasetConfig:
    def __init__(self, sat_id, download_func, preprocess_func, process_func):
        self.sat_id = sat_id
        self.download_func = download_func
        self.preprocess_func = preprocess_func
        self.process_func = process_func

class PipelineGenerator:
    def __init__(self, config: DatasetConfig, ts='1s', tau='60s'):
        self.config = config
        self.ts = ts
        self.tau = tau

    def _node(self, func, inputs, outputs, name):
        return node(func, inputs=inputs, outputs=outputs, name=name)

    def generate_pipeline(self):
        node_download = self._node(
            self.config.download_func,
            inputs=dict(start="params:start_date", end="params:end_date"),
            outputs=f"raw_data_{self.ts}",
            name=f"download_{self.config.sat_id.upper()}_data"
        )

        node_preprocess = self._node(
            self.config.preprocess_func,
            inputs=dict(raw_data=f"raw_data_{self.ts}", start="params:start_date", end="params:end_date"),
            outputs=f"inter_data_{self.ts}",
            name=f"preprocess_{self.config.sat_id.upper()}_data"
        )

        node_process = self._node(
            self.config.process_func,
            inputs=f"inter_data_{self.ts}",
            outputs=f"primary_data_rtn_{self.ts}",
            name=f"process_{self.config.sat_id.upper()}_data"
        )

        node_extract = self._node(
            extract_features,
            inputs=[f"primary_data_rtn_{self.ts}", "params:tau", "params:extract_params"],
            outputs=f"feature_tau_{self.tau}",
            name=f"extract_{self.config.sat_id}_features"
        )

        return pipeline(
            [node_download, node_preprocess, node_process, node_extract],
            namespace=self.config.sat_id,
            parameters={"params:start_date": "params:jno_start_date", "params:end_date": "params:jno_end_date", "params:tau": self.tau}
        )


## State data pipeline

In [None]:
def get_state_data(tstart=None, tend=None, raw_data=None, columns=None, **kwargs):
    """Get the state data with proper column names and types in RTN coordinates."""
    pass


def processs_state_data(df: pl.DataFrame) -> pl.DataFrame:
    pass

In [None]:
def create_state_data_pipeline(sat_id, **kwargs) -> Pipeline:
    node_get_state_data = node(
        get_state_data,
        inputs={
            "tstart": "params:start_date",
            "tend": "params:end_date",
            "raw_data": None,
            "columns": None,
        },
        outputs="inter_state_rtn_1h",
        name=f"get_{sat_id.upper()}_state_data",
    )

    node_processs_state_data = node(
        processs_state_data,
        inputs="inter_state_rtn_1h",
        outputs="primary_state_rtn_1h",
        name=f"process_{sat_id.upper()}_state_data",
    )

    nodes = [node_get_state_data, node_processs_state_data]
    pipelines = pipeline(
        nodes,
        namespace=sat_id,
        parameters={
            "params:start_date": "params:jno_start_date",
            "params:end_date": "params:jno_end_date",
        },
    )

    return pipelines

## Candidate pipeline

In [None]:
def combine_features(df: pl.DataFrame, state: pl.DataFrame) -> pl.DataFrame:
    pass

def create_candidate_pipeline(sat_id, **kwargs) -> Pipeline:
    time_resolution = "1s"

    node_combine_features = node(
        combine_features,
        inputs=[
            f"{sat_id}.feature_rtn_{time_resolution}",
            f"{sat_id}.primary_state_rtn_1h",
        ],
        outputs=f"candidates.{sat_id}_{time_resolution}",
    )

    nodes = [node_combine_features]
    return pipeline(nodes)

In [None]:
#| exports
def combine_candidates(dict):
    pass

# node_thm_extract_features = node(
#     extract_features,
#     inputs=["primary_thm_rtn_1s", "params:tau", "params:thm_1s_params"],
#     outputs="candidates_thm_rtn_1s",
#     name="extract_ARTEMIS_features",
# )

# node_combine_candidates = node(
#     combine_candidates,
#     inputs=dict(
#         sta_candidates="candidates_sta_rtn_1s",
#         jno_candidates="candidates_jno_ss_se_1s",
#         thm_candidates="candidates_thm_rtn_1s",
#     ),
#     outputs="candidates_all_1s",
#     name="combine_candidates",
# )

In [None]:
def create_pipeline(**kwargs) -> Pipeline:
    sat_id = "sta"
    return (
        create_mag_data_pipeline(sat_id)
        + create_state_data_pipeline()
        + create_candidate_pipeline(sat_id)
    )