In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.io.shapereader import Reader, natural_earth
from cartopy.feature import ShapelyFeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
from matplotlib.colors import LinearSegmentedColormap, LogNorm
import cmocean
import matplotlib.ticker as mticker
import matplotlib.cm as cm
from utils import move_last_two_months_first

font_size = 24
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams['axes.labelsize'] = font_size
plt.rcParams['axes.titlesize'] = font_size
plt.rcParams['xtick.labelsize'] = font_size
plt.rcParams['ytick.labelsize'] = font_size
plt.rcParams['legend.fontsize'] = font_size
plt.rcParams['legend.title_fontsize'] = font_size

def compute_total_output_variance_from_files_iteratively(data_folder):
    files = glob.glob(os.path.join(data_folder, "*.hdf5"))
    all_grids = []
    all_stds = []
    
    s1 = np.zeros((110, 210, 6, 6))
    s2 = np.zeros((110, 210, 6, 6))
    c = 0
    for file_path in files:
        with h5py.File(file_path, 'r+') as file:
            for grid in file['train_grids'][:,-1]:
                s1 += grid
                s2 += grid ** 2
                c += 1

    variance =  s2/c - (s1/c)**2

    return variance
    
def compute_sampling_loss_from_files_iteratively(data_folder, n_samples):
    files = glob.glob(os.path.join(data_folder, "*.hdf5"))
    
    s1 = np.zeros((110, 210, 6, 6))
    c = 0
    for file_path in files:
        with h5py.File(file_path, 'r+') as file:
            for std in file['train_stds'][:,-1]:
                s1 += std ** 2 / n_samples
                c += 1

    return s1/c

def make_years_to_relative_error_threshold_figure(variance, sample_loss_one, threshold):

    variance = move_last_two_months_first(variance)
    sample_loss_one = move_last_two_months_first(sample_loss_one)
    
    month_strings = ["November", "December", "January", "February", "March", "April"]

    fig, axes = plt.subplots(5, 6, figsize=(20,15), layout="compressed", subplot_kw={'projection': ccrs.PlateCarree(central_longitude=180)})

    # Add transparent land by creating a custom feature
    land_shp = natural_earth(resolution='110m', category='physical', name='land')
    land_feature = ShapelyFeature(Reader(land_shp).geometries(),
                                  ccrs.PlateCarree(), facecolor='lightgray', edgecolor='face', alpha=0.5)
    # Create a feature for States/Admin 1 regions at 1:50m from Natural Earth
    states_provinces = cfeature.NaturalEarthFeature(
        category='cultural',
        name='admin_1_states_provinces_lines',
        scale='50m',
        facecolor='none')
    
    # Add countries boundaries and labels
    countries_shp = natural_earth(resolution='50m', category='cultural', name='admin_0_countries')
    countries_feature = ShapelyFeature(Reader(countries_shp).geometries(), ccrs.PlateCarree(),
                                       facecolor='none', edgecolor='black', linewidth=1, alpha=1)
    norm = LogNorm(10, 10000)
    pc = np.empty((5, 6), dtype=object)
    for category in range(1, 6):
        
        for month in range(6):
            
            ax = axes[category - 1, month]
            ax.add_feature(land_feature)
            ax.add_feature(countries_feature)
            ax.add_feature(states_provinces, edgecolor='gray')

            ax.set_xticks([])
            ax.set_yticks([])
            sample_loss_one[:,:,month, category][sample_loss_one[:,:,month, category] == 0] = np.inf
            
            val = (sample_loss_one[:,:,month, category]) / (threshold*(variance[:,:,month, category]))
            
            # Plot gradient for model prediction on top of the land feature
            lonplot2 = np.linspace(135, 240, val.shape[1])
            latplot2 = np.linspace(-5, -60, val.shape[0])
            
            # Use a continuous colormap

            pc[category-1][month] = ax.contourf(lonplot2, latplot2, val, levels=10, cmap=cmocean.cm.matter, transform=ccrs.PlateCarree(), norm=norm)

            # make extra sure that they all use the same norm
            if category == 1 and month == 0:
                norm = pc[category-1][month].norm
            
            gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False,
                              linewidth=1, color='white', alpha=0.3, linestyle='--')
            gl.top_labels = False
            gl.right_labels = False
            ax.set_extent([135, 240, -60, -5], crs=ccrs.PlateCarree())

    for month in range(6):
        ax = axes[0, month]
        ax.set_title(month_strings[month], pad=15)
        # Add gridlines with labels
        
        gl.xlocator = mticker.FixedLocator([-180, -160, -140, -90, -45, 0, 45, 90, 140, 160, 180, 225, 270, 315, 360])
        gl.xformatter = LONGITUDE_FORMATTER
        gl.yformatter = LATITUDE_FORMATTER
        gl.xlabel_style = {'size': 18, 'color': 'black'}
        gl.ylabel_style = {'size': 18, 'color': 'black'}
            
    for category in range(1, 6):
        ax = axes[category - 1, 0]
        ax.set_ylabel(f"Category {category}", rotation=0, ha='right', labelpad=20)
    
    cbar = fig.colorbar(cm.ScalarMappable(norm=norm, cmap=cmocean.cm.matter), ax=axes, fraction=0.046, pad=0.005, aspect=20)

    cbar.ax.tick_params(labelsize=18)
        
    plt.show()

data_folder  = "/path/to/data"


In [None]:
var = np.load("var.npy", allow_pickle=True)

## or if not calculated already
# var = compute_total_output_variance_from_files_iteratively(data_folder)

In [None]:
loss = np.load("loss.npy", allow_pickle=True)

## or if not calculated already
# loss = compute_sampling_loss_from_files_iteratively(data_folder, 1)

In [None]:
make_years_to_relative_error_threshold_figure(np.flipud(var), np.flipud(loss), 1)