# Master's thesis - Lukas Meuris - graphCast evaluation

This notebook contains the code to plot results of the models. 

In [None]:
import sys
sys.path.append("../")

import numpy as np
import pandas as pd
import xarray as xr
import cartopy.crs as ccrs

import optax

import os
import time
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

import weatherbench2
from weatherbench2 import config

import matplotlib.pyplot as plt

# Deterministic results
following code is used to get deterministic results (RMSE, bias, ACC, SEEPS)

In [None]:
results_mse = xr.open_dataset('../evaluation/mse_det.nc')
results_mae = xr.open_dataset('../evaluation/mae_det.nc')
results_lch = xr.open_dataset('../evaluation/lch_det.nc')
results_rse = xr.open_dataset('../evaluation/rse_det.nc')
results_rae = xr.open_dataset('../evaluation/rae_det.nc')

In [None]:
results_mse

In [None]:
# add RMSE metric based on the MSE.
results_mse = xr.concat(
    [results_mse,
    results_mse.sel(metric=['mse']).assign_coords(metric=['rmse']) ** 0.5],
    dim='metric'
)
results_mae = xr.concat(
    [results_mae,
    results_mae.sel(metric=['mse']).assign_coords(metric=['rmse']) ** 0.5],
    dim='metric'
)
results_lch = xr.concat(
    [results_lch,
    results_lch.sel(metric=['mse']).assign_coords(metric=['rmse']) ** 0.5],
    dim='metric'
)
results_rse = xr.concat(
    [results_rse,
    results_rse.sel(metric=['mse']).assign_coords(metric=['rmse']) ** 0.5],
    dim='metric'
)
results_rae = xr.concat(
    [results_rae,
    results_rae.sel(metric=['mse']).assign_coords(metric=['rmse']) ** 0.5],
    dim='metric'
)

In [None]:
#convert time from ns to days.
results_mse['lead_time'] = results_mse['lead_time'].astype('timedelta64[ns]') / pd.Timedelta(days=1)
results_mae['lead_time'] = results_mae['lead_time'].astype('timedelta64[ns]') / pd.Timedelta(days=1)
results_lch['lead_time'] = results_lch['lead_time'].astype('timedelta64[ns]') / pd.Timedelta(days=1)
results_rse['lead_time'] = results_rse['lead_time'].astype('timedelta64[ns]') / pd.Timedelta(days=1)
results_rae['lead_time'] = results_rae['lead_time'].astype('timedelta64[ns]') / pd.Timedelta(days=1)

In [None]:
# remove initial time from results.
results_mse = results_mse.isel(lead_time=slice(1,41))
results_mae = results_mae.isel(lead_time=slice(1,41))
results_lch = results_lch.isel(lead_time=slice(1,41))
results_rse = results_rse.isel(lead_time=slice(1,41))
results_rae = results_rae.isel(lead_time=slice(1,41))

In [None]:
#plot results
for region in results_mse['region'].values:
    plt.figure()
    var = '10m_v_component_of_wind'
    metric = 'acc'
    title = 'V10M'
    units = "[m/s]"

    results_mse[var].sel(metric=metric, region=region).plot(label='MSE', color='blue')
    results_mae[var].sel(metric=metric, region=region).plot(label='MAE', color='red')
    results_lch[var].sel(metric=metric, region=region).plot(label='Log-cosh', color='green')
    results_rse[var].sel(metric=metric, region=region).plot(label='RSE', color='purple')
    results_rae[var].sel(metric=metric, region=region).plot(label='RAE', color='orange')

    # Add legend, title and labels
    if title == 'T2M':
        plt.legend(fontsize=15)
    plt.title(title)
    plt.ylabel("ACC")
    #plt.axhline(y=0, color='grey', linestyle='-')
    plt.xlabel("Lead time (days)")

    #save plot to file
    plt.savefig('../plots/ACC/' + region + '_' + title + '.png', dpi=300, bbox_inches='tight')
    
    # Show the plot
    plt.show()


