In [2]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.io.shapereader import Reader, natural_earth
from cartopy.feature import ShapelyFeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.ticker as mticker

def plot_examples_w_sites(all_outputs, sites=None):

    # Add transparent land by creating a custom feature
    land_shp = natural_earth(resolution='110m', category='physical', name='land')
    land_feature = ShapelyFeature(Reader(land_shp).geometries(),
                                  ccrs.PlateCarree(), facecolor='lightgray', edgecolor='face', alpha=0.5)
    # Create a feature for States/Admin 1 regions at 1:50m from Natural Earth
    states_provinces = cfeature.NaturalEarthFeature(
        category='cultural',
        name='admin_1_states_provinces_lines',
        scale='50m',
        facecolor='none')
    
    # Add countries boundaries and labels
    countries_shp = natural_earth(resolution='50m', category='cultural', name='admin_0_countries')
    countries_feature = ShapelyFeature(Reader(countries_shp).geometries(), ccrs.PlateCarree(),
                                       facecolor='none', edgecolor='black', linewidth=1, alpha=1)
    
    n_examples = 5
    
    n_models = len(all_outputs.keys())
    fig, axes = plt.subplots(n_models, n_examples, figsize=(20,15), layout="compressed")

    model_n = 0
    
    max_val = np.max([list(outputs) for outputs in all_outputs.values()])
    
    for model, outputs in all_outputs.items():

        for i in range(n_examples):
            ax = axes[model_n, i]

            ax.add_feature(land_feature)
            ax.add_feature(countries_feature)
            ax.add_feature(states_provinces, edgecolor='gray')
            ax.set_xticks([])
            ax.set_yticks([])

            # Plot gradient for model prediction on top of the land feature
            lonplot2 = np.linspace(135, 240, val.shape[1])
            latplot2 = np.linspace(-5, -60, val.shape[0])

            val = outputs[i]
            
            # Use a continuous colormap
            contour = ax.contourf(lonplot2, latplot2, val, levels=100, cmap='summer', transform=ccrs.PlateCarree())
            
            gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False,
                              linewidth=1, color='white', alpha=0.3, linestyle='--')
            gl.top_labels = False
            gl.right_labels = False
            if sites is not None:
                # 
                ax.scatter(sites)
                ax.set_extent([140, 200, -30, -5], crs=ccrs.PlateCarree())
                
            else:
                ax.set_extent([135, 240, -60, -5], crs=ccrs.PlateCarree())
            
        model_n += 1

    for model_idx, model in enumerate(all_outputs.keys()):
        ax = axes[model_idx, 0]
        ax.set_ylabel(model, rotation=0, ha='right')
        
    fig.colorbar(f, ax=axes)
    if save_path is not None:
        plt.savefig(save_path)
        plt.clf()
    else: 
        plt.show()