Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore the time series to the variable summaries #408

Merged
116 changes: 28 additions & 88 deletions src/unified_graphics/diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from collections import namedtuple
from dataclasses import dataclass
from enum import Enum
from typing import Generator, List, Optional, Union
from pathlib import Path
from typing import Generator, List, Union
from urllib.parse import urlparse

import numpy as np
import pandas as pd
import sqlalchemy as sa
import xarray as xr
import zarr # type: ignore
Expand Down Expand Up @@ -75,44 +77,6 @@ def to_geojson(self):
}


@dataclass
class SummaryStatistics:
min: float
max: float
mean: float

@classmethod
def from_data_array(cls, array: xr.DataArray) -> "SummaryStatistics":
return cls(
min=float(array.min()),
max=float(array.max()),
mean=float(array.mean()),
)


@dataclass
class DiagSummary:
initialization_time: str
obs_minus_forecast_adjusted: SummaryStatistics
obs_minus_forecast_unadjusted: SummaryStatistics
observation: SummaryStatistics
obs_count: int

@classmethod
def from_dataset(cls, dataset: xr.Dataset) -> "DiagSummary":
return cls(
initialization_time=dataset.attrs["initialization_time"],
obs_minus_forecast_adjusted=SummaryStatistics.from_data_array(
dataset["obs_minus_forecast_adjusted"]
),
obs_minus_forecast_unadjusted=SummaryStatistics.from_data_array(
dataset["obs_minus_forecast_unadjusted"]
),
observation=SummaryStatistics.from_data_array(dataset["observation"]),
obs_count=len(dataset["nobs"]),
)


