# Notebook for regridding using xarray to compare WRF HIST and GridRad MESH

### Read in geometry and apply coarsen to WRF lats/lons:

In [1]:
import xarray as xr

geog = xr.open_dataset("/home/scratch/ahaberlie/simgeog/geo_em.d01.nc") #Open WRF lat/lon coords file

coarse_geog = geog[['CLAT', 'CLONG']].coarsen(south_north=20, west_east=20, boundary='trim').mean() #Find mean lat/lon of grid runs: averaging over 20 pixels is
                                                                                                    #roughly equivalent to 75 km (20 * 3.75 km = 75)

### New dimensions of lat/lon data are now 44 x 69 after averaging across 20 pixels (mean lat/lon of the 20th pixel run, or centroid)

## Use WRF HIST data to validate that this method upscales original data to ~75 x 75 km grid

In [None]:
wrf_hist = xr.open_mfdataset('/home/scratch/ahaberlie/AFWA_HAIL/HIST/*/*.nc') #Open WRF daily max file
wrf_hist

## Open gridrad MESH dataset for validation against WRF HIST

In [2]:
gridrad = xr.open_mfdataset('/home/scratch/gridrad_mesh/gridrad/*/*/*.nc')
gridrad

In [None]:
coarse_gridrad_geog = gridrad[['Latitude', 'Longitude']].coarsen(Latitude=17, Longitude=52, boundary='trim').mean()

## Resample to convective daily max

In [None]:
wrf_daily_max = wrf_hist.resample(Time='24H', base=12).max()
wrf_daily_max #Returns 5479 days (4 leap days)

In [4]:
gridrad_daily_max = gridrad.resample(time='24H', base=12).max()
gridrad_daily_max

Unnamed: 0,Array,Chunk
Bytes,3.77 GiB,10.54 MiB
Shape,"(366, 1201, 2301)","(1, 1201, 2301)"
Count,3653 Tasks,366 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 3.77 GiB 10.54 MiB Shape (366, 1201, 2301) (1, 1201, 2301) Count 3653 Tasks 366 Chunks Type float32 numpy.ndarray",2301  1201  366,

Unnamed: 0,Array,Chunk
Bytes,3.77 GiB,10.54 MiB
Shape,"(366, 1201, 2301)","(1, 1201, 2301)"
Count,3653 Tasks,366 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.77 GiB,10.54 MiB
Shape,"(366, 1201, 2301)","(1, 1201, 2301)"
Count,3653 Tasks,366 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 3.77 GiB 10.54 MiB Shape (366, 1201, 2301) (1, 1201, 2301) Count 3653 Tasks 366 Chunks Type float32 numpy.ndarray",2301  1201  366,

Unnamed: 0,Array,Chunk
Bytes,3.77 GiB,10.54 MiB
Shape,"(366, 1201, 2301)","(1, 1201, 2301)"
Count,3653 Tasks,366 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.77 GiB,10.54 MiB
Shape,"(366, 1201, 2301)","(1, 1201, 2301)"
Count,3653 Tasks,366 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 3.77 GiB 10.54 MiB Shape (366, 1201, 2301) (1, 1201, 2301) Count 3653 Tasks 366 Chunks Type float32 numpy.ndarray",2301  1201  366,

Unnamed: 0,Array,Chunk
Bytes,3.77 GiB,10.54 MiB
Shape,"(366, 1201, 2301)","(1, 1201, 2301)"
Count,3653 Tasks,366 Chunks
Type,float32,numpy.ndarray


## Pass severe hail threshold (can switch this out for large hail [≥ 4 cm] as well)

In [None]:
sev_hail_days = 1 * (wrf_daily_max >= 0.0254) #Pass day threshold to this variable
sev_hail_days

In [None]:
gridrad_sev = 1 * (gridrad_daily_max >= 25.4) #Pass mesh day threshold to this variable
gridrad_sev

### Compute hail days to work with later

In [None]:
import dask
import dask.array as da
from dask import delayed
import dask.dataframe as dd
from dask.distributed import Client
dask.config.set({'temporary_directory': '/home/scratch/jgoodin'})
client = Client()
client

In [None]:
sev_hail_days = sev_hail_days.compute()

In [None]:
gridrad_sev = gridrad_sev.compute()

## Groupby season to find seasonal sev hail days

In [None]:
seasonal_max = sev_hail_days.groupby('Time.season')
seasonal_max

In [None]:
gridrad_seasonal = gridrad_sev.groupby('Time.season')
gridrad_seasonal

### Select by season for plotting

In [None]:
wrf_DJF = seasonal_max['DJF']
wrf_MAM = seasonal_max['MAM']
wrf_JJA = seasonal_max['JJA']
wrf_SON = seasonal_max['SON']

In [None]:
wrf_JJA #Returns # of days in 3-month season x 899 x 1399

In [None]:
gridrad_DJF = gridrad_seasonal['DJF']
gridrad_MAM = gridrad_seasonal['MAM']
gridrad_JJA = gridrad_seasonal['JJA']
gridrad_SON = gridrad_seasonal['SON']

