In [7]:
import os
import glob

import functools
import numpy as np
import xarray as xr
import matplotlib
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import pandas as pd

import dask
from dask_jobqueue import PBSCluster
from dask.distributed import Client

from fates_calibration.train_emulators import get_pft_grids, get_pft_ensemble
from fates_calibration.FATES_calibration_constants import VAR_UNITS, FATES_INDEX, FATES_PFT_IDS, FATES_INDEX_new

In [8]:
def get_map(ds, da):
    
    thedir  = '/glade/u/home/forrest/ppe_representativeness/output_v4/'
    thefile = 'clusters.clm51_PPEn02ctsm51d021_2deg_GSWP3V1_leafbiomassesai_PPE3_hist.annual+sd.400.nc'
    sg = xr.open_dataset(thedir+thefile)
    
    out = np.zeros(sg.cclass.shape) + np.nan
    for c,(o,a) in enumerate(sg.rcent_coords):
        i = np.arange(400)[
            (abs(ds.grid1d_lat - a) < 0.1) &
            (abs(ds.grid1d_lon - o) < 0.1)]
        out[sg.cclass == c + 1] = i
    cclass = out.copy()
    cclass[np.isnan(out)] = 0

    sgmap = xr.Dataset()
    sgmap['cclass'] = xr.DataArray(cclass.astype(int), dims=['lat','lon'])
    sgmap['notnan'] = xr.DataArray(~np.isnan(out), dims=['lat','lon'])
    sgmap['lat'] = sg.lat
    sgmap['lon'] = sg.lon
    
    damap = da.sel(gridcell=sgmap.cclass).where(sgmap.notnan).compute()
    
    return damap

def annual_mean(da, cf):

    days_per_month = da['time.daysinmonth']
    ann_mean = cf*(days_per_month*da).groupby('time.year').sum()
    ann_mean.name = da.name

    return ann_mean

def get_colors():
    tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)]
    for i in range(len(tableau20)):
        r, g, b = tableau20[i]
        tableau20[i] = (r/255., g/255., b/255.)

    return tableau20

def get_minmax_ens(oaat_file, suffix, parameter):
    
    sub = oaat_file[oaat_file.parameter == parameter]
    min_ens = str(sub[sub.minmax == 'min'].ensemble.values[0]).replace(suffix, '')
    max_ens = str(sub[sub.minmax == 'max'].ensemble.values[0]).replace(suffix, '')

    if suffix == 'FATES_OAAT_':
        min_ens = str(min_ens).zfill(3)
        max_ens = str(max_ens).zfill(3)
    
    return min_ens, max_ens

def get_min_max_diff(hist_dir, ensemble_dir, oaat_file, suffix, file_suffix, parameter, variable, cf):

    min_ens, max_ens = get_minmax_ens(oaat_file, suffix, parameter)

    min_file = f'{file_suffix}{min_ens}.nc'
    max_file = f'{file_suffix}{max_ens}.nc'
    
    dataset_min = xr.open_dataset(os.path.join(hist_dir, ensemble_dir, min_file))
    dataset_max = xr.open_dataset(os.path.join(hist_dir, ensemble_dir, max_file))
    
    dat_min = dataset_min[variable]
    dat_max = dataset_max[variable]
    mean_da_min = annual_mean(dat_min, cf).mean(dim='year')
    mean_da_max = annual_mean(dat_max, cf).mean(dim='year')
    
    da_map_min = get_map(dataset_min, mean_da_min)
    da_map_max = get_map(dataset_max, mean_da_max)
    da_diff = da_map_max - da_map_min

    return da_diff.to_dataset(name=variable)

In [9]:
def get_pft_ensemble(land_mask_file, mesh_file, pfts, ensemble, fates_ind):
    
    pft_ensembles = []
    for pft in pfts:
        pft_grids = get_pft_grids(land_mask_file, mesh_file, fates_ind[pft])
        pft_ensembles.append(ensemble.where(ensemble.gridcell.isin(pft_grids)))
    pft_ensemble = xr.concat(pft_ensembles, dim='pft')

    return pft_ensemble

def plot_pft_ensemble(pft_ensemble, pfts):

    colors = get_colors()
    plt.subplots(figsize=(7, 7), layout='compressed')
    for idx, pft in enumerate(pfts):
        dat = pft_ensemble.isel(pft=idx)
        plt.scatter(dat.GPP, dat.EFLX_LH_TOT, label=pfts[idx], color=colors[idx])
    plt.ylabel('Annual Mean Latent Heat Flux (W/m2)', fontsize=11)
    plt.xlabel('Annual Mean GPP (kgC/m2/yr)', fontsize=11)
    plt.legend(loc='lower right')