ModelMetadata = namedtuple(
"ModelMetadata",
(
Expand Down Expand Up @@ -486,65 +450,41 @@ def get_model_run_list(
return group.group_keys()


def summary(
diag_zarr: str,
def history(
parquet_path: str,
model: str,
system: str,
domain: str,
background: str,
frequency: str,
initialization_time: str,
variable: Variable,
loop: MinimLoop,
filters: MultiDict,
) -> Optional[DiagSummary]:
store = get_store(diag_zarr)
path = "/".join(
[
model,
system,
domain,
background,
frequency,
variable.value,
initialization_time,
loop.value,
]
) -> pd.DataFrame:
# FIXME: This fails when diag_zarr is a file:// URL. Pandas ends up trying to use
# urlopen to read the file, but it's a directory. For now, we strip file://, but
# this is a hack.
parquet_file = (
Path(parquet_path.replace("file://", ""))
/ "_".join((model, background, system, domain, frequency))
/ variable.value
)

ds = xr.open_zarr(store, group=path, consolidated=False)
ds = apply_filters(ds, filters)
return DiagSummary.from_dataset(ds) if len(ds["nobs"]) > 0 else None

df = pd.read_parquet(
parquet_file,
columns=["initialization_time", "obs_minus_forecast_unadjusted"],
filters=(("loop", "=", loop.value), ("is_used", "=", True)),
)

def history(
diag_zarr: str,
model: str,
system: str,
domain: str,
background: str,
frequency: str,
variable: Variable,
loop: MinimLoop,
filters: MultiDict,
):
for init_time in get_model_run_list(
diag_zarr, model, system, domain, background, frequency, variable
):
result = summary(
diag_zarr,
model,
system,
domain,
background,
frequency,
init_time,
variable,
loop,
filters,
)
if df.empty:
return df

if not result:
continue
df = (
df.sort_values("initialization_time")
.groupby("initialization_time")
.describe()
.droplevel(0, axis=1) # Drop a level from the columns created by the groupby
.reset_index()
)

yield result
return df[["initialization_time", "min", "max", "mean", "count"]]
6 changes: 4 additions & 2 deletions src/unified_graphics/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def serviceworker():
@bp.route("/diag/<model>/<system>/<domain>/<background>/<frequency>/<variable>/<loop>/")
def history(model, system, domain, background, frequency, variable, loop):
data = diag.history(
current_app.config["DIAG_ZARR"],
current_app.config["DIAG_PARQUET"],
model,
system,
domain,
Expand All @@ -147,7 +147,9 @@ def history(model, system, domain, background, frequency, variable, loop):
request.args,
)

return jsonify([d for d in data])
return data.to_json(orient="records", date_format="iso"), {
"Content-Type": "application/json"
}


@bp.route(
Expand Down
11 changes: 4 additions & 7 deletions src/unified_graphics/static/js/component/ChartTimeSeries.js
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,7 @@ export default class ChartTimeSeries extends ChartElement {
}

get yScale() {
const domain = [
min(this.#data, (d) => d.obs_minus_forecast_adjusted.min),
max(this.#data, (d) => d.obs_minus_forecast_adjusted.max),
];
const domain = [min(this.#data, (d) => d.min), max(this.#data, (d) => d.max)];
const { top, bottom } = this.margin;
const height = this.height - top - bottom;

Expand All @@ -196,12 +193,12 @@ export default class ChartTimeSeries extends ChartElement {
const { xScale, yScale } = this;
const rangeArea = area()
.x((d) => xScale(d.initialization_time))
.y0((d) => yScale(d.obs_minus_forecast_adjusted.min))
.y1((d) => yScale(d.obs_minus_forecast_adjusted.max))
.y0((d) => yScale(d.min))
.y1((d) => yScale(d.max))
.curve(curveBumpX);
const meanLine = line()
.x((d) => xScale(d.initialization_time))
.y((d) => yScale(d.obs_minus_forecast_adjusted.mean))
.y((d) => yScale(d.mean))
.curve(curveBumpX);

this.#svg.attr("viewBox", `0 0 ${this.width} ${this.height}`);
Expand Down
19 changes: 14 additions & 5 deletions src/unified_graphics/templates/layouts/diag.html
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,20 @@ <h2 class="heading-2 flex-0">{% if minim_loop == "ges" %}Guess{% else %}Analysis
>Observation &minus; Forecast</color-ramp>
</chart-container>
{%- else -%}
<chart-container class="padding-2 radius-md bg-white shadow-1">
<span class="axis-y title" slot="title-y">Observation Count</span>
<chart-histogram id="distribution-{{ minim_loop }}" src="{{ dist_url[minim_loop] }}"></chart-histogram>
<span class="axis-x title" slot="title-x">Observation &minus; Forecast</span>
</chart-container>
<div class="grid">
<chart-container class="padding-2 radius-md bg-white shadow-1">
<span class="axis-y title" slot="title-y">Observation Count</span>
<chart-histogram id="distribution-{{ minim_loop }}" src="{{ dist_url[minim_loop] }}"></chart-histogram>
<span class="axis-x title" slot="title-x">Observation &minus; Forecast</span>
</chart-container>

<chart-container class="padding-2 radius-md bg-white shadow-1">
<span class="axis-y title" slot="title-y">Observation &minus; Forecast</span>
<chart-timeseries id="history-{{ minim_loop }}" src="{{ history_url[minim_loop] }}"
current="{{ form.get("initialization_time") }}"></chart-timeseries>
<span class="axis-x title" slot="title-x">Initialization Time</span>
</chart-container>
</div>
{%- endif %}

<chart-container class="padding-2 radius-md bg-white shadow-1">
Expand Down
29 changes: 27 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
from typing import Optional

import alembic.command
Expand Down Expand Up @@ -85,6 +86,7 @@ def session(engine):
s.rollback()


# FIXME: Replace diag_dataset with this fixture
@pytest.fixture(scope="class")
def test_dataset():
def factory(
Expand All @@ -99,7 +101,7 @@ def factory(
loop: str = "ges",
longitude: list[float] = [90, 91],
latitude: list[float] = [22, 23],
is_used: list[int] = [1, 0],
is_used: list[bool] = [True, False],
observation: list[float] = [1, 0],
forecast_unadjusted: list[float] = [0, 1],
forecast_adjusted: Optional[list[float]] = None,
Expand All @@ -126,7 +128,7 @@ def factory(
coords=dict(
longitude=(["nobs"], np.array(longitude, dtype=np.float64)),
latitude=(["nobs"], np.array(latitude, dtype=np.float64)),
is_used=(["nobs"], np.array(is_used, dtype=np.int8)),
is_used=(["nobs"], np.array(is_used)),
**kwargs,
),
attrs={
Expand All @@ -142,3 +144,26 @@ def factory(
)

return factory


@pytest.fixture
def diag_parquet(tmp_path):
def factory(
ds: xr.Dataset,
) -> Path:
parquet_file = (
tmp_path
/ "_".join((ds.model, ds.background, ds.system, ds.domain, ds.frequency))
/ ds.name
)
df = ds.to_dataframe()
df["loop"] = ds.loop
df["initialization_time"] = ds.initialization_time

df.to_parquet(
parquet_file, partition_cols=["loop"], index=True, engine="pyarrow"
)

return parquet_file

return factory
Loading
Loading