---
title: Models
subtitle: Validation of MHD data against observations for JUNO mission
format:
  html:
    code-fold: true
---

## Setup

Need to run command in shell first as `pipeline` is project-specific command

```{sh}
kedro pipeline create model
```

`kedro run --to-outputs=jno.primary_state_rtn_1h`

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
from loguru import logger
from typing import Union, Collection, Callable, Optional, Tuple
from typing import Any, Dict


import numpy as np
import polars as pl
import pandas

#### `Kedro`

In [None]:
from kedro.pipeline import Pipeline, node
from kedro.pipeline.modular_pipeline import pipeline


In [None]:
from ids_finder.utils.basic import load_catalog

In [None]:
catalog = load_catalog()
catalog.list()

## Process model data

In [None]:
def overview(df: pl.DataFrame):
    """Overview of the data"""
    df_pd = df.to_pandas()
    df_pd.hvplot(x="time", y=bcols_rtn)
    
    b_fig = df_pd.hvplot.line(x="time", y=bcols_rtn)
    v_fig = df_pd.hvplot.line(x="time", y=vcols_rtn)
    rho_fig = df_pd.hvplot.line(x="time", y="rho", logy=True)
    Ti_fig = df_pd.hvplot.line(x="time", y="Ti", logy=True)
    return (b_fig + v_fig + rho_fig + Ti_fig).cols(1).opts(shared_axes=False)


# jno_mswim2d_rtn.pipe(overview)

In [None]:
catalog.load('jno.primary_state_rtn_1h')

radial_distance,plasma_density,plasma_temperature,time,model_b_r,model_b_t,model_b_n,sw_vel_r,sw_vel_t,sw_vel_n
f64,f64,i64,datetime[ns],f64,f64,f64,f64,f64,f64
1.004,2.77562,76394,2011-08-25 00:00:00,-1.214485,2.724603,0.578,442.731905,1.438722,-0.0
1.003,2.82988,76600,2011-08-25 01:00:00,-1.101071,2.656081,0.523,440.550352,-6.464314,-0.0
1.003,2.94476,74014,2011-08-25 02:00:00,-0.721961,2.640019,0.949,441.072265,-7.503789,-0.0
1.004,2.89843,70426,2011-08-25 03:00:00,-0.69794,2.757097,1.18,439.112842,-4.326845,0.0
1.004,2.71252,68131,2011-08-25 04:00:00,-0.902173,3.002557,1.08,437.440827,0.11421,-0.0
1.004,2.89406,73654,2011-08-25 05:00:00,-0.715813,2.763938,0.807,437.643073,1.68236,-0.1
1.004,3.1065,80005,2011-08-25 06:00:00,-0.824154,2.315087,0.336,435.357101,-5.304179,-0.1
1.003,3.29236,78831,2011-08-25 07:00:00,-0.680414,2.02755,0.576,434.080668,-11.782328,0.0
1.004,3.45049,76311,2011-08-25 08:00:00,-0.418169,2.058053,1.08,434.591819,-12.90159,0.0
1.004,3.43728,75513,2011-08-25 09:00:00,-0.407522,2.487763,1.27,433.637723,-8.415747,0.0


## Compare JUNO data with model

We are using juno 1min data to compare with model data

In [None]:
from ids_finder.utils.basic import resample
from ids_finder.utils.polars import pl_norm

In [None]:
from ids_finder.pipelines.juno.pipeline import download_juno_data, preprocess_jno

In [None]:
def create_jno_data_pipeline(**kwargs) -> Pipeline:
    nodes = [
        node(download_juno_data, inputs=None, outputs="raw_jno_ss_se_1min", name="download_JUNO_data_1min"),
        node(preprocess_jno, inputs="raw_jno_ss_se_1min", outputs="preprocessed_jno_ss_se_1min", name="preprocess_JUNO_node_1min",),
    ]
    return pipeline(nodes, namespace="model")

In [None]:
#| code-summary: load jno data and resample to 1h to match model resolution
preprocessed_jno_ss_se_1min: pl.DataFrame = catalog.load('preprocessed_jno_ss_se_1min')
jno_ss_se_1min = preprocessed_jno_ss_se_1min.lazy().rename(
    {"BX SE": "br", "BY SE": "bt", "BZ SE": "bn"}
)

jno_ss_se_1h = jno_ss_se_1min.pipe(resample, every="1h", period='2h')

In [None]:
import plotly.graph_objects as go;
from plotly_resampler import register_plotly_resampler
from plotly_resampler import FigureResampler
import plotly.express as px

In [None]:
jno_mswim2d_1h = processed_jno_mswim2d.lazy()

In [None]:
def _tf(df: pl.DataFrame):
    "temporal function to select interesting columns and add norm"
    cols = ["time", "br", "bt", "bn"]

    return df.select(cols).with_columns(
        b=pl_norm(bcols_rtn),
    )