In [10]:
def plot_hist_and_scatter(pft, pft_ensemble_gs0, pft_ensemble_gs1, xvar, yvar, xlab, ylab):
    
    gs0_col = '#298c8c'
    gs1_col = '#800074'

    dat_gs0 = pft_ensemble_gs0.sel(pft=pft).mean(dim='gridcell')
    dat_gs1 = pft_ensemble_gs1.sel(pft=pft).mean(dim='gridcell')
    dat_gs1 = dat_gs1.where(dat_gs1.BTRANMN > 0.0)

    fig = plt.figure(figsize=(6, 6))
    gs = fig.add_gridspec(2, 2,  width_ratios=(4, 1), height_ratios=(1, 4),
                          left=0.1, right=0.9, bottom=0.1, top=0.9,
                          wspace=0.05, hspace=0.05)
    
    ax = fig.add_subplot(gs[1, 0])

    ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
    ax_histy = fig.add_subplot(gs[1, 1], sharey=ax)
    ax_histx.tick_params(axis="x", labelbottom=False)
    ax_histy.tick_params(axis="y", labelleft=False)

    ax.scatter(dat_gs1[xvar], dat_gs1[yvar], label='gs1', color=gs1_col, alpha=0.3)
    ax.scatter(dat_gs0[xvar], dat_gs0[yvar], label='gs0', color=gs0_col, alpha=0.3)

    ax_histx.hist(dat_gs0[xvar].values.flatten(), color=gs0_col, alpha=0.5)
    ax_histx.hist(dat_gs1[xvar].values.flatten(), color=gs1_col, alpha=0.5)
    
    ax_histy.hist(dat_gs0[yvar].values.flatten(), color=gs0_col, alpha=0.5, orientation='horizontal')
    ax_histy.hist(dat_gs1[yvar].values.flatten(), color=gs1_col, alpha=0.5, orientation='horizontal')
    
    ax.legend(loc='upper right')
    ax_histx.set_title(dom_pfts[pft].replace('_', ' '))
    ax.set_xlabel(xlab, labelpad=5)
    ax.set_ylabel(ylab, labelpad=5, rotation=90);

In [11]:
def plot_effect(dat, parameter, variable, units):

    minval = abs(dat[variable].min())
    maxval = abs(dat[variable].max())
    vmax = np.max([minval, maxval])
    
    figure, ax = plt.subplots(1, 1, figsize=(13, 6),
                            subplot_kw=dict(projection=ccrs.Robinson()),
                            layout='compressed')
    ax.set_title(f"Effect of {parameter.replace('fates_', '')} on {variable}", loc='left', fontsize='large', fontweight='bold')
    ax.coastlines()
    ocean = ax.add_feature(cfeature.NaturalEarthFeature('physical', 'ocean', '110m',
                                                        facecolor='white'))
    pcm = ax.pcolormesh(dat.lon, dat.lat, dat[variable],
                        transform=ccrs.PlateCarree(), shading='auto',
                        cmap='RdBu', vmin=-1.0*vmax, vmax=vmax)
    cbar = figure.colorbar(pcm, ax=ax, shrink=0.5, orientation='horizontal')
    cbar.set_label(f'{variable} Difference ({units})', size=10, fontweight='bold')

In [12]:
mesh_dir = '/glade/work/afoster/FATES_calibration/mesh_files/'

In [None]:
ensemble_gs1 = xr.open_dataset('/glade/work/afoster/FATES_calibration/history_files/fates_lh_dominant_gs1.nc')

In [None]:
ensemble_gs0 = xr.open_dataset('/glade/work/afoster/FATES_calibration/history_files/fates_lh_dominant_gso_vcmax.nc')

In [None]:
dom_pfts = ['broadleaf_evergreen_tropical_tree', 'needleleaf_evergreen_extratrop_tree',
            'needleleaf_colddecid_extratrop_tree', 'arctic_c3_grass', 'cool_c3_grass',
            'c4_grass']

In [None]:
pft_ensemble_gs1 = get_pft_ensemble(os.path.join(mesh_dir, 'dominant_pft_grid_update.nc'),
                                    os.path.join(mesh_dir, 'dominant_pft_grid_update_mesh.nc'),
                                    dom_pfts, ensemble_gs1, FATES_INDEX_new)

In [None]:
pft_ensemble_gs0 = get_pft_ensemble(os.path.join(mesh_dir, 'dominant_pft_grid_update.nc'),
                                    os.path.join(mesh_dir, 'dominant_pft_grid_update_mesh.nc'),
                                    dom_pfts, ensemble_gs0, FATES_INDEX_new)

