In [None]:
# | default_exp ds.tplot.plot
# | export
from pydantic import (
    BaseModel,
    field_validator,
)
from space_analysis.ds.config import PanelConfig, Config

import pyspedas
from cdasws import CdasWs

from pytplot import tplot, options
import matplotlib.pyplot as plt
from loguru import logger

from matplotlib.pyplot import Figure, Axes

In [None]:
# | export
class TplotConfig(BaseModel):
    tvar: str = None
    trans: list[str] = None


class GraphicalConfig(BaseModel):
    ylabel: str = None

In [None]:
# | export
class ProcessConfig(BaseModel):
    tvar: str = None
    trans: list = list()

    @field_validator("trans", mode="before")
    @classmethod
    def check_transforms(cls, v: list[dict | str]):
        return [({"name": tran} if isinstance(tran, str) else tran) for tran in v]

In [None]:
# | export
def update_panel(ax: Axes, config: PanelConfig):
    pass


def plot(
    tvars2plot: list[str],
    config: Config,
    fig: Figure = None,
    axes: list[Axes] = None,
    **kwargs,
) -> tuple[Figure, list[Axes]]:
    if not isinstance(tvars2plot, list):
        tvars2plot = [tvars2plot]

    panel_configs = config.panels
    output_config = config.output

    if fig is None or axes is None:
        nrows = len(tvars2plot)
        fig, axes = plt.subplots(nrows=nrows, sharex=True, **kwargs)
        axes = [axes] if isinstance(axes, Axes) else axes

    for ax, tvar, panel_config in zip(axes, tvars2plot, panel_configs):
        tplot(tvar, fig=fig, axis=ax, display=False)
        update_panel(ax, panel_config)

    fig.set(**output_config.figure)
    output_config.figure_extra.process(fig, axes)

    return fig, axes


def export(tvars2plot: list, config: Config, plot_kwargs: dict = None, **kwargs):
    fig, axes = plot(tvars2plot, config, **plot_kwargs)

    output_config = config.output
    path = output_config.path

    for fmt in output_config.formats:
        match fmt:
            case "csv":
                export2csv(tvars2plot, path)
            case "display":
                fig.show()
            case _:
                fig.savefig(f"{path}.{fmt}", **kwargs)

    return fig, axes

In [None]:
# | export
def process(tvar: str | list[str], config: ProcessConfig):
    for tran in config.trans:
        tvar = tran.transform_func(
            tvar, **tran.model_dump(exclude=["name", "transform_func"])
        )
    return tvar

    # if config.trans:
    #     if "slice-1" in config.trans:
    #         tvar = split_vec(tvar)[:1]
    #     if "slice-3" in config.trans:
    #         tvar = join_vec(split_vec(tvar)[:3])
    #     if "mva" in config.trans:
    #         minvar_matrix_make(tvar)
    #         tvar = tvector_rotate(tvar + "_mva_mat", tvar)[0]
    #         legend_names = [r"$B_l$", r"$B_m$", r"$B_n$"]

    #     if "magnitude" in trans:
    #         tvar2plot = tvectot(tvar, join_component=False)
    #         options(tvar2plot, "legend_names", None)

    #     if "magnitude_join" in trans:
    #         tvar2plot = tvectot(tvar, join_component=True)
    #         legend_names = legend_names + [r"$B_{total}$"]
    #         options(tvar2plot, "legend_names", legend_names)

    # else:
    #     tvar2plot = tvar

In [None]:
# | export
def load_data(config: PanelConfig, load_func=None):
    if isinstance(config, list):
        return [load_data(c) for c in config]

    timerange = [time.isoformat() for time in config.timerange]
    var = config.id

    if load_func is None:
        if config.satellite and config.instrument:
            mod = getattr(pyspedas, config.satellite)
            load_func = getattr(mod, config.instrument)
        elif config.ds:
            cdas = CdasWs()
            status, data = cdas.get_data(config.ds, var, timerange[0], timerange[1])
            pytplot.store_data(var, {"x": data[var].Epoch, "y": data[var]})
            return var
        else:
            logger.error("No load function provided")
            return None

    load_args = {
        "trange": timerange,
        "time_clip": True,
        "varnames": config.id,
    }

    # Conditionally add the 'datatype' and 'probe' argument
    if config.datatype is not None:
        load_args["datatype"] = config.datatype
    if config.probe is not None:
        load_args["probe"] = config.probe

    return load_func(**load_args)


In [None]:
#| export
def update_tvar(tvar, config: PanelConfig):
    options(tvar, "thick", 2)
    # options(tvar, "char_size", 16)
    if config.name is not None:
        options(tvar, "ytitle", f"{config.name}")
    if config.units is not None:
        if config.units == "":
            options(tvar, "ysubtitle", "")
        options(tvar, "ysubtitle", f"[{config.units}]")

    return tvar


def process_panel(
    config: PanelConfig,
    process_func=process,
    load_func=load_data,
    update_func=update_tvar,
):
    tvar = load_func(config)

    if isinstance(tvar, list):
        tvar = tvar[0]

    tvar_processed = process_func(tvar, config=config.process)
    logger.debug(f"Processed tvar: {tvar_processed}")
    return update_func(tvar_processed, config=config)