## Imports and utility functions

In [None]:
!pip install xarray cdsapi

In [None]:
import xarray
from typing import Optional
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
import math
import pandas as pd
import cdsapi
import scipy

In [None]:
def select(
    data: xarray.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xarray.Dataset:
  data = data[variable]
  if "batch" in data.dims:
    data = data.isel(batch=0)
  if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
    data = data.isel(time=range(0, max_steps))
  if level is not None and "level" in data.coords:
    data = data.sel(level=level)
  return data

def scale(
    data: xarray.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xarray.Dataset, colors.Normalize, str]:
  vmin = np.nanpercentile(data, (2 if robust else 0))
  vmax = np.nanpercentile(data, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

def plot_data(
    data: dict[str, xarray.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4
    ) -> tuple[xarray.Dataset, colors.Normalize, str]:

  first_data = next(iter(data.values()))[0]
  max_steps = first_data.sizes.get("time", 1)
  assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

  cols = min(cols, len(data))
  rows = math.ceil(len(data) / cols)
  figure = plt.figure(figsize=(plot_size * 2 * cols,
                               plot_size * rows))
  figure.suptitle(fig_title, fontsize=16)
  figure.subplots_adjust(wspace=0, hspace=0)
  figure.tight_layout()

  images = []
  for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
    ax = figure.add_subplot(rows, cols, i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)
    im = ax.imshow(
        plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
        origin="lower", cmap=cmap)
    plt.colorbar(
        mappable=im,
        ax=ax,
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.75,
        cmap=cmap,
        extend=("both" if robust else "neither"))
    images.append(im)

## Plot input data

#### Additional steps for Google Colab

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd drive/MyDrive/XXX

#### Plotting

In [None]:
plot_size = 5

with open("single-level.nc", "rb") as f:
  example_batch = xarray.load_dataset(f).compute()

# Europe in the middle + Map not upside down :
# in the longitude dimension, map values higher than 180 to values - 360
example_batch["longitude"] = xarray.where(example_batch['longitude'] > 180, example_batch['longitude'] - 360, example_batch['longitude'])
# sort example_batch by latitude and longitude
example_batch = example_batch.sortby(["latitude", "longitude"])

plot_example_level = 50
plot_example_max_steps = 1
plot_example_robust = True

variables = ["t2m", "u10", "msl"]

for variable in variables:
  plot_example_variable = variable

  data = {
      " ": scale(select(example_batch, plot_example_variable, plot_example_level, plot_example_max_steps),
                robust=plot_example_robust),
  }
  fig_title = plot_example_variable
  if "level" in example_batch[plot_example_variable].coords:
    fig_title += f" at {plot_example_level} hPa"

  plot_data(data, fig_title, plot_size, plot_example_robust)

## Plot predicted data

### Fetch the real target data

In [None]:
## IN GOOGLE COLAB

url = 'url: XXX'
key = 'key: XXX'

with open('/root/.cdsapirc', 'w') as f:
    f.write('\n'.join([url, key]))

with open('/root/.cdsapirc') as f:
    print(f.read())

In [None]:
client = cdsapi.Client()

In [None]:
day = 15
month = 3
year = 2024

singlelevelfields = [
                        '10m_u_component_of_wind',
                        '10m_v_component_of_wind',
                        '2m_temperature',
                        'geopotential',
                        'land_sea_mask',
                        'mean_sea_level_pressure',
                        'toa_incident_solar_radiation',
                        'total_precipitation'
                    ]

client.retrieve(
    'reanalysis-era5-single-levels',
    {
        'product_type': 'reanalysis',
        'variable': singlelevelfields,
        'grid': '1.0/1.0',
        'year': [year],
        'month': [month],
        'day': [day],
        'time': ['12:00', '13:00', '14:00', '15:00', '16:00', '17:00', '18:00'], # time of first prediction
        'format': 'netcdf'
    },
    'single-level-truth.nc'
)

In [None]:
with open("single-level-truth.nc", "rb") as f:
  eval_targets = xarray.load_dataset(f).compute()

# Europe in the middle + Map not upside down :
# in the longitude dimension, map values higher than 180 to values - 360
eval_targets["longitude"] = xarray.where(eval_targets['longitude'] > 180, eval_targets['longitude'] - 360, eval_targets['longitude'])
# sort example_batch by latitude and longitude
eval_targets = eval_targets.sortby(["latitude", "longitude"])

# eval_targets = xarray.open_dataset('single-level-truth.nc', engine = scipy.__name__).to_dataframe()
# eval_targets = eval_targets.rename(columns = {col:singlelevelfields[ind] for ind, col in enumerate(eval_targets.columns.values.tolist())})
# eval_targets = eval_targets.rename(columns = {'geopotential': 'geopotential_at_surface'})

# eval_targets = eval_targets.sort_index()
# eval_targets['total_precipitation_6hr'] = eval_targets.groupby(level=[0, 1])['total_precipitation'].rolling(window = 6, min_periods = 1).sum().reset_index(level=[0, 1], drop=True)
# eval_targets.pop('total_precipitation')


### Compare with predictions

In [None]:
predictions = pd.read_csv("predictions.csv")
predictions = predictions.rename(columns = {"lon": "longitude", "lat": "latitude", "10m_u_component_of_wind": "u10", "10m_v_component_of_wind": "v10", "2m_temperature": "t2m", "geopotential":"z", "mean_sea_level_pressure":"msl", "total_precipitation_6hr":"tp"})

# select only one pressure level (closest to surface)
predictions = predictions.loc[(predictions.level == 50)]
predictions["time"] = pd.to_datetime(predictions["time"])
predictions["longitude"] = predictions["longitude"].apply(lambda x: x - 360 if x > 180 else x)
predictions.sort_values(by = ['longitude', 'latitude', 'time'], inplace = True)
predictions.set_index(['time', 'latitude', 'longitude'], inplace = True)
predictions = predictions.drop(['batch','level'], axis = 1)
predictions = xarray.Dataset.from_dataframe(predictions)

# convert all floats to float32
predictions = predictions.astype(np.float32)

predictions

In [None]:
plot_size = 5

plot_pred_level = 50
plot_pred_robust = True
plot_max_steps = 1 # min(predictions.dims["time"], plot_pred_max_steps.value)


# eval_targets = example_batch

variables = ["t2m", "u10", "msl"]

for variable in variables:
  plot_pred_variable = variable

  data = {
      "Targets": scale(select(eval_targets, plot_pred_variable, plot_pred_level, plot_max_steps), robust=plot_pred_robust),
      "Predictions": scale(select(predictions, plot_pred_variable, plot_pred_level, plot_max_steps), robust=plot_pred_robust),
      "Diff": scale((select(eval_targets, plot_pred_variable, plot_pred_level, plot_max_steps) -
                          select(predictions, plot_pred_variable, plot_pred_level, plot_max_steps)),
                        robust=plot_pred_robust, center=0),
  }
  fig_title = plot_pred_variable
  if "level" in predictions[plot_pred_variable].coords:
    fig_title += f" at {plot_pred_level} hPa"

  plot_data(data, fig_title, plot_size, plot_pred_robust)