In [None]:
pft_grids = get_pft_grids(os.path.join(mesh_dir, 'dominant_pft_grid_update.nc'),
                          os.path.join(mesh_dir, 'dominant_pft_grid_update_mesh.nc'), FATES_INDEX_new['needleleaf_colddecid_extratrop_tree'])

In [None]:
plot_hist_and_scatter(0, pft_ensemble_gs0, pft_ensemble_gs1,
                      'BTRANMN', 'GPP', 'Mean BTRAN', 'Annual GPP (kgC m$^{-2}$ yr$^{-1}$)')
plt.savefig('/glade/u/home/afoster/FATES_Calibration/AGU_figures/BTRAN_GPP_BETT.png')

In [13]:
# Setup PBSCluster
cluster = PBSCluster(
    cores=1,                                                   # The number of cores you want
    memory='25GB',                                             # Amount of memory
    processes=1,                                               # How many processes
    queue='casper',                                            # The type of queue to utilize
    local_directory='/glade/work/afoster',                     # Use your local directory
    resource_spec='select=1:ncpus=1:mem=25GB',                 # Specify resources
    log_directory='/glade/derecho/scratch/afoster/dask_logs',  # log directory
    account='P93300041',                                       # Input your project ID here
    walltime='02:00:00',                                       # Amount of wall time
    interface='ext')                                           # Interface to use

cluster.scale(30)
dask.config.set({
    'distributed.dashboard.link': 'https://jupyterhub.hpc.ucar.edu/stable/user/{USER}/proxy/{port}/status'
})
client = Client(cluster)
client

Perhaps you already have a cluster running?
Hosting the HTTP server on port 36751 instead


0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/afoster/proxy/36751/status,

0,1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/afoster/proxy/36751/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://128.117.208.81:35481,Workers: 0
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/afoster/proxy/36751/status,Total threads: 0
Started: Just now,Total memory: 0 B


Task exception was never retrieved
future: <Task finished name='Task-158641' coro=<Client._gather.<locals>.wait() done, defined at /glade/work/afoster/conda-envs/fates_calibration/lib/python3.11/site-packages/distributed/client.py:2208> exception=AllExit()>
Traceback (most recent call last):
  File "/glade/work/afoster/conda-envs/fates_calibration/lib/python3.11/site-packages/distributed/client.py", line 2217, in wait
    raise AllExit()
distributed.client.AllExit
Task exception was never retrieved
future: <Task finished name='Task-158346' coro=<Client._gather.<locals>.wait() done, defined at /glade/work/afoster/conda-envs/fates_calibration/lib/python3.11/site-packages/distributed/client.py:2208> exception=AllExit()>
Traceback (most recent call last):
  File "/glade/work/afoster/conda-envs/fates_calibration/lib/python3.11/site-packages/distributed/client.py", line 2217, in wait
    raise AllExit()
distributed.client.AllExit
Task exception was never retrieved
future: <Task finished name

In [14]:
hist_dir_gs1 = '/glade/work/afoster/FATES_calibration/history_files/fates_lh_dominant_gs1'
files_gs1 = sorted([os.path.join(hist_dir_gs1, file) for file in os.listdir(hist_dir_gs1)])

hist_dir_gs0 = '/glade/work/afoster/FATES_calibration/history_files/fates_lh_dominant_gso_vcmax'
files_gs0 = sorted([os.path.join(hist_dir_gs0, file) for file in os.listdir(hist_dir_gs0)])

In [None]:
ds_gs1 = xr.open_mfdataset(files_gs1, combine='nested', concat_dim='ensemble', parallel=True)
sub_gs1 = ds_gs1.isel(gridcell=364)

In [None]:
#sub_gs1['GPP_month'] = 24*60*60*(ds_gs1['time.daysinmonth']*sub_gs1.GPP)
sub_gs1['GPP_month'] = sub_gs1.GPP/1E-6/12.011*1000.0
sub_gs1['WUE'] = sub_gs1['GPP_month']/sub_gs1.EFLX_LH_TOT.where(sub_gs1.EFLX_LH_TOT > 0)

In [None]:
wue=sub_gs1.WUE.to_dataset(name='WUE')

In [None]:
wue.to_netcdf('/glade/work/afoster/FATES_calibration/history_files/gs1_sub.nc')

In [16]:
ds_gs0 = xr.open_mfdataset(files_gs0, combine='nested', concat_dim='ensemble', parallel=True)
sub_gs0 = ds_gs0.isel(gridcell=364)

