In [1]:
# | default_exp utils/speasy
# | export
import speasy as spz
import polars as pl

from pydantic import model_validator
from functools import cached_property
from space_analysis.core import Variables as Vs
from space_analysis.core import Variable as V

from speasy.core.dataprovider import DataProvider
from speasy import SpeasyVariable
from speasy.core.inventory import DatasetIndex, ParameterIndex

from fastcore.all import patch

import matplotlib.pyplot as plt
from matplotlib.pyplot import Axes

In [2]:
# | export
def spzvar2pldf(var: SpeasyVariable):
    # see SpeasyVariable.to_dataframe
    var = var.replace_fillval_by_nan()
    return (
        pl.DataFrame(var.values, schema=var.columns)
        .with_columns(time=pl.Series(var.time))
        .lazy()
    )  # Need to `lazy` last or ShapeError: unable to add a column of length xxxx to a DataFrame of height yyyy


def spzvars2pldf(vars: list[SpeasyVariable]):
    # join all dataframes into a single one on the time column
    if len(vars) == 1:
        return spzvar2pldf(vars[0])
    return pl.concat([spzvar2pldf(var) for var in vars], how="align")

In [3]:
# | export
@patch
def time_resolutions(self: SpeasyVariable):
    return pl.Series(self.time).diff().describe()

In [4]:
# | export
def get_provider(v: str) -> DataProvider:
    if v == "cda":
        return spz.cda
    else:
        return spz.amda


def get_dataset_index(v: str, provider: str = "cda") -> DatasetIndex:
    return get_provider(provider).flat_inventory.datasets[v]


def get_dataset_parameters(v: str, provider: str = "cda"):
    # return vars(get_dataset_index(v, provider)).values()
    ds_info = vars(get_dataset_index(v, provider))
    return [member for member in ds_info.values() if isinstance(member, ParameterIndex)]


def get_parameter_index(param: str, ds: str) -> ParameterIndex:
    ds_info = vars(get_dataset_index(ds))
    return ds_info[param]

In [5]:
# | export
class Variable(V):

    def to_polars(self):
        return spzvar2pldf(self.data)

    @cached_property
    def data(self) -> SpeasyVariable:
        return spz.get_data(self.product, self.timerange)

    @property
    def time_resolutions(self) -> pl.DataFrame:
        return self.data.time_resolutions()

    def plot(self, fig=None, ax: Axes = None):
        if fig is None and ax is None:
            fig, ax = plt.subplots()

        self.data.replace_fillval_by_nan().plot(ax=ax)

        if self.name:
            ax.set_ylabel(self.name)

        return fig, ax

In [6]:
# | export
class Variables(Vs):
    variables: list[Variable] = None
    products: list[str | ParameterIndex] = None

    # initize products from provider and dataset if not provided
    @model_validator(mode="after")
    def check_products(self):
        if self.products is None and self.dataset:
            if self.parameters:
                self.products = [
                    f"{self.provider}/{self.dataset}/{var}" for var in self.parameters
                ]

            else:
                self.products = get_dataset_parameters(self.dataset, self.provider)
                self.parameters = [member.spz_name() for member in self.products]
        return self

    @model_validator(mode="after")
    def check_variables(self):
        if self.variables is None:
            self.variables = [Variable(product=product) for product in self.products]
        # set the same timerange for all variables
        for var in self.variables:
            var.timerange = self.timerange
        return self

    @property
    def data(self) -> list[SpeasyVariable]:
        return [var.data for var in self.variables]

    def to_polars(self):
        return spzvars2pldf(self.data)

    def plot(self, gridspec_kw: dict = {"hspace": 0}):
        vars = self.variables

        fig, axes = plt.subplots(nrows=len(vars), sharex=True, gridspec_kw=gridspec_kw)
        axes: list[Axes] = axes if len(vars) > 1 else [axes]

        for var, ax in zip(vars, axes):
            var.plot(ax=ax)

        return fig, axes

    @property
    def time_resolutions(self) -> pl.DataFrame:
        return self.get_data()[0].time_resolutions()

