# **Graph Cast Improvement Experiments**

# 1. Load and Import necessary dependencies

In [None]:
# @title Install graphcast and dependencies

%pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip

Collecting https://github.com/deepmind/graphcast/archive/master.zip
  Using cached https://github.com/deepmind/graphcast/archive/master.zip
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting shapely>=2.0 (from cartopy->graphcast==0.2.0.dev0)
  Downloading shapely-2.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (6.8 kB)
Downloading shapely-2.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m49.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: shapely
Successfully installed shapely-2.1.2


In [None]:
# @title Workaround for cartopy crashes

# Workaround for cartopy crashes due to the shapely installed by default in
# google colab kernel (https://github.com/anitagraser/movingpandas/issues/81):
!pip uninstall -y shapely
!pip install shapely --no-binary shapely

Found existing installation: shapely 2.1.2
Uninstalling shapely-2.1.2:
  Successfully uninstalled shapely-2.1.2
Collecting shapely
  Using cached shapely-2.1.2.tar.gz (315 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: shapely
  Building wheel for shapely (pyproject.toml) ... [?25l[?25hdone
  Created wheel for shapely: filename=shapely-2.1.2-cp312-cp312-linux_x86_64.whl size=1295381 sha256=f8a46ab50a20cc04bdcd04adf8dafc19103cc4986506249a8f7100d8be7fda0c
  Stored in directory: /root/.cache/pip/wheels/e9/38/c7/6f4f8e2dc4abc29e31467a46546a15559efb2b735ab460b46d
Successfully built shapely
Installing collected packages: shapely
Successfully installed shapely-2.1.2


In [None]:
# @title Imports

import dataclasses
import datetime
import functools
import math
import re
from typing import Optional

import cartopy.crs as ccrs
from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout
from graphcast import xarray_jax
from graphcast import xarray_tree
from IPython.display import HTML
import ipywidgets as widgets
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray


def parse_file_parts(file_name):
  return dict(part.split("-", 1) for part in file_name.split("_"))


In [None]:
# @title Plotting functions

def select(
    data: xarray.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xarray.Dataset:
  data = data[variable]
  if "batch" in data.dims:
    data = data.isel(batch=0)
  if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
    data = data.isel(time=range(0, max_steps))
  if level is not None and "level" in data.coords:
    data = data.sel(level=level)
  return data

def scale(
    data: xarray.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
  vmin = np.nanpercentile(data, (2 if robust else 0))
  vmax = np.nanpercentile(data, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, matplotlib.colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

def plot_data(
    data: dict[str, xarray.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:

  first_data = next(iter(data.values()))[0]
  max_steps = first_data.sizes.get("time", 1)
  assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

  cols = min(cols, len(data))
  rows = math.ceil(len(data) / cols)
  figure = plt.figure(figsize=(plot_size * 2 * cols,
                               plot_size * rows))
  figure.suptitle(fig_title, fontsize=16)
  figure.subplots_adjust(wspace=0, hspace=0)
  figure.tight_layout()

  images = []
  for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
    ax = figure.add_subplot(rows, cols, i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)
    im = ax.imshow(
        plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
        origin="lower", cmap=cmap)
    plt.colorbar(
        mappable=im,
        ax=ax,
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.75,
        cmap=cmap,
        extend=("both" if robust else "neither"))
    images.append(im)

  def update(frame):
    if "time" in first_data.dims:
      td = datetime.timedelta(microseconds=first_data["time"][frame].item() / 1000)
      figure.suptitle(f"{fig_title}, {td}", fontsize=16)
    else:
      figure.suptitle(fig_title, fontsize=16)
    for im, (plot_data, norm, cmap) in zip(images, data.values()):
      im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))

  ani = animation.FuncAnimation(
      fig=figure, func=update, frames=max_steps, interval=250)
  plt.close(figure.number)
  return HTML(ani.to_jshtml())

In [None]:
# @title Authenticate with Google Cloud Storage

gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")
dir_prefix = "graphcast/"

# 2. Load Pre-trained GraphCast Models

GraphCast model crashes during runtime due to insufficient memory. I am now switching to GraphCast_small model.

*   GraphCast model needs around 60 GB of RAM for inferencing.
*   GraphCast_small requires around 10GB of RAM for inferencing.

**Assumption:** Whatever approaches improve the GraphCast_small model, also improve GraphCast model

In [None]:
# @title Load pre-trained GraphCast_Small Model
# checkpoint_filename = "GraphCast - ERA5 1979-2017 - resolution 0.25 - pressure levels 37 - mesh 2to6 - precipitation input and output.npz"
checkpoint_filename = "GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz"

# extract model name
model_name = checkpoint_filename.split("-")[0].strip()

# Load checkpoint from GCS bucket
with gcs_bucket.blob(f"{dir_prefix}params/{checkpoint_filename}").open("rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)

# Extract components
params = ckpt.params          # Pretrained weights
state = {}
model_config = ckpt.model_config
task_config = ckpt.task_config

# Show model info
print(f"{model_name} Model Loaded Successfully\n")
print("Model description: ", ckpt.description, "\n")
print("Model license: ", ckpt.license, "\n")


GraphCast_small Model Loaded Successfully

Model description:  
Low resolution version of the GraphCast model (1deg, smaller mesh), with 37
pressure levels. This model is trained on ERA5 data from 1979 to 2015, and can
be causally evaluated on 2016 and later years. This model takes as inputs
`total_precipitation_6hr`. This model has much lower memory requirements.
 

Model license:  
The model weights are licensed under the Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You
may obtain a copy of the License at:
https://creativecommons.org/licenses/by-nc-sa/4.0/.
The weights were trained on ERA5 data, see README for attribution statement.
 



# 3. Load Example Data

In [None]:
# @title Load the appropraite data

dataset_file = "source-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc"
# dataset_file = "source-era5_date-2022-01-01_res-1.0_levels-13_steps-04.nc"
# dataset_file = "source-era5_date-2022-01-01_res-1.0_levels-13_steps-12.nc"
# dataset_file = "source-era5_date-2022-01-01_res-1.0_levels-13_steps-20.nc"
# dataset_file = "source-era5_date-2022-01-01_res-1.0_levels-13_steps-40.nc"


# Load dataset from GCS
with gcs_bucket.blob(f"{dir_prefix}dataset/{dataset_file}").open("rb") as f:
    example_batch = xarray.load_dataset(f).compute()

print("Dataset loaded with dimensions:", example_batch.dims)
example_batch

To continue decoding into a timedelta64 dtype, either set `decode_timedelta=True` when opening this dataset, or add the attribute `dtype='timedelta64[ns]'` to this variable on disk.
To opt-in to future behavior, set `decode_timedelta=False`.
  example_batch = xarray.load_dataset(f).compute()




In [None]:
# @title Extract inference inputs/targets/forcings
inputs, targets, forcings = data_utils.extract_inputs_targets_forcings(
    example_batch,
    target_lead_times=slice("6h", f"{(example_batch.sizes['time']-2)*6}h"),
    **dataclasses.asdict(task_config)
)

In [None]:
# @title Load normalization data

with gcs_bucket.blob(dir_prefix+"stats/diffs_stddev_by_level.nc").open("rb") as f:
    diffs_stddev_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob(dir_prefix+"stats/mean_by_level.nc").open("rb") as f:
    mean_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob(dir_prefix+"stats/stddev_by_level.nc").open("rb") as f:
    stddev_by_level = xarray.load_dataset(f).compute()

In [None]:
# @title Build jitted functions, and possibly initialize random weights

# def construct_wrapped_graphcast(
#     model_config: graphcast.ModelConfig,
#     task_config: graphcast.TaskConfig):
#   """Constructs and wraps the GraphCast Predictor."""
#   # Deeper one-step predictor.
#   predictor = graphcast.GraphCast(model_config, task_config)

#   # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
#   # from/to float32 to/from BFloat16.
#   predictor = casting.Bfloat16Cast(predictor)

#   # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
#   # BFloat16 happens after applying normalization to the inputs/targets.
#   predictor = normalization.InputsAndResiduals(
#       predictor,
#       diffs_stddev_by_level=diffs_stddev_by_level,
#       mean_by_level=mean_by_level,
#       stddev_by_level=stddev_by_level)

#   # Wraps everything so the one-step model can produce trajectories.
#   predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
#   return predictor


# @hk.transform_with_state
# def run_forward(model_config, task_config, inputs, targets_template, forcings):
#   predictor = construct_wrapped_graphcast(model_config, task_config)
#   return predictor(inputs, targets_template=targets_template, forcings=forcings)


# # @hk.transform_with_state
# # def loss_fn(model_config, task_config, inputs, targets, forcings):
# #   predictor = construct_wrapped_graphcast(model_config, task_config)
# #   loss, diagnostics = predictor.loss(inputs, targets, forcings)
# #   return xarray_tree.map_structure(
# #       lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
# #       (loss, diagnostics))

# # def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):
# #   def _aux(params, state, i, t, f):
# #     (loss, diagnostics), next_state = loss_fn.apply(
# #         params, state, jax.random.PRNGKey(0), model_config, task_config,
# #         i, t, f)
# #     return loss, (diagnostics, next_state)
# #   (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
# #       _aux, has_aux=True)(params, state, inputs, targets, forcings)
# #   return loss, diagnostics, next_state, grads

# # Jax doesn't seem to like passing configs as args through the jit. Passing it
# # in via partial (instead of capture by closure) forces jax to invalidate the
# # jit cache if you change configs.
# def with_configs(fn):
#   return functools.partial(
#       fn, model_config=model_config, task_config=task_config)

# # Always pass params and state, so the usage below are simpler
# def with_params(fn):
#   return functools.partial(fn, params=params, state=state)

# # Our models aren't stateful, so the state is always empty, so just return the
# # predictions. This is requiredy by our rollout code, and generally simpler.
# def drop_state(fn):
#   return lambda **kw: fn(**kw)[0]

# # init_jitted = jax.jit(with_configs(run_forward.init))

# # if params is None:
# #   params, state = init_jitted(
# #       rng=jax.random.PRNGKey(0),
# #       inputs=train_inputs,
# #       targets_template=train_targets,
# #       forcings=train_forcings)

# # loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
# # grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))
# run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
#     run_forward.apply))))

In [None]:
# @title Build jitted functions, and possibly initialize random weights

def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
  """Constructs and wraps the GraphCast Predictor."""
  # Deeper one-step predictor.
  predictor = graphcast.GraphCast(model_config, task_config)

  # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
  # from/to float32 to/from BFloat16.
  predictor = casting.Bfloat16Cast(predictor)

  # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
  # BFloat16 happens after applying normalization to the inputs/targets.
  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level)

  # Wraps everything so the one-step model can produce trajectories.
  predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
  return predictor


@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  return predictor(inputs, targets_template=targets_template, forcings=forcings)

# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
def with_configs(fn):
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

# Always pass params and state, so the usage below are simpler
def with_params(fn):
  return functools.partial(fn, params=params, state=state)

# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is requiredy by our rollout code, and generally simpler.
def drop_state(fn):
  return lambda **kw: fn(**kw)[0]

run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
    run_forward.apply))))

# 4. Model Inferencing / Run the Model

In [None]:
# @title Autoregressive rollout (loop in python)

assert model_config.resolution in (0, 360. / example_batch.sizes["lon"]), (
  "Model resolution doesn't match the data resolution. You likely want to "
  "re-filter the dataset list, and download the correct data.")


predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=inputs,
    targets_template=targets * np.nan,
    forcings=forcings)
predictions

  num_target_steps = targets_template.dims["time"]
  scan_length = targets_template.dims['time']
  num_inputs = inputs.dims['time']


In [None]:
predictions["2m_temperature"]

In [None]:
# @title Choose predictions to plot

plot_pred_variable = widgets.Dropdown(
    options=predictions.data_vars.keys(),
    value="2m_temperature",
    description="Variable")
plot_pred_level = widgets.Dropdown(
    options=predictions.coords["level"].values,
    value=500,
    description="Level")
plot_pred_robust = widgets.Checkbox(value=True, description="Robust")
plot_pred_max_steps = widgets.IntSlider(
    min=1,
    max=predictions.dims["time"],
    value=predictions.dims["time"],
    description="Max steps")

widgets.VBox([
    plot_pred_variable,
    plot_pred_level,
    plot_pred_robust,
    plot_pred_max_steps,
    widgets.Label(value="Run the next cell to plot the predictions. Rerunning this cell clears your selection.")
])

  max=predictions.dims["time"],
  value=predictions.dims["time"],


VBox(children=(Dropdown(description='Variable', index=2, options=('10m_u_component_of_wind', '10m_v_component_…

In [None]:
# @title Plot predictions

plot_size = 5
plot_max_steps = min(predictions.dims["time"], plot_pred_max_steps.value)

data = {
    "Targets": scale(select(targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Predictions": scale(select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Diff": scale((select(targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -
                        select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),
                       robust=plot_pred_robust.value, center=0),
}
fig_title = plot_pred_variable.value
if "level" in predictions[plot_pred_variable.value].coords:
  fig_title += f" at {plot_pred_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_pred_robust.value)


  plot_max_steps = min(predictions.dims["time"], plot_pred_max_steps.value)


# 5. Experiments with Ensemble Statistics (Mean, Median)

# 5.1. Taking Ensemble Mean

In [None]:
# @title Ensemble Rollout Functions with Gaussian Noise

ensemble_size = 10       # number of ensemble members
noise_std = 0.0000001        # standard deviation of Gaussian noise added to inputs


def add_gaussian_noise(inputs, rng, std):
    """Adds Gaussian noise to all input arrays."""
    noisy_inputs_dict = {}
    for k, v in inputs.items():
        key, rng = jax.random.split(rng)
        noise = jax.random.normal(key, v.shape) * std
        noisy_inputs_dict[k] = v + noise
    # Convert back to xarray Dataset
    return xarray.Dataset(noisy_inputs_dict), rng

def ensemble_rollout(run_fn, inputs, targets_template, forcings, ensemble_size=5, std=noise_std, rng=jax.random.PRNGKey(0)):
    ensemble_predictions_data = []
    for i in range(ensemble_size):
        rng, subkey = jax.random.split(rng)
        noisy_inputs, _ = add_gaussian_noise(inputs, subkey, std)

        # Pass the subkey as the rng argument to the run_fn
        pred = run_fn(rng=subkey, inputs=noisy_inputs, targets_template=targets_template, forcings=forcings)

        # Add ensemble dimension
        pred = pred.expand_dims({"ensemble": [i]})
        ensemble_predictions_data.append(pred)
        print(f"Ensemble member {i+1} done.")

    return xarray.concat(ensemble_predictions_data, dim="ensemble").astype(np.float32)


ensemble_preds_data = ensemble_rollout(
    run_forward_jitted,
    inputs=inputs,
    targets_template=targets * np.nan,
    forcings=forcings,
    ensemble_size=ensemble_size,
    std=noise_std,
    rng=jax.random.PRNGKey(42)
)

ensemble_preds_data

Ensemble member 1 done.
Ensemble member 2 done.
Ensemble member 3 done.
Ensemble member 4 done.
Ensemble member 5 done.
Ensemble member 6 done.
Ensemble member 7 done.
Ensemble member 8 done.
Ensemble member 9 done.
Ensemble member 10 done.


In [None]:
# @title Aggregate ensemble predictions to get ensemble mean

ensemble_mean = ensemble_preds_data.mean(dim="ensemble").compute()
ensemble_std = ensemble_preds_data.std(dim="ensemble").compute()

# Explicitly convert data variables to NumPy arrays and assign back with dimensions
for var in ensemble_mean.data_vars:
    dims = ensemble_mean[var].dims
    ensemble_mean[var] = xarray.DataArray(ensemble_mean[var].values, dims=dims)

for var in ensemble_std.data_vars:
    dims = ensemble_std[var].dims
    ensemble_std[var] = xarray.DataArray(ensemble_std[var].values, dims=dims)


print("Ensemble Mean:")
display(ensemble_mean)

print("\nEnsemble Standard Deviation:")
display(ensemble_std)

Ensemble Mean:



Ensemble Standard Deviation:


In [None]:
# @title Choose ensemble mean to plot

plot_pred_variable = widgets.Dropdown(
    options=ensemble_mean.data_vars.keys(),
    value="2m_temperature",
    description="Variable")
plot_pred_level = widgets.Dropdown(
    options=ensemble_mean.coords["level"].values,
    value=500,
    description="Level")
plot_pred_robust = widgets.Checkbox(value=True, description="Robust")
plot_pred_max_steps = widgets.IntSlider(
    min=1,
    max=ensemble_mean.dims["time"],
    value=ensemble_mean.dims["time"],
    description="Max steps")

widgets.VBox([
    plot_pred_variable,
    plot_pred_level,
    plot_pred_robust,
    plot_pred_max_steps,
    widgets.Label(value="Run the next cell to plot the predictions. Rerunning this cell clears your selection.")
])

  max=ensemble_mean.dims["time"],
  value=ensemble_mean.dims["time"],


VBox(children=(Dropdown(description='Variable', index=2, options=('10m_u_component_of_wind', '10m_v_component_…

In [None]:
# @title Plot ensemble mean predictions

plot_size = 5
plot_max_steps = min(ensemble_mean.dims["time"], plot_pred_max_steps.value)

data = {
    "Targets": scale(select(targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Predictions": scale(select(ensemble_mean, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Diff": scale((select(targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -
                        select(ensemble_mean, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),
                       robust=plot_pred_robust.value, center=0),
}
fig_title = plot_pred_variable.value
if "level" in ensemble_mean[plot_pred_variable.value].coords:
  fig_title += f" at {plot_pred_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_pred_robust.value)


  plot_max_steps = min(ensemble_mean.dims["time"], plot_pred_max_steps.value)


# 5.2. Taking Ensemble Median

In [None]:
# @title Aggregate median and trimmed mean
median = ensemble_preds_data.median(dim="ensemble")


In [None]:
# @title Choose median of predictions to plot

plot_pred_variable = widgets.Dropdown(
    options=median.data_vars.keys(),
    value="2m_temperature",
    description="Variable")
plot_pred_level = widgets.Dropdown(
    options=median.coords["level"].values,
    value=500,
    description="Level")
plot_pred_robust = widgets.Checkbox(value=True, description="Robust")
plot_pred_max_steps = widgets.IntSlider(
    min=1,
    max=median.dims["time"],
    value=median.dims["time"],
    description="Max steps")

widgets.VBox([
    plot_pred_variable,
    plot_pred_level,
    plot_pred_robust,
    plot_pred_max_steps,
    widgets.Label(value="Run the next cell to plot the predictions. Rerunning this cell clears your selection.")
])

  max=median.dims["time"],
  value=median.dims["time"],


VBox(children=(Dropdown(description='Variable', index=2, options=('10m_u_component_of_wind', '10m_v_component_…

In [None]:
# @title Plot median of predictions

plot_size = 5
plot_max_steps = min(median.dims["time"], plot_pred_max_steps.value)

data = {
    "Targets": scale(select(targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Predictions": scale(select(median, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Diff": scale((select(targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -
                        select(median, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),
                       robust=plot_pred_robust.value, center=0),
}
fig_title = plot_pred_variable.value
if "level" in median[plot_pred_variable.value].coords:
  fig_title += f" at {plot_pred_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_pred_robust.value)


  plot_max_steps = min(median.dims["time"], plot_pred_max_steps.value)


# 5.3. Taking Trimmed Ensemble Mean

In [None]:
# @title Calculating Trimmed Ensemble Mean
trimmed_mean = ensemble_preds_data.sortby("ensemble").isel(ensemble=slice(1, -1)).mean(dim="ensemble")

In [None]:
# @title Choose predictions to plot trimmed mean

plot_pred_variable = widgets.Dropdown(
    options=trimmed_mean.data_vars.keys(),
    value="2m_temperature",
    description="Variable")
plot_pred_level = widgets.Dropdown(
    options=trimmed_mean.coords["level"].values,
    value=500,
    description="Level")
plot_pred_robust = widgets.Checkbox(value=True, description="Robust")
plot_pred_max_steps = widgets.IntSlider(
    min=1,
    max=trimmed_mean.dims["time"],
    value=trimmed_mean.dims["time"],
    description="Max steps")

widgets.VBox([
    plot_pred_variable,
    plot_pred_level,
    plot_pred_robust,
    plot_pred_max_steps,
    widgets.Label(value="Run the next cell to plot the predictions. Rerunning this cell clears your selection.")
])

  max=trimmed_mean.dims["time"],
  value=trimmed_mean.dims["time"],


VBox(children=(Dropdown(description='Variable', index=2, options=('10m_u_component_of_wind', '10m_v_component_…

In [None]:
# @title Plot trimmed mean of predictions

plot_size = 5
plot_max_steps = min(trimmed_mean.dims["time"], plot_pred_max_steps.value)

data = {
    "Targets": scale(select(targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Predictions": scale(select(trimmed_mean, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Diff": scale((select(targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -
                        select(trimmed_mean, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),
                       robust=plot_pred_robust.value, center=0),
}
fig_title = plot_pred_variable.value
if "level" in trimmed_mean[plot_pred_variable.value].coords:
  fig_title += f" at {plot_pred_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_pred_robust.value)


  plot_max_steps = min(trimmed_mean.dims["time"], plot_pred_max_steps.value)


# 5.4. Trying Bias Correction

In [None]:
# @title Simple Bias Correction

# Apply bias correction to the ensemble mean predictions
corrected_predictions = ensemble_mean.copy()

for var in corrected_predictions.data_vars:
    if var in targets.data_vars:
        # Flatten time/space for fitting
        em_flat = ensemble_mean[var].values.flatten()
        y_flat = targets[var].values.flatten()

        # Remove NaN values for fitting
        mask = ~np.isnan(y_flat)
        em_flat_masked = em_flat[mask]
        y_flat_masked = y_flat[mask]

        if len(em_flat_masked) > 1: # Ensure there's enough data to fit a line
            # Fit linear model (obs = a * ensemble_mean + b)
            a, b = np.polyfit(em_flat_masked, y_flat_masked, 1)

            # Apply bias correction to the current variable
            corrected_predictions[var] = a * corrected_predictions[var] + b
        else:
            print(f"Skipping bias correction for {var} due to insufficient data.")
    else:
        print(f"Skipping bias correction for {var} as it is not in targets.")

print("Bias Corrected Ensemble Mean:")
display(corrected_predictions)

  a, b = np.polyfit(em_flat_masked, y_flat_masked, 1)
  a, b = np.polyfit(em_flat_masked, y_flat_masked, 1)


Bias Corrected Ensemble Mean:


In [None]:
# @title Choose corrected_predictions to plot

plot_pred_variable = widgets.Dropdown(
    options=corrected_predictions.data_vars.keys(),
    value="2m_temperature",
    description="Variable")
plot_pred_level = widgets.Dropdown(
    options=corrected_predictions.coords["level"].values,
    value=500,
    description="Level")
plot_pred_robust = widgets.Checkbox(value=True, description="Robust")
plot_pred_max_steps = widgets.IntSlider(
    min=1,
    max=corrected_predictions.dims["time"],
    value=corrected_predictions.dims["time"],
    description="Max steps")

widgets.VBox([
    plot_pred_variable,
    plot_pred_level,
    plot_pred_robust,
    plot_pred_max_steps,
    widgets.Label(value="Run the next cell to plot the predictions. Rerunning this cell clears your selection.")
])

  max=corrected_predictions.dims["time"],
  value=corrected_predictions.dims["time"],


VBox(children=(Dropdown(description='Variable', index=2, options=('10m_u_component_of_wind', '10m_v_component_…

In [None]:
# @title Plot bias corrected_predictions

plot_size = 5
plot_max_steps = min(corrected_predictions.dims["time"], plot_pred_max_steps.value)

data = {
    "Targets": scale(select(targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Predictions": scale(select(corrected_predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Diff": scale((select(targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -
                        select(corrected_predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),
                       robust=plot_pred_robust.value, center=0),
}
fig_title = plot_pred_variable.value
if "level" in corrected_predictions[plot_pred_variable.value].coords:
  fig_title += f" at {plot_pred_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_pred_robust.value)


  plot_max_steps = min(corrected_predictions.dims["time"], plot_pred_max_steps.value)


# 5.5. Comparing Results

In [None]:
# @title Functions to calculate RMSE per variable

def calculate_rmse(predictions, targets):
    """Calculates RMSE between predictions and targets, ignoring NaNs, and as a percentage."""
    # Ensure both datasets have the same variables
    common_vars = list(set(predictions.data_vars) & set(targets.data_vars))
    if not common_vars:
        raise ValueError("No common variables found between predictions and targets.")

    rmse_results = {}
    rmse_percentage_results = {}
    for var in common_vars:
        pred_values = predictions[var].values
        target_values = targets[var].values

        # Calculate squared error, ignoring NaNs
        squared_error = (pred_values - target_values)**2
        mean_squared_error = np.nanmean(squared_error)
        rmse = np.sqrt(mean_squared_error)
        rmse_results[var] = rmse

        # Calculate percentage error relative to the mean absolute target value
        mean_abs_target = np.nanmean(np.abs(target_values))
        if mean_abs_target != 0:
            rmse_percentage = (rmse / mean_abs_target) * 100
            rmse_percentage_results[var] = rmse_percentage
        else:
            rmse_percentage_results[var] = float('inf') # Handle cases where mean absolute target is zero


    return rmse_results, rmse_percentage_results


In [None]:
# @title Calculate RMSE for each variable using original predictions and targets

rmse_prediction_original, rmse_percentage_predictions_original = calculate_rmse(predictions, targets)


In [None]:
# @title Calculate RMSE for each variable using ensemble mean and targets
rmse_ensemble_mean, rmse_percentage_ensemble_mean = calculate_rmse(ensemble_mean, targets)


In [None]:
# @title Calculate RMSE for each variable using ensemble median and targets
rmse_ensemble_median, rmse_percentage_ensemble_median = calculate_rmse(median, targets)


In [None]:
# @title Calculate RMSE for each variable using ensemble trimmed mean and targets
rmse_ensemble_trimmed_mean, rmse_percentage_ensemble_trimmed_mean = calculate_rmse(trimmed_mean, targets)


In [None]:
# @title Tabulate RMSE of each ensemble aggregation techniques

import pandas as pd

# Create a dictionary to hold the RMSE results
rmse_comparison_aggregation = {
    "Variable": list(rmse_prediction_original.keys()),
    "Original Prediction RMSE": list(rmse_prediction_original.values()),
    "Ensemble Mean RMSE": list(rmse_ensemble_mean.values()),
    "Ensemble Median RMSE": list(rmse_ensemble_median.values()),
    "Ensemble Trimmed Mean RMSE": list(rmse_ensemble_trimmed_mean.values())
}

# Create a pandas DataFrame
rmse_aggregation_df = pd.DataFrame(rmse_comparison_aggregation)

# Display the DataFrame
display(rmse_aggregation_df)

Unnamed: 0,Variable,Original Prediction RMSE,Ensemble Mean RMSE,Ensemble Median RMSE,Ensemble Trimmed Mean RMSE
0,temperature,0.442355,0.442752,0.442615,0.442781
1,v_component_of_wind,1.240714,1.240377,1.240389,1.240392
2,total_precipitation_6hr,0.000641,0.000641,0.000641,0.000641
3,specific_humidity,0.000225,0.000225,0.000225,0.000225
4,mean_sea_level_pressure,42.128708,42.073589,42.073338,42.078846
5,10m_v_component_of_wind,0.64042,0.640355,0.640406,0.640368
6,geopotential,38.699047,38.693371,38.669533,38.70079
7,u_component_of_wind,1.224694,1.224574,1.224557,1.224604
8,vertical_velocity,0.09793,0.097902,0.097903,0.097899
9,10m_u_component_of_wind,0.630767,0.630644,0.630652,0.630635


In [None]:
# @title Compare Bias Corrected vs Ensemble Mean RMSE

import pandas as pd

# Calculate RMSE for original predictions
rmse_prediction_original, _ = calculate_rmse(predictions, targets)

# Calculate RMSE for ensemble mean
rmse_ensemble_mean, _ = calculate_rmse(ensemble_mean, targets)

# Calculate RMSE for bias corrected predictions
rmse_corrected_predictions, _ = calculate_rmse(corrected_predictions, targets)

# Create a dictionary to hold the RMSE results
rmse_comparison_bias_correction = {
    "Variable": list(rmse_prediction_original.keys()),
    "Original Prediction RMSE": list(rmse_prediction_original.values()),
    "Ensemble Mean RMSE": list(rmse_ensemble_mean.values()),
    "Bias Corrected Prediction RMSE": list(rmse_corrected_predictions.values())
}

# Create a pandas DataFrame
rmse_bias_correction_df = pd.DataFrame(rmse_comparison_bias_correction)

# Display the DataFrame
display(rmse_bias_correction_df)

Unnamed: 0,Variable,Original Prediction RMSE,Ensemble Mean RMSE,Bias Corrected Prediction RMSE
0,temperature,0.442355,0.442752,14.338505
1,v_component_of_wind,1.240714,1.240377,1.24008
2,total_precipitation_6hr,0.000641,0.000641,0.000631
3,specific_humidity,0.000225,0.000225,0.000225
4,mean_sea_level_pressure,42.128708,42.073589,660.583023
5,10m_v_component_of_wind,0.64042,0.640355,0.640153
6,geopotential,38.699047,38.693371,37.969866
7,u_component_of_wind,1.224694,1.224574,1.224443
8,vertical_velocity,0.09793,0.097902,0.097863
9,10m_u_component_of_wind,0.630767,0.630644,0.630066


# 6. Experiment with noise magnitude (Ensemble Size is Fixed at 10)

Due to the time taken to run an ensemble is little high ensemble member count is fixed at 10.

# 6.1. Noise Magnitude = 10^(-1)

In [None]:
noise_std=0.1
ensemble_size=10

ensemble_noise_exp_1 = ensemble_rollout(
    run_forward_jitted,
    inputs=inputs,
    targets_template=targets * np.nan,
    forcings=forcings,
    ensemble_size=ensemble_size,
    std=noise_std,
    rng=jax.random.PRNGKey(42)
)

Ensemble member 1 done.
Ensemble member 2 done.
Ensemble member 3 done.
Ensemble member 4 done.
Ensemble member 5 done.
Ensemble member 6 done.
Ensemble member 7 done.
Ensemble member 8 done.
Ensemble member 9 done.
Ensemble member 10 done.


In [None]:
# @title Calculate RMSE for ensemble run with noise 10^(-1)

ensemble_mean_noise_exp_1 = ensemble_noise_exp_1.mean(dim="ensemble").compute()
rmse_prediction_noise_exp_1, rmse_percentage_predictions_noise_exp_1 = calculate_rmse(ensemble_mean_noise_exp_1, targets)


# 6.2. Noise Magnitude = 10^(-3)

In [None]:
noise_std=0.0001
ensemble_size=10

ensemble_noise_exp_3 = ensemble_rollout(
    run_forward_jitted,
    inputs=inputs,
    targets_template=targets * np.nan,
    forcings=forcings,
    ensemble_size=ensemble_size,
    std=noise_std,
    rng=jax.random.PRNGKey(42)
)

Ensemble member 1 done.
Ensemble member 2 done.
Ensemble member 3 done.
Ensemble member 4 done.
Ensemble member 5 done.
Ensemble member 6 done.
Ensemble member 7 done.
Ensemble member 8 done.
Ensemble member 9 done.
Ensemble member 10 done.


In [None]:
# @title Calculate RMSE for ensemble run with noise 10^(-3)

ensemble_mean_noise_exp_3 = ensemble_noise_exp_3.mean(dim="ensemble").compute()
rmse_prediction_noise_exp_3, rmse_percentage_predictions_noise_exp_3 = calculate_rmse(ensemble_mean_noise_exp_3, targets)


# 6.3. Noise Magnitude = 10^(-5)

In [None]:
noise_std=0.00001
ensemble_size=10

ensemble_noise_exp_5 = ensemble_rollout(
    run_forward_jitted,
    inputs=inputs,
    targets_template=targets * np.nan,
    forcings=forcings,
    ensemble_size=ensemble_size,
    std=noise_std,
    rng=jax.random.PRNGKey(42)
)

Ensemble member 1 done.
Ensemble member 2 done.
Ensemble member 3 done.
Ensemble member 4 done.
Ensemble member 5 done.
Ensemble member 6 done.
Ensemble member 7 done.
Ensemble member 8 done.
Ensemble member 9 done.
Ensemble member 10 done.


In [None]:
# @title Calculate RMSE for ensemble run with noise 10^(-5)

ensemble_mean_noise_exp_5 = ensemble_noise_exp_5.mean(dim="ensemble").compute()
rmse_prediction_noise_exp_5, rmse_percentage_predictions_noise_exp_5 = calculate_rmse(ensemble_mean_noise_exp_5, targets)


# 6.4. Noise Magnitude = 10^(-6)

In [None]:
noise_std=0.000001
ensemble_size=10

ensemble_noise_exp_6 = ensemble_rollout(
    run_forward_jitted,
    inputs=inputs,
    targets_template=targets * np.nan,
    forcings=forcings,
    ensemble_size=ensemble_size,
    std=noise_std,
    rng=jax.random.PRNGKey(42)
)

Ensemble member 1 done.
Ensemble member 2 done.
Ensemble member 3 done.
Ensemble member 4 done.
Ensemble member 5 done.
Ensemble member 6 done.
Ensemble member 7 done.
Ensemble member 8 done.
Ensemble member 9 done.
Ensemble member 10 done.


In [None]:
# @title Calculate RMSE for ensemble run with noise 10^(-6)

ensemble_mean_noise_exp_6 = ensemble_noise_exp_6.mean(dim="ensemble").compute()
rmse_prediction_noise_exp_6, rmse_percentage_predictions_noise_exp_6 = calculate_rmse(ensemble_mean_noise_exp_6, targets)


# 6.5. Noise Magnitude = 10^(-7)

In [None]:
noise_std=0.0000001
ensemble_size=10

ensemble_noise_exp_7 = ensemble_rollout(
    run_forward_jitted,
    inputs=inputs,
    targets_template=targets * np.nan,
    forcings=forcings,
    ensemble_size=ensemble_size,
    std=noise_std,
    rng=jax.random.PRNGKey(42)
)

Ensemble member 1 done.
Ensemble member 2 done.
Ensemble member 3 done.
Ensemble member 4 done.
Ensemble member 5 done.
Ensemble member 6 done.
Ensemble member 7 done.
Ensemble member 8 done.
Ensemble member 9 done.
Ensemble member 10 done.


In [None]:
# @title Calculate RMSE for ensemble run with noise 10^(-7)

ensemble_mean_noise_exp_7 = ensemble_noise_exp_7.mean(dim="ensemble").compute()
rmse_prediction_noise_exp_7, rmse_percentage_predictions_noise_exp_7 = calculate_rmse(ensemble_mean_noise_exp_7, targets)

# 6.6. Noise Magnitude = 10^(-8)

In [None]:
noise_std=0.00000001
ensemble_size=10

ensemble_noise_exp_8 = ensemble_rollout(
    run_forward_jitted,
    inputs=inputs,
    targets_template=targets * np.nan,
    forcings=forcings,
    ensemble_size=ensemble_size,
    std=noise_std,
    rng=jax.random.PRNGKey(42)
)

Ensemble member 1 done.
Ensemble member 2 done.
Ensemble member 3 done.
Ensemble member 4 done.
Ensemble member 5 done.
Ensemble member 6 done.
Ensemble member 7 done.
Ensemble member 8 done.
Ensemble member 9 done.
Ensemble member 10 done.


In [None]:
# @title Calculate RMSE for ensemble run with noise 10^(-8)

ensemble_mean_noise_exp_8 = ensemble_noise_exp_8.mean(dim="ensemble").compute()
rmse_prediction_noise_exp_8, rmse_percentage_predictions_noise_exp_8 = calculate_rmse(ensemble_mean_noise_exp_8, targets)

# 6.7. Compare the Results

In [None]:
# @title Tabulate RMSE for each noise setting

import pandas as pd

# Create a dictionary to hold the RMSE results
rmse_comparison = {
    "Variable": list(rmse_prediction_original.keys()),
    "Original Prediction": list(rmse_prediction_original.values()),
    "Noise 10^(-1)": list(rmse_prediction_noise_exp_1.values()),
    "Noise 10^(-3)": list(rmse_prediction_noise_exp_3.values()),
    "Noise 10^(-5)": list(rmse_prediction_noise_exp_5.values()),
    "Noise 10^(-6)": list(rmse_prediction_noise_exp_6.values()),
    "Noise 10^(-7)": list(rmse_prediction_noise_exp_7.values()),
    "Noise 10^(-8)": list(rmse_prediction_noise_exp_8.values()),
}

# Create a pandas DataFrame
rmse_df = pd.DataFrame(rmse_comparison)

# Display the DataFrame
display(rmse_df)

Unnamed: 0,Variable,Original Prediction,Noise 10^(-1),Noise 10^(-3),Noise 10^(-5),Noise 10^(-6),Noise 10^(-7),Noise 10^(-8)
0,temperature,0.442355,3.990034,3.982821,2.588837,0.483672,0.44274,0.442301
1,v_component_of_wind,1.240714,6.844286,6.406958,4.61538,1.307718,1.240328,1.240383
2,total_precipitation_6hr,0.000641,0.032065,0.004314,0.001981,0.000651,0.000641,0.000641
3,specific_humidity,0.000225,0.031635,0.000811,0.000538,0.000229,0.000225,0.000225
4,mean_sea_level_pressure,42.128708,889.176392,829.365601,421.788147,47.515511,42.058968,41.976044
5,10m_v_component_of_wind,0.64042,2.915304,2.71545,1.830944,0.65585,0.640402,0.640429
6,geopotential,38.699047,727.677368,710.181946,452.629517,47.371582,38.683083,38.566456
7,u_component_of_wind,1.224694,9.393501,7.827695,5.860281,1.298496,1.224622,1.224406
8,vertical_velocity,0.09793,0.505797,0.397314,0.228283,0.09838,0.097897,0.097937
9,10m_u_component_of_wind,0.630767,3.691192,3.552326,2.108152,0.647309,0.630632,0.630739


# 7. Experiment with ensemble member count with fixed noised magnitude(10^-7).

# 7.1. Ensemble Member Count = 2

In [None]:
noise_std=0.0000001
ensemble_size=2

ensemble_mem_2 = ensemble_rollout(
    run_forward_jitted,
    inputs=inputs,
    targets_template=targets * np.nan,
    forcings=forcings,
    ensemble_size=ensemble_size,
    std=noise_std,
    rng=jax.random.PRNGKey(42)
)

Ensemble member 1 done.
Ensemble member 2 done.


In [None]:
# @title Calculate RMSE for ensemble with 2 members

ensemble_mean_mem_2 = ensemble_mem_2.mean(dim="ensemble").compute()
rmse_prediction_mem_2, rmse_percentage_predictions_mem_2 = calculate_rmse(ensemble_mean_mem_2, targets)


# 7.2. Ensemble Member Count = 5

In [None]:
noise_std=0.0000001
ensemble_size=5

ensemble_mem_5 = ensemble_rollout(
    run_forward_jitted,
    inputs=inputs,
    targets_template=targets * np.nan,
    forcings=forcings,
    ensemble_size=ensemble_size,
    std=noise_std,
    rng=jax.random.PRNGKey(42)
)

Ensemble member 1 done.
Ensemble member 2 done.
Ensemble member 3 done.
Ensemble member 4 done.
Ensemble member 5 done.


In [None]:
# @title Calculate RMSE for ensemble with 5 members

ensemble_mean_mem_5 = ensemble_mem_5.mean(dim="ensemble").compute()
rmse_prediction_mem_5, rmse_percentage_predictions_mem_5 = calculate_rmse(ensemble_mean_mem_5, targets)

# 7.3. Ensemble Member Count = 10

In [None]:
noise_std=0.0000001
ensemble_size=10

ensemble_mem_10 = ensemble_rollout(
    run_forward_jitted,
    inputs=inputs,
    targets_template=targets * np.nan,
    forcings=forcings,
    ensemble_size=ensemble_size,
    std=noise_std,
    rng=jax.random.PRNGKey(42)
)

Ensemble member 1 done.
Ensemble member 2 done.
Ensemble member 3 done.
Ensemble member 4 done.
Ensemble member 5 done.
Ensemble member 6 done.
Ensemble member 7 done.
Ensemble member 8 done.
Ensemble member 9 done.
Ensemble member 10 done.


In [None]:
# @title Calculate RMSE for ensemble with 10 members

ensemble_mean_mem_10 = ensemble_mem_10.mean(dim="ensemble").compute()
rmse_prediction_mem_10, rmse_percentage_predictions_mem_10 = calculate_rmse(ensemble_mean_mem_10, targets)

# 7.4. Ensemble Member Count = 15

In [None]:
noise_std=0.0000001
ensemble_size=15

ensemble_mem_15 = ensemble_rollout(
    run_forward_jitted,
    inputs=inputs,
    targets_template=targets * np.nan,
    forcings=forcings,
    ensemble_size=ensemble_size,
    std=noise_std,
    rng=jax.random.PRNGKey(42)
)

Ensemble member 1 done.
Ensemble member 2 done.
Ensemble member 3 done.
Ensemble member 4 done.
Ensemble member 5 done.
Ensemble member 6 done.
Ensemble member 7 done.
Ensemble member 8 done.
Ensemble member 9 done.
Ensemble member 10 done.
Ensemble member 11 done.
Ensemble member 12 done.
Ensemble member 13 done.
Ensemble member 14 done.
Ensemble member 15 done.


In [None]:
# @title Calculate RMSE for ensemble with 15 members

ensemble_mean_mem_15 = ensemble_mem_15.mean(dim="ensemble").compute()
rmse_prediction_mem_15, rmse_percentage_predictions_mem_15 = calculate_rmse(ensemble_mean_mem_15, targets)

# 7.5. Compare Results

In [None]:
# @title Tabulate RMSE for each ensemble member count

import pandas as pd

# Create a dictionary to hold the RMSE results
rmse_comparison_member_count = {
    "Variable": list(rmse_prediction_original.keys()),
    "Original Prediction RMSE": list(rmse_prediction_original.values()),
    "Ensemble Members = 2 RMSE": list(rmse_prediction_mem_2.values()),
    "Ensemble Members = 5 RMSE": list(rmse_prediction_mem_5.values()),
    "Ensemble Members = 10 RMSE": list(rmse_prediction_mem_10.values()),
    "Ensemble Members = 15 RMSE": list(rmse_prediction_mem_15.values()),
}

# Create a pandas DataFrame
rmse_member_count_df = pd.DataFrame(rmse_comparison_member_count)

# Display the DataFrame
display(rmse_member_count_df)

Unnamed: 0,Variable,Original Prediction RMSE,Ensemble Members = 2 RMSE,Ensemble Members = 5 RMSE,Ensemble Members = 10 RMSE,Ensemble Members = 15 RMSE
0,temperature,0.442355,0.443039,0.442807,0.442754,0.442707
1,v_component_of_wind,1.240714,1.240826,1.240656,1.240416,1.240319
2,total_precipitation_6hr,0.000641,0.00064,0.00064,0.00064,0.000641
3,specific_humidity,0.000225,0.000225,0.000225,0.000225,0.000225
4,mean_sea_level_pressure,42.128708,42.07925,42.071419,42.055698,42.049641
5,10m_v_component_of_wind,0.64042,0.640605,0.640384,0.640331,0.640317
6,geopotential,38.699047,38.719555,38.700905,38.678978,38.671928
7,u_component_of_wind,1.224694,1.224828,1.224776,1.224694,1.224579
8,vertical_velocity,0.09793,0.097896,0.097902,0.097898,0.097908
9,10m_u_component_of_wind,0.630767,0.630905,0.630703,0.630621,0.630688
