In [None]:
# !pip install xarray cdsapi

In [None]:
import xarray
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
import math
import pandas
import cdsapi

In [None]:
# @title Plotting functions

def select(
    data: xarray.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xarray.Dataset:
  data = data[variable]
  if "batch" in data.dims:
    data = data.isel(batch=0)
  if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
    data = data.isel(time=range(0, max_steps))
  if level is not None and "level" in data.coords:
    data = data.sel(level=level)
  return data

def scale(
    data: xarray.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
  vmin = np.nanpercentile(data, (2 if robust else 0))
  vmax = np.nanpercentile(data, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, matplotlib.colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

def plot_data(
    data: dict[str, xarray.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:

  first_data = next(iter(data.values()))[0]
  max_steps = first_data.sizes.get("time", 1)
  assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

  cols = min(cols, len(data))
  rows = math.ceil(len(data) / cols)
  figure = plt.figure(figsize=(plot_size * 2 * cols,
                               plot_size * rows))
  figure.suptitle(fig_title, fontsize=16)
  figure.subplots_adjust(wspace=0, hspace=0)
  figure.tight_layout()

  images = []
  for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
    ax = figure.add_subplot(rows, cols, i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)
    im = ax.imshow(
        plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
        origin="lower", cmap=cmap)
    plt.colorbar(
        mappable=im,
        ax=ax,
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.75,
        cmap=cmap,
        extend=("both" if robust else "neither"))
    images.append(im)

## Plot input data

In [None]:
plot_size = 7

with ("single-level.nc").open("rb") as f:
  example_batch = xarray.load_dataset(f).compute()

plot_example_variable = "2m_temperature"
plot_example_level = 50
plot_example_max_steps = 1

data = {
    " ": scale(select(example_batch, plot_example_variable, plot_example_level, plot_example_max_steps),
              robust=plot_example_robust),
}
fig_title = plot_example_variable
if "level" in example_batch[plot_example_variable].coords:
  fig_title += f" at {plot_example_level} hPa"

plot_data(data, fig_title, plot_size, plot_example_robust)

## Plot predicted data

### Fetch the real target data

In [None]:
client = cdsapi.Client()

In [None]:
day = 15
month = 3
year = 2024

client.retrieve(
    'reanalysis-era5-single-levels',
    {
        'product_type': 'reanalysis',
        'variable': singlelevelfields,
        'grid': '1.0/1.0',
        'year': [year],
        'month': [month],
        'day': [day],
        'time': ['18:00'], # time of first prediction
        'format': 'netcdf'
    },
    'single-level-truth.nc'
)
eval_targets = xarray.open_dataset('single-level-truth.nc', engine = scipy.__name__).to_dataframe()
eval_targets = eval_targets.rename(columns = {col:eval_targetsfields[ind] for ind, col in enumerate(eval_targets.columns.values.tolist())})
eval_targets = eval_targets.rename(columns = {'geopotential': 'geopotential_at_surface'})

### Compare with predictions

In [None]:
# @title Plot predictions

plot_size = 5
predictions = pd.read_csv("predictions.csv")
predictions = xarray.Dataset.from_dataframe(predictions)
predictions = predictions.set_index("time").to_array("variable")
predictions = predictions.rename({"index": "time"})
predictions = predictions.compute()

plot_pred_variable = "2m_temperature"
plot_pred_level = 50
plot_pred_robust = True
plot_max_steps = 1 # min(predictions.dims["time"], plot_pred_max_steps.value)


data = {
    "Targets": scale(select(eval_targets, plot_pred_variable, plot_pred_level, plot_max_steps), robust=plot_pred_robust),
    "Predictions": scale(select(predictions, plot_pred_variable, plot_pred_level, plot_max_steps), robust=plot_pred_robust),
    "Diff": scale((select(eval_targets, plot_pred_variable, plot_pred_level, plot_max_steps) -
                        select(predictions, plot_pred_variable, plot_pred_level, plot_max_steps)),
                       robust=plot_pred_robust, center=0),
}
fig_title = plot_pred_variable
if "level" in predictions[plot_pred_variable].coords:
  fig_title += f" at {plot_pred_level} hPa"

plot_data(data, fig_title, plot_size, plot_pred_robust)