In [None]:
gridrad_MAM

## Plot summed hail day counts to new grid

In [22]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy
%matplotlib inline

"""
plt.rcParams['figure.figsize'] = 10, 10

geog_test_hail = deepcopy(geog[['CLAT', 'CLONG']])

geog_test_hail['sev_hail_days'] = (('Time', 'south_north', 'west_east'), np.array([sev_hail_days]))

plt.imshow(geog_test_hail['sev_hail_days'].values[0,:,:])
"""

"\nplt.rcParams['figure.figsize'] = 10, 10\n\ngeog_test_hail = deepcopy(geog[['CLAT', 'CLONG']])\n\ngeog_test_hail['sev_hail_days'] = (('Time', 'south_north', 'west_east'), np.array([sev_hail_days]))\n\nplt.imshow(geog_test_hail['sev_hail_days'].values[0,:,:])\n"

# Coarsen the data

### NOTE: must coarsen MESH data to different number of pixels since resolution isn't the same- both must be (44 x 69)

In [None]:
hail_day_coarse = sev_hail_days.coarsen(south_north=20, west_east=20, boundary='trim').max()

In [None]:
hail_day_coarse #Make sure new lats/lons are (44 x 69)

In [None]:
wrf_DJF_coarse = wrf_DJF.coarsen(south_north=20, west_east=20, boundary='trim').max()
wrf_MAM_coarse = wrf_MAM.coarsen(south_north=20, west_east=20, boundary='trim').max()
wrf_JJA_coarse = wrf_JJA.coarsen(south_north=20, west_east=20, boundary='trim').max()
wrf_SON_coarse = wrf_SON.coarsen(south_north=20, west_east=20, boundary='trim').max()

### Coarsen gridrad MESH data (will require different number of pixels to be passed)

#### NOTE: might have to be an approximation for gridrad

In [None]:
gridrad_coarse = gridrad_sev.coarsen(Longitude=52, Latitude=17, boundary='trim').max()

In [None]:
gridrad_coarse

In [None]:
gridrad_coarse.Longitude.shape

In [None]:
gridrad_DJF_coarse = gridrad_DJF.coarsen(Longitude=52, Latitude=17, boundary='trim').max()
gridrad_MAM_coarse = gridrad_MAM.coarsen(Longitude=52, Latitude=17, boundary='trim').max()
gridrad_JJA_coarse = gridrad_JJA.coarsen(Longitude=52, Latitude=17, boundary='trim').max()
gridrad_SON_coarse = gridrad_SON.coarsen(Longitude=52, Latitude=17, boundary='trim').max()

## Sum along 'Time' dimension (simulation year) to get hail day count for each grid

In [None]:
annual_sum = hail_day_coarse.sum(dim = 'Time')

In [None]:
wrf_DJF_sum = wrf_DJF_coarse.sum(dim = 'Time')
wrf_MAM_sum = wrf_MAM_coarse.sum(dim = 'Time')
wrf_JJA_sum = wrf_JJA_coarse.sum(dim = 'Time')
wrf_SON_sum = wrf_SON_coarse.sum(dim = 'Time')

In [None]:
annual_sum

In [None]:
gridrad_sum = gridrad_coarse.sum(dim = 'time')

In [None]:
gridrad_sum

In [None]:
#gridrad_sum.MESH95.max()

In [None]:
gridrad_DJF_sum = gridrad_DJF_coarse.sum(dim = 'time')
gridrad_MAM_sum = gridrad_MAM_coarse.sum(dim = 'time')
gridrad_JJA_sum = gridrad_JJA_coarse.sum(dim = 'time')
gridrad_SON_sum = gridrad_SON_coarse.sum(dim = 'time')

### Divide by 15 (23 for gridrad mesh) to get mean annual statistic

In [None]:
annual_stat = annual_sum / 15
#wrf_DJF_stat = wrf_DJF_sum / 15
#wrf_MAM_stat = wrf_MAM_sum / 15
#wrf_JJA_stat = wrf_JJA_sum / 15
#wrf_SON_stat = wrf_SON_sum / 15

In [None]:
gridrad_annual_stat = gridrad_sum / 23
#gridrad_DJF_stat = gridrad_DJF_sum / 23
#gridrad_MAM_stat = gridrad_MAM_sum / 23
#gridrad_JJA_stat = gridrad_JJA_sum / 23
#gridrad_SON_stat = gridrad_SON_sum / 23

### Select HAIL_MAX2D (MESH95) variable to create plottable 2D array

In [None]:
annual_plot = annual_stat.HAIL_MAX2D
wrf_DJF_stat_plot = wrf_DJF_stat.HAIL_MAX2D
wrf_MAM_stat_plot = wrf_MAM_stat.HAIL_MAX2D #**These are plotting variables**
wrf_JJA_stat_plot = wrf_JJA_stat.HAIL_MAX2D
wrf_SON_stat_plot = wrf_SON_stat.HAIL_MAX2D

In [None]:
annual_stat_plot.max()