# spatial results
spatial results show the results by lon x lat on a world map

## Prediction
show one 10 day forecast for 2020-01-01

In [None]:
# show predictions for 2020-01-01:
relative_path = "predictions/pred_64x32_2020_rae_eval.zarr"
pred_path = os.path.join(os.path.dirname(os.getcwd()), relative_path)
pred_data = xr.open_zarr(pred_path)
pred_data = pred_data.isel(time=0).compute()

In [None]:
plt.figure()
g = pred_data['total_precipitation_6hr'].isel(prediction_timedelta=[12,20,40]).plot(x='longitude',
                                                                           y='latitude',
                                                                           col="prediction_timedelta",
                                                                           col_wrap=3,
                                                                           robust=True,
                                                                           subplot_kws={'projection': ccrs.PlateCarree()},
                                                                           aspect=1.5)
title = 'RAE_T2M'
# Manually set titles for each subplot
subtitles = ['3 days', '5 days', '10 days']
for ax, subtitle in zip(g.axs.flat, subtitles):
    ax.set_title(subtitle)
    ax.coastlines()

plt.suptitle(title, fontsize=16, x=0.4, y=1)
#plt.savefig('../plots/Predictions/' + title + '.png', dpi=300, bbox_inches='tight')
plt.show()

## bias map
show the bias maps for 2020

In [None]:
results = xr.open_dataset('../evaluation/mse_spatial.nc')

In [None]:
datasets = ['mse','mae','rse','rae']
for dataset in datasets:
    plt.figure()
    var = '10m_u_component_of_wind'
    title = 'U10M'
    units = " [m/s]"
    level = 700

    lead_times = [np.timedelta64(3, 'D'),np.timedelta64(5, 'D'),np.timedelta64(10, 'D')]

    results = xr.open_dataset('../evaluation/' + dataset + '_spatial.nc')
    g = results[var].sel(metric='bias', lead_time=lead_times).plot(x='longitude',y='latitude',
                                                               col="lead_time", robust=True, 
                                                               cbar_kwargs={"label": var + units},
                                                               subplot_kws={'projection': ccrs.PlateCarree()},
                                                               aspect=1.5, cmap= 'coolwarm',
                                                               vmin=-3,vmax=3)
    
    # Manually set titles for each subplot
    subtitles = ['3 days', '5 days', '10 days']
    for ax, subtitle in zip(g.axs.flat, subtitles):
        ax.set_title(subtitle)
        ax.coastlines()

    title = title + ' ' + dataset.upper()
    plt.suptitle(title, fontsize=16, x=0.4, y=1)
    plt.savefig('../plots/Bias_maps/' + title + '.png', dpi=300, bbox_inches='tight')
    plt.show()


In [None]:

plt.figure()
var = '10m_v_component_of_wind'
title = 'V10M'
units = " [m/s]"
level = 700

lead_times = [np.timedelta64(3, 'D'),np.timedelta64(5, 'D'),np.timedelta64(10, 'D')]

results = xr.open_dataset('../evaluation/lch_spatial.nc')
g = results[var].sel(metric='bias', lead_time=lead_times).plot(x='longitude',y='latitude',
                                                               col="lead_time", robust=True, 
                                                               cbar_kwargs={"label": var + units},
                                                               subplot_kws={'projection': ccrs.PlateCarree()},
                                                               aspect=1.5, cmap= 'coolwarm',
                                                               vmin=-3,vmax=3)
    
# Manually set titles for each subplot
subtitles = ['3 days', '5 days', '10 days']
for ax, subtitle in zip(g.axs.flat, subtitles):
    ax.set_title(subtitle)
    ax.coastlines()

title = title + ' ' + 'log-cosh'
plt.suptitle(title, fontsize=16, x=0.4, y=1)
plt.savefig('../plots/Bias_maps/' + title + '.png', dpi=300, bbox_inches='tight')
plt.show()