jno_joint_1h_wide: pl.DataFrame = (
    jno_ss_se_1h.pipe(_tf).join(
        jno_mswim2d_1h.pipe(_tf),
        on="time",
        suffix="_model",
    )
).collect()

jno_joint_1h_long = pl.concat(
    [
        jno_ss_se_1h.pipe(_tf).with_columns(type=pl.lit("1h")),
        jno_mswim2d_1h.pipe(_tf).with_columns(type=pl.lit("1h_model")),
    ]
).collect()

#### Data Porfiling

Results are showed in the following links

[Timeseries Report Result](jno_model_ts.html)

[Comparison Report Result](jno_model_comparison.html)

In [None]:
#| eval: false
from ydata_profiling import ProfileReport, compare

In [None]:
from fastcore.utils import threaded

In [None]:
@threaded
def get_report_t(df: pl.DataFrame, output, **kwargs):
    '''get report and save to file in a thread
    '''
    get_report(df, **kwargs).to_file(output)
    return output

def get_report(df: pl.DataFrame, **kwargs):
    return ProfileReport(
        df.to_pandas().set_index("time"), **kwargs
    )

def get_comparison_report(df: pl.DataFrame, compare_col=None, tsmode=False, **kwargs):
    
    dfs_dict: Dict[str, pl.DataFrame] = df.partition_by(compare_col, as_dict=True)
    
    if tsmode:
        raise NotImplementedError("tsmode for comparison is not implemented yet in `ydata_profiling`")
        # UnionMatchError: can not match type "list" to any type of "time_index_analysis.period" union: typing.Union[float,  
        
        # Notes: for `tsmode`, we need to match the time first
        # select common timestamps
        from functools import reduce
        basetimestamps = reduce(np.intersect1d, [df.get_column('time') for df in dfs_dict.values()])
        dfs_dict = {
            k: df.filter(pl.col("time").is_in(basetimestamps))
            for k, df in dfs_dict.items()
        }

        for k, df in dfs_dict.items():
            logger.info(f"{k}: {len(df)}")
    
    comparison_report = compare(
        [get_report(df, title=k, **kwargs) for k, df in dfs_dict.items()]
    )
    
    # Obtain merged statistics
    comparison_report.get_description()

    return comparison_report

In [None]:
get_report_t(
    jno_joint_1h_wide,
    output="jno_model_ts.html",
    tsmode=True,
    title="JUNO Model Timeseries Report",
)

In [None]:
get_comparison_report(jno_joint_1h_long, compare_col="type").to_file(
    "jno_model_comparison.html"
)

## Validation

#### Connect `python` with `R` kernel

In [None]:
%load_ext rpy2.ipython

from ids_finder.utils.r import py2rpy_polars
conv_pl = py2rpy_polars()

In [None]:
%%R
library(ggplot2)
library(ggpubr)
library(viridis)

library(glue)
library(arrow)

In [None]:
import warnings

In [None]:
#| column: screen
fig = px.line(
    jno_joint_1h_long.sort("time"),
    x="time",
    y="b",
    color="type",
)
fig

### Compare directly with scatter plot

In [None]:
def bb_jointplot(data):
    g = sns.jointplot(
        x = 'b',
        y = 'b_model',
        data = data,
        kind = 'hist',
    )
    return g

In [None]:
g = bb_jointplot(jno_joint_1h_wide)
g.ax_marg_x.set_xlim(0, 5)
g.ax_marg_y.set_ylim(0, 5)

In [None]:
%%R -i jno_joint_1h_wide -c conv_pl
p1 <- ggplot(jno_joint_1h_wide, aes(x=b, y=b_model) ) +
  geom_bin2d() +
  geom_density_2d( colour="white" ) +
  scale_fill_continuous(trans="log", type = "viridis") +
  stat_regline_equation() + 
  xlim(-0.1, 10) +  # Set x-axis limits
  ylim(-0.1, 10) +  # Set y-axis limits
  theme_pubr(legend = 'right')
  # theme(legend.position = c(0.8,0.8))
  # stat_density_2d(aes(fill = ..level..), geom = "polygon", colour="white")

p1

#### Test: remove outliers

In [None]:
#| eval: false
from pyod.models.ecod import ECOD


In [None]:
data = jno_joint_1h_wide[['b', 'b_model']]

clf = ECOD()
clf.fit(data)

In [None]:
y_train_scores = clf.decision_scores_  # raw outlier scores on the train data
y_train_pred = clf.labels_  # binary labels (0: inliers, 1: outliers)

In [None]:
bb_jointplot(jno_joint_1h_wide.filter(y_train_pred==1))
data = jno_joint_1h_wide.filter(y_train_pred==0)
g = sns.jointplot(
    x = 'b',
    y = 'b_model',
    data = data,
    kind = 'hist',
)

g.plot_joint(sns.kdeplot, color="r", zorder=0, levels=6)


In [None]:
# def create_pipeline(**kwargs) -> Pipeline:
    # return create_jno_model_pipeline(**kwargs) + create_jno_data_pipeline(**kwargs)