In [8]:
from rich import print

print(
    Variables(
        variables=[
            {
                "product": "cda/PSP_FLD_L3_RFS_LFR_QTN/N_elec",
                "name": "RFS LFR QTN\nElectron Density\n[$cm^{{-3}}$]",
            }
        ]
    )
)

In [None]:
earth_start = "2019-04-09"
earth_end = "2019-04-12"
timerange = [earth_start, earth_end]

d = spz.get_data(
    get_parameter_index("MOM.P.MAGF", "WI_PLSP_3DP"),
    timerange,
)

In [6]:
# | hide
from nbdev import nbdev_export

nbdev_export()

### Test


In [8]:
timerange = ["2019-04-07T01:00", "2019-04-07T12:00"]
vars = Variables(
    dataset="PSP_FLD_L2_MAG_RTN",
    parameters=["psp_fld_l2_mag_RTN"],
    timerange=timerange,
).retrieve_data()

In [12]:
vars.data[0].time_resolutions(), vars.time_resolutions

(shape: (8, 2)
 ┌────────────┬────────────────┐
 │ statistic  ┆ value          │
 │ ---        ┆ ---            │
 │ str        ┆ str            │
 ╞════════════╪════════════════╡
 │ count      ┆ 5800765        │
 │ null_count ┆ 1              │
 │ mean       ┆ 0:00:00.006826 │
 │ min        ┆ 0:00:00.006690 │
 │ 25%        ┆ 0:00:00.006826 │
 │ 50%        ┆ 0:00:00.006826 │
 │ 75%        ┆ 0:00:00.006826 │
 │ max        ┆ 0:00:00.006935 │
 └────────────┴────────────────┘,
 shape: (8, 2)
 ┌────────────┬────────────────┐
 │ statistic  ┆ value          │
 │ ---        ┆ ---            │
 │ str        ┆ str            │
 ╞════════════╪════════════════╡
 │ count      ┆ 5800765        │
 │ null_count ┆ 1              │
 │ mean       ┆ 0:00:00.006826 │
 │ min        ┆ 0:00:00.006690 │
 │ 25%        ┆ 0:00:00.006826 │
 │ 50%        ┆ 0:00:00.006826 │
 │ 75%        ┆ 0:00:00.006826 │
 │ max        ┆ 0:00:00.006935 │
 └────────────┴────────────────┘)

In [53]:
def data_provider_summary(data_provider: DataProvider = spz.cda):
    # show the name of the data_provider, the number of datasets, parameters and catalogs

    inventory = data_provider.flat_inventory
    print("Data Provider:", data_provider.provider_name)
    print("Datasets:", len(inventory.datasets))
    print("Parameters:", len(inventory.parameters))
    print("Catalogs:", len(inventory.catalogs))


# data_provider_summary(spz.cda)
# data_provider_summary(spz.amda)
# data_provider_summary(spz.csa)

Data Provider: cda
Datasets: 2608
Parameters: 58510
Catalogs: 0
Data Provider: amda
Datasets: 1074
Parameters: 5397
Catalogs: 24
Data Provider: csa
Datasets: 912
Parameters: 1993
Catalogs: 0


In [None]:
from fastcore.utils import patch
from speasy.products import SpeasyVariable
from humanize import naturalsize

In [None]:
@patch
def preview(self: SpeasyVariable):
    print("===========================================")
    print(f"Name:         {self.name}")
    print(f"Columns:      {self.columns}")
    print(f"Values Unit:  {self.unit}")
    print(f"Memory usage: {naturalsize(self.nbytes)}")
    print(f"Axes Labels:  {self.axes_labels}")
    print("-------------------------------------------")
    print(f"Meta-data:    {self.meta}")
    print("-------------------------------------------")
    print(f"Time Axis:    {self.time[:3]}")
    print("-------------------------------------------")
    print(f"Values:       {self.values[:3]}")
    print("===========================================")