---
title: Speasy utils
---

[speasy](https://speasy.readthedocs.io/en/latest/) - an open source Python client for Space Physics web services such as CDAWEB or AMDA.

In [None]:
# | default_exp utils/speasy
# | export
import speasy as spz
import xarray as xr
import polars as pl

from fastcore.all import patch
from pydantic import model_validator, ConfigDict
from pydantic.dataclasses import dataclass
from functools import cached_property
from space_analysis.core import Variables as Vs
from space_analysis.core import Variable as V

from speasy.core.dataprovider import DataProvider
from speasy import SpeasyVariable
from speasy.core.inventory import DatasetIndex, ParameterIndex

import matplotlib.pyplot as plt
from matplotlib.pyplot import Axes

from humanize import naturalsize

## Converting to other data structures

In [None]:
# | export
def is_scalar(v: SpeasyVariable):
    """
    Related issue: [Scalar timeseries dimension · SciQLop/speasy](https://github.com/SciQLop/speasy/issues/149)
    """
    return len(v.shape) == 2 and v.shape[1] == 1


def to_dataarray(v: SpeasyVariable):
    """
    Notes: scalar timeseries of `ndim==2` is a design choice to be consistent with what Pandas does.
    """
    time = xr.DataArray(v.time, dims="time")
    attrs = dict(v.meta, units=v.unit)
    if is_scalar(v):
        values, coords = v.values.squeeze(), [time]
    else:
        values, coords = v.values, [time, v.columns]
    return xr.DataArray(values, coords=coords, name=v.name, attrs=attrs)


def to_dataarrays(vs: list[SpeasyVariable]):
    return [to_dataarray(v) for v in vs]

In [None]:
# | export
def spzvar2pldf(var: SpeasyVariable):
    # see SpeasyVariable.to_dataframe
    var = var.replace_fillval_by_nan()
    return (
        pl.DataFrame(var.values, schema=var.columns)
        .with_columns(time=pl.Series(var.time))
        .lazy()
    )  # Need to `lazy` last or ShapeError: unable to add a column of length xxxx to a DataFrame of height yyyy


def spzvars2pldf(vars: list[SpeasyVariable]):
    # join all dataframes into a single one on the time column
    if len(vars) == 1:
        return spzvar2pldf(vars[0])
    return pl.concat([spzvar2pldf(var) for var in vars], how="align")

## Functions

In [None]:
# | export
def get_time_resolution(data: SpeasyVariable):
    return pl.Series(data.time).diff().describe()

In [None]:
# | export
def get_provider(v: str) -> DataProvider:
    if v == "cda":
        return spz.cda
    else:
        return spz.amda


def get_dataset_index(v: str, provider: str = "cda") -> DatasetIndex:
    return get_provider(provider).flat_inventory.datasets[v]


def get_dataset_parameters(v: str, provider: str = "cda"):
    # return vars(get_dataset_index(v, provider)).values()
    ds_info = vars(get_dataset_index(v, provider))
    return [member for member in ds_info.values() if isinstance(member, ParameterIndex)]


def get_parameter_index(param: str, ds: str) -> ParameterIndex:
    ds_info = vars(get_dataset_index(ds))
    return ds_info[param]

In [None]:
# | export
@patch
def preview(self: SpeasyVariable):
    print("===========================================")
    print(f"Name:         {self.name}")
    print(f"Columns:      {self.columns}")
    print(f"Values Unit:  {self.unit}")
    print(f"Memory usage: {naturalsize(self.nbytes)}")
    print(f"Axes Labels:  {self.axes_labels}")
    print("-------------------------------------------")
    print(f"Meta-data:    {self.meta}")
    print("-------------------------------------------")
    print(f"Time Axis:    {self.time[:3]}")
    print("-------------------------------------------")
    print(f"Values:       {self.values[:3]}")
    print("===========================================")

## Data Classes

Based on `SpeasyVariable` class from speasy.

In [None]:
# | export
class Variable(V):
    parameter: str = None
    product: str = None
    """product name should be unique"""

    @cached_property
    def data(self) -> SpeasyVariable:
        return spz.get_data(self.product, self.timerange)

    @property
    def time_resolution(self):
        return get_time_resolution(self.data)

    def to_polars(self):
        return spzvar2pldf(self.data)

    def preview(self):
        return self.to_polars().head().collect()

    @classmethod
    def load_from_file(cls, path: str):
        """Load the configuration from a file."""
        import yaml

        with open(path, "r") as f:
            yy = yaml.load(f, Loader=yaml.FullLoader)
            return cls(**yy)

In [None]:
# | export
@patch
def plot(self: Variable, fig=None, ax: Axes = None):
    if fig is None and ax is None:
        fig, ax = plt.subplots()

    self.data.replace_fillval_by_nan().plot(ax=ax)

    if self.name:
        ax.set_ylabel(self.name)

    return fig, ax


@patch
def dump(self: Variable, path: str):
    """Dump the configuration to a file."""
    import yaml

    with open(path, "w") as f:
        yy = yaml.load(
            self.model_dump_json(exclude_defaults=True), Loader=yaml.FullLoader
        )
        yaml.dump(yy, f)

In [None]:
# | export
class Variables(Vs):
    
    model_config = ConfigDict(arbitrary_types_allowed=True)
    
    parameters: list[str] = None
    variables: list[Variable] = None
    dataset: str = None
    provider: str = "cda"
    products: list[str | ParameterIndex] = None

    # initize products from provider and dataset if not provided
    @model_validator(mode="after")
    def check_products(self):
        if self.products is None and self.dataset:
            if self.parameters:
                self.products = [
                    f"{self.provider}/{self.dataset}/{var}" for var in self.parameters
                ]

            else:
                self.products = get_dataset_parameters(self.dataset, self.provider)
                self.parameters = [member.spz_name() for member in self.products]
        return self

    @model_validator(mode="after")
    def check_variables(self):
        if self.variables is None:
            self.variables = [Variable(product=product) for product in self.products]
        # set the same timerange for all variables
        for var in self.variables:
            var.timerange = self.timerange
        return self

    @property
    def data(self) -> list[SpeasyVariable]:
        return [var.data for var in self.variables]

    @property
    def time_resolutions(self):
        return [var.time_resolution for var in self.variables]

    def to_polars(self):
        return spzvars2pldf(self.data)

    def plot(self, gridspec_kw: dict = {"hspace": 0}):
        vars = self.variables

        fig, axes = plt.subplots(nrows=len(vars), sharex=True, gridspec_kw=gridspec_kw)
        axes: list[Axes] = axes if len(vars) > 1 else [axes]

        for var, ax in zip(vars, axes):
            var.plot(ax=ax)

        return fig, axes


class SVariables:
    products: list[str] = None
    parameters: list[str] = None

    _data: list[Variable] = None

In [None]:
earth_start = "2019-04-09"
earth_end = "2019-04-12"
timerange = [earth_start, earth_end]

d = spz.get_data(
    get_parameter_index("MOM.P.MAGF", "WI_PLSP_3DP"),
    timerange,
)

In [None]:
# | hide
from nbdev import nbdev_export

nbdev_export()

### Test


In [None]:
timerange = ["2019-04-07T01:00", "2019-04-07T12:00"]
vars = Variables(
    dataset="PSP_FLD_L2_MAG_RTN",
    parameters=["psp_fld_l2_mag_RTN"],
    timerange=timerange,
)

In [None]:
def data_provider_summary(data_provider: DataProvider = spz.cda):
    # show the name of the data_provider, the number of datasets, parameters and catalogs

    inventory = data_provider.flat_inventory
    print("Data Provider:", data_provider.provider_name)
    print("Datasets:", len(inventory.datasets))
    print("Parameters:", len(inventory.parameters))
    print("Catalogs:", len(inventory.catalogs))


# data_provider_summary(spz.cda)
# data_provider_summary(spz.amda)
# data_provider_summary(spz.csa)

Data Provider: cda
Datasets: 2608
Parameters: 58510
Catalogs: 0
Data Provider: amda
Datasets: 1074
Parameters: 5397
Catalogs: 24
Data Provider: csa
Datasets: 912
Parameters: 1993
Catalogs: 0
