In [None]:
import sys
import xarray as xr
import numpy as np
import pandas as pd
import math
import glob
import yaml
import geopandas
import cartopy
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.colorbar import Colorbar # different way to handle colorbar
import matplotlib.ticker as mticker
import cmocean.cm as cmo

# cartopy
import cartopy.crs as ccrs
from cartopy.mpl.geoaxes import GeoAxes
import cartopy.feature as cfeature
import dask

# import personal modules
# Path to modules
sys.path.append('../modules')
# Import my modules
from utils import roundPartial
from plotter import draw_basemap, plot_terrain
from colorline import colorline

dask.config.set(**{'array.slicing.split_large_chunks': True})

In [None]:
server='expanse'
if server == 'comet':
    path_to_data = '/data/projects/Comet/cwp140/'
elif server == 'expanse':
    path_to_data = '/expanse/lustre/scratch/dnash/temp_project/'   
path_to_out  = '../out/'       # output files (numerical results, intermediate datafiles) -- read & write
path_to_figs = '../figs/'      # figures

In [None]:
config_file = '../preprocess/calculate_trajectories/config_1.yaml'
job_info = 'job_1' # this is the job name

config = yaml.load(open(config_file), Loader=yaml.SafeLoader) # read the file
ddict = config[job_info] # pull the job info from the dict
HUC8_ID = ddict['HUC8_ID']

fname = path_to_data + 'preprocessed/PRISM/PRISM_HUC8_CO.nc'
ds = xr.open_dataset(fname)
# get list of event dates from first HUC8
ds = ds.sel(HUC8=HUC8_ID)
ds = ds.where(ds.extreme == 1, drop=True)
event_dates = ds.date.values
nevents = len(event_dates)

In [None]:
## append filenames to a list
print('Gathering filenames ...')
fname_lst = []
for i, dt in enumerate(event_dates):
    ts = pd.to_datetime(str(dt)) 
    d = ts.strftime("%Y%m%d")
    fname = path_to_data + 'preprocessed/ERA5_trajectories/PRISM_HUC8_{0}_{1}.nc'.format(HUC8_ID, d)
    fname_lst.append(fname)

In [None]:
## open all files for current HUC8
# final_ds = xr.open_mfdataset(fname_lst, combine='nested', concat_dim=pd.Index(event_dates[9:], name="start_date"), engine='netcdf4')
# final_ds
ds_lst = []
for i, fname in enumerate(fname_lst):
    ds = xr.open_dataset(fname)
    ds_lst.append(ds)

## save all trajectories for current HUC8 as single netcdf
final_ds = xr.concat(ds_lst, pd.Index(event_dates, name="start_date"))
final_ds

In [None]:
# out_fname = '/expanse/nfs/cw3e/cwp140/preprocessed/ERA5_trajectories/PRISM_HUC8_{0}.nc'.format(HUC8_ID)
# final_ds.to_netcdf(path=out_fname, mode = 'w', format='NETCDF4')

In [None]:
# select DJF
DJF = final_ds.sel(start_date=final_ds.start_date.dt.season=="DJF")
MAM = final_ds.sel(start_date=final_ds.start_date.dt.season=="MAM")
JJA = final_ds.sel(start_date=final_ds.start_date.dt.season=="JJA")
SON = final_ds.sel(start_date=final_ds.start_date.dt.season=="SON")

ds_lst = [DJF, MAM, JJA, SON]

In [None]:
min_q = final_ds.q.min().values
max_q = final_ds.q.max().values
print(np.round(min_q), np.round(max_q))

In [None]:
ext = [-130., -100., 10., 45.] 
fmt = 'png'

# Set up projection
datacrs = ccrs.PlateCarree()  ## the projection the data is in
mapcrs = ccrs.PlateCarree() ## the projection you want your map displayed in

# Set tick/grid locations
tx = 10
ty = 5
dx = np.arange(ext[0],ext[1]+tx,tx)
dy = np.arange(ext[2],ext[3]+ty,ty)

nrows = 2
ncols = 3

## Use gridspec to set up a plot with a series of subplots that is
## n-rows by n-columns
gs = GridSpec(nrows, ncols, height_ratios=[1, 1], width_ratios = [1, 1, 0.05], wspace=0.1, hspace=0.1)
## use gs[rows index, columns index] to access grids


fig = plt.figure(figsize=(6.5 ,7.0))
fig.dpi = 600
fname = path_to_figs + 'test_trajectory'

# enumerate through seasons
ssn = ['DJF', 'MAM', 'JJA', 'SON']
row_idx = [0, 0, 1, 1]
col_idx = [0, 1, 0, 1]
for k, ds in enumerate(ds_lst):
    ax = fig.add_subplot(gs[row_idx[k], col_idx[k]], projection=mapcrs)
    ax.set_title(ssn[k], loc='left', fontsize=10)
    ax = draw_basemap(ax, extent=ext, xticks=dx, yticks=dy,left_lats=True, right_lats=False)
    ax.set_extent(ext, datacrs)
    
    ax.add_feature(cfeature.STATES, edgecolor='0.4', linewidth=0.8)
    
    # need this to fix annotate transform
    transform = datacrs._as_mpl_transform(ax)
    nevents = len(ds.start_date)
    ## Add different points
    for i in range(nevents):
        data = ds.isel(start_date=i)
        y_lst = data.latitude.values
        x_lst = data.longitude.values
        z_lst = data.q.values
        ax.plot(x_lst, y_lst, c='gray', transform=datacrs, alpha=0.2)
        cf = ax.scatter(x_lst, y_lst, c=z_lst, vmin=0, vmax=12, cmap=cmo.deep, marker='.', transform=datacrs, alpha=0.7, s=6)

# Add color bar
cbax = plt.subplot(gs[:,2]) # colorbar axis
cb = Colorbar(ax = cbax, mappable = cf, orientation = 'vertical', ticklocation = 'right')
cb.set_label('Specific Humidity (g kg$^{-1}$)', fontsize=11)
cb.ax.tick_params(labelsize=12)

fig.savefig('%s.%s' %(fname, fmt), bbox_inches='tight', dpi=fig.dpi, transparent=True)
plt.show()
fig.clf()