In [None]:
import xarray as xa
from matplotlib import pyplot as plt
from matplotlib import cm, colors
import numpy as np
import cartopy
import matplotlib.ticker as mticker
import cmaps
import matplotlib.lines as mlines
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

In [None]:
ace_monsoon = xa.open_dataset('demo_output/monsoon_wang/ACE2-PCMDI/r1/ACE2_AllM_wang-monsoon.nc')
ace_monsoon['hitmap'].plot()

In [None]:
print(ace_monsoon['modmap'].min())
print(ace_monsoon['modmap'].max())

In [None]:
ngcm_precip = xa.open_dataset('demo_output/monsoon_wang/NeuralGCM-precip/r100/NeuralGCM-precip_AllM_wang-monsoon.nc')
ngcm_evap = xa.open_dataset('demo_output/monsoon_wang/NeuralGCM-evap/r56/NeuralGCM-evap_AllM_wang-monsoon.nc')

In [None]:
def add_latlon(ax, top=False, bottom=False, left=False, right=False):
    gl = ax.gridlines(crs=cartopy.crs.PlateCarree(), draw_labels=True,
                  linewidth=2, color='gray', alpha=0.5, linestyle='--')
    gl.top_labels = top
    gl.left_labels = left
    gl.right_labels = right
    gl.bottom_labels = bottom
    
    gl.ylocator = mticker.FixedLocator([-60, -30, 0, 30, 60])
    gl.xlocator = mticker.FixedLocator([-180, -90, 0, 90, 180])
    gl.xformatter = LONGITUDE_FORMATTER
    gl.yformatter = LATITUDE_FORMATTER
def plot_model(model, ax, s=.5):
    precip_map = model['modmap']*86400
    hit = model['hitmap']
    miss = model['missmap']
    fa = model['falarmmap']
    lon, lat = np.meshgrid(precip_map.lon, precip_map.lat)
    ax.contourf(obs_map.lon, obs_map.lat, obs_map, transform=cartopy.crs.PlateCarree(), cmap=cmap, norm=norm, levels=30)
    ax.scatter(lon[miss], lat[miss], transform=cartopy.crs.PlateCarree(), color='red', s=s)
    ax.scatter(lon[hit], lat[hit], transform=cartopy.crs.PlateCarree(), color='blue', s=s, marker = "o")
    ax.scatter(lon[fa], lat[fa], transform=cartopy.crs.PlateCarree(), color='green', s=s, marker = "o")

In [None]:
fig = plt.figure(figsize=(7, 8))
cmap = cmaps.WhiteBlueGreenYellowRed
norm = colors.Normalize(vmin=0, vmax=13)
ax = fig.add_subplot(411, projection=cartopy.crs.PlateCarree(central_longitude=180))
obs_map = ngcm_evap['obsmap']*86400
ax.contourf(obs_map.lon, obs_map.lat, obs_map, transform=cartopy.crs.PlateCarree(), cmap=cmap, norm=norm, levels=30)
add_latlon(ax, left=True, )
ax.add_feature(cartopy.feature.COASTLINE)
ax.set_title('Observation')

ax = fig.add_subplot(412, projection=cartopy.crs.PlateCarree(central_longitude=180))
add_latlon(ax, left=True, )
ax.add_feature(cartopy.feature.COASTLINE)
ax.set_title('ACE2')
plot_model(ace_monsoon, ax, s=.05)

ax = fig.add_subplot(413, projection=cartopy.crs.PlateCarree(central_longitude=180))
add_latlon(ax, left=True, )
ax.add_feature(cartopy.feature.COASTLINE)
ax.set_title('NeuralGCM-evap')
plot_model(ngcm_evap, ax, s=1)

ax = fig.add_subplot(414, projection=cartopy.crs.PlateCarree(central_longitude=180))
add_latlon(ax, left=True, bottom=True)
ax.add_feature(cartopy.feature.COASTLINE)
ax.set_title('NeuralGCM-precip')
plot_model(ngcm_evap, ax, s=1)
plt.tight_layout()
fig.subplots_adjust(right=0.95)
cbar_ax = fig.add_axes([0.98, 0.2, 0.015, 0.59])
cbar = fig.colorbar(cm.ScalarMappable(norm=norm,cmap=cmap), cax=cbar_ax, extend='max', label='Precipitation (mm/day)')
cbar.ax.tick_params(labelsize=12)
red_dot = mlines.Line2D([], [], color='red', marker='o', linestyle='None', markersize=6, label='Miss')
blue_star = mlines.Line2D([], [], color='blue', marker='o', linestyle='None', markersize=6, label='Hit')
green_triangle = mlines.Line2D([], [], color='green', marker='o', linestyle='None', markersize=6, label='False Alarm')

fig.legend(handles=[red_dot, blue_star, green_triangle], loc='lower center', ncol=3, frameon=True, fontsize=12, bbox_to_anchor=(0.55, -0.03))
plt.savefig('Figs/Monsoon_wang_all.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
(obs_map).max()

In [None]:
miss