In [17]:
sub_gs0['GPP_month'] = sub_gs0.GPP/1E-6/12.011*1000.0
sub_gs0['WUE'] = sub_gs0['GPP_month']/sub_gs0.EFLX_LH_TOT.where(sub_gs0.EFLX_LH_TOT > 0)

In [18]:
wue=sub_gs0.WUE.to_dataset(name='WUE')

In [23]:
sub_gs0.grid1d_lat.isel(ensemble=1).values

array(59.68421053)

In [24]:
sub_gs0.grid1d_lon.isel(ensemble=1).values

array(135.)

In [19]:
wue.to_netcdf('/glade/work/afoster/FATES_calibration/history_files/gs0_sub.nc')

This may cause some slowdown.
Consider scattering data ahead of time and using futures.


In [None]:
gs1_mean = sub_gs1.isel(ensemble=slice(1, 501)).WUE.mean(dim='ensemble')
gs0_mean = sub_gs0.isel(ensemble=slice(1, 501)).WUE.mean(dim='ensemble')

In [None]:
time = sub_gs0.time

In [None]:
plt.plot(time, gs1_mean, color=gs1_col, label='gs1')
plt.plot(time, gs0_mean, color=gs0_col, label='gs0')

In [None]:
fig = plt.figure(figsize=(12, 12))
quantiles = [1, 5, 25]
colors = [0.9, 0.7, 0.5]
labels = ['1st-99th percentile', '5th-95th percentile', '25th-75th percentile']

for i, quantile in enumerate(quantiles):
    q1 = sub_gs1.WUE.isel(ensemble=slice(1, 501)).chunk(dict(ensemble=-1, time=-1)).quantile(quantile/100, dim='ensemble')
    q2 = sub_gs1.WUE.isel(ensemble=slice(1, 501)).chunk(dict(ensemble=-1, time=-1)).quantile(1 - quantile/100, dim='ensemble')
    plt.fill_between(time, q1, q2, color=colors[i]*np.ones(3), label=labels[i])

In [None]:
client.shutdown()

In [None]:
clm_oaat = pd.read_csv('/glade/work/afoster/FATES_calibration/parameter_files/clm6sp_oaat.csv', header=None)
clm_oaat.columns = ['ensemble', 'parameter', 'minmax']
fates_oaat = pd.read_csv('/glade/work/afoster/FATES_calibration/parameter_files/fates_param_oaat/fates_oaat_key.csv', index_col=[0])
fates_oaat.columns = ['ensemble', 'minmax', 'parameter']

In [None]:
hist_dir = '/glade/work/afoster/FATES_calibration/history_files'

clm_hydro_ensemble_dir = 'ctsm_sp_oaat'
clm_hydro_suffix = 'ctsm60SP_bigleaf_sparsegrid_CLM6SPoaat'

clm_btran_ensemble_dir = 'ctsm_sp_oaat_btran'
clm_btran_suffix = 'ctsm60SP_bigleaf_sparsegrid_btran_CLM6SPoaat'

fates_clm_ensemble_dir = 'fates_clmpars_sp_oaat'
fates_clm_suffix = 'ctsm60SP_fates_sparsegrid_CLM6SPoaat'

fates_ensemble_dir = 'fates_sp_oaat'
fates_suffix = 'ctsm60SP_fates_sparsegrid_FATES_OAAT_'

In [None]:
fates_param = 'smpsc_delta'
fates_btran = get_min_max_diff(hist_dir, fates_ensemble_dir, fates_oaat,
                             'FATES_OAAT_', fates_suffix, fates_param, 'BTRANMN', 1/365)

In [None]:
plot_effect(fates_btran, fates_param, 'BTRANMN', '')
#plt.savefig('/glade/u/home/afoster/FATES_Calibration/AGU_figures/gpp_rad_leaf_clumping.png', dpi=300)

In [None]:
param = 'fff'
fates_btran_fff = get_min_max_diff(hist_dir, fates_clm_ensemble_dir, clm_oaat,
                             'CLM6SPoaat', fates_clm_suffix, param, 'BTRANMN', 1/365)
plot_effect(fates_btran_fff, param, 'BTRANMN', '')

In [None]:
param = 'smpso'
clm_btran_fff = get_min_max_diff(hist_dir, clm_btran_ensemble_dir, clm_oaat,
                             'CLM6SPoaat', clm_btran_suffix, param, 'BTRANMN', 1/365)
plot_effect(clm_btran_fff, param, 'BTRANMN', '')