In [None]:
gridrad_annual_plot = gridrad_annual_stat.MESH95
#gridrad_DJF_stat_plot = gridrad_DJF_stat.MESH95
#gridrad_MAM_stat_plot = gridrad_MAM_stat.MESH95 #**These are plotting variables**
#gridrad_JJA_stat_plot = gridrad_JJA_stat.MESH95
#gridrad_SON_stat_plot = gridrad_SON_stat.MESH95

In [None]:
gridrad_stat_plot.max()

## Find deltas between WRF HIST and Gridrad MESH

In [None]:
annual_delta = annual_stat_plot - gridrad_stat_plot
DJF_delta = wrf_DJF_stat_plot - gridrad_DJF_stat_plot
MAM_delta = wrf_MAM_stat_plot - gridrad_MAM_stat_plot #**These are also plotting variables**
JJA_delta = wrf_JJA_stat_plot - gridrad_JJA_stat_plot
SON_delta = wrf_SON_stat_plot - gridrad_SON_stat_plot

## Plot on map of CONUS

In [None]:
import cartopy
import cartopy.crs as ccrs
import matplotlib.cm as cm
import matplotlib.gridspec as gridspec
import cartopy.feature as cfeature
import cartopy.io.shapereader as shpreader

fig = plt.figure(figsize= (14,10)) #Plot a figure
gs1 = gridspec.GridSpec(2, 1, height_ratios=[1, .04], bottom=.05, top=.95, wspace=.1)
map_proj = ccrs.LambertConformal(central_longitude = -100, central_latitude = 35) #Set projection, central lat/lon
ax = plt.subplot(projection = map_proj)
ax.set_extent([240,287,22,50])
ax.add_feature(cfeature.LAND.with_scale('10m')) #Add land
countries_shp = shpreader.natural_earth(resolution='50m',
                                     category='cultural',
                                     name='admin_0_countries')
for country, info in zip(shpreader.Reader(countries_shp).geometries(), 
                             shpreader.Reader(countries_shp).records()):
        if info.attributes['NAME_LONG'] != 'United States':

            ax.add_geometries([country], ccrs.PlateCarree(),
                             facecolor='lightgrey', edgecolor='k', zorder=6)
ax.add_feature(cfeature.NaturalEarthFeature('physical', 'coastline', '50m', edgecolor='k', 
                                                facecolor='None'), zorder=8) 
ax.add_feature(cfeature.BORDERS.with_scale('10m'), linewidth = 2) #Add state borders
ax.add_feature(cfeature.STATES.with_scale('10m'), facecolor = 'none', linewidth = 2) #Add US states
ax.add_feature(cfeature.NaturalEarthFeature('physical', 'ocean', '50m', edgecolor='face', 
                                                facecolor='lightsteelblue'), zorder=6)
lakes_shp = shpreader.natural_earth(resolution='50m',
                                     category='physical',
                                     name='lakes')
for lake, info in zip(shpreader.Reader(lakes_shp).geometries(), 
                             shpreader.Reader(lakes_shp).records()):
        name = info.attributes['name']
        if name == 'Lake Superior' or name == 'Lake Michigan' or \
           name == 'Lake Huron' or name == 'Lake Erie' or name == 'Lake Ontario':
            
            ax.add_geometries([lake], ccrs.PlateCarree(),
                             facecolor='lightsteelblue', edgecolor='k', zorder=6)
            
cmap = cm.get_cmap('viridis')
cmap.set_over('lemonchiffon')
levels = np.arange(0, 110, 10)

mmp = ax.pcolormesh(coarse_gridrad_geog['Longitude'].values, coarse_gridrad_geog['Latitude'].values, gridrad_annual_plot, vmax = 110, 
              shading='nearest', transform=ccrs.PlateCarree(), cmap=cmap, vmin=levels[1], norm = mpl.colors.BoundaryNorm(levels, ncolors=cmap.N, 
                                                                                                                                    clip=False)) #alpha=0.5)

gls = ax.gridlines(draw_labels=True, dms = True, x_inline=False, y_inline=False, color="black", linestyle="dashed", zorder = 10)
gls.top_labels=False
gls.right_labels=False
gls.xlabel_style = {'size': 25, 'rotation': 0}
gls.xpadding = 15.0
gls.ylabel_style = {'size': 25}
plt.title("Mean Annual 75 km Grid Days ≥ 2.54 cm \n (GridRad)", fontsize = 40)
ax2 = plt.subplot(gs1[1, 0]) 
#bounds = [0, 10, 20, 30, 40, 50]
#norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
cb = plt.colorbar(mmp, cax = ax2, ticks = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100], orientation = 'horizontal', extend = 'max')
#cb.set_label('Days', fontsize = 30)
cb.ax.tick_params(labelsize = 35)
plt.subplots_adjust(bottom=0.05)
#plt.tight_layout()
#plt.savefig('/home/scratch/jgoodin/compare_wrf_mesh/wrf_75km_4cm_hail_days_gridlines_USE.png') #**Make sure to save results to png!**