### test for plotting pft level data on h1 files

In [1]:
import os, sys
import shutil
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import xarray as xr
import xesmf as xe

# Helpful for plotting only
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import uxarray as ux  #need npl 2024a or later
import geoviews.feature as gf

#sys.path.append('/glade/u/home/wwieder/python/adf/lib/plotting_functions.py')
from plotting_functions import *

In [2]:
# Load datataset  
# TODO, develop function for this too
gppfile='/glade/derecho/scratch/wwieder/ADF/b.e30_beta04.BLT1850.ne30_t232_wgx3.121/climo/b.e30_beta04.BLT1850.ne30_t232_wgx3.121_GPP_climo.nc'
laih1file='/glade/derecho/scratch/wwieder/ctsm53n04ctsm52028_ne30pg3t232_hist.clm2.h1.TLAI.1860s.nc'
case = 'ctsm53n04ctsm52028_ne30pg3t232_hist'
mesh0 = '/glade/campaign/cesm/cesmdata/inputdata/share/meshes/ne30pg3_ESMFmesh_cdf5_c20211018.nc'
#ux file for plotting
ds0 = ux.open_dataset(mesh0, gppfile)
ds1 = ux.open_dataset(mesh0, laih1file)

#xr files for manipulations
ds0b = xr.open_dataset(gppfile, decode_times=True)
ds = xr.open_dataset(laih1file, decode_times=True)

In [3]:
ds0b

In [4]:
# Missing coords for lndgrid, add them here and change name
ds0b = ds0b.rename({'lndgrid': 'n_face'})
ds0b['n_face'] = np.arange(1,(ds.pfts1d_ixy.values.max().astype(int)+1))
ds

In [5]:
# select a single PFT
## TODO this step is kind of a memory hog
npft=16
var='TLAI'
for i in range(1,npft):
    print('starting pft '+str(i))
    temp = ds.where(ds.pfts1d_itype_veg==i, drop=True).max('time')
    # TODO, pft_weights should be time evolving, but not currently done
    # Rename coord, since the pft dimension is not meaningful
    temp = temp.rename({'pft': 'n_face'})

    # assign values from pfts1d_ixy to n_face
    temp['n_face'] = temp.pfts1d_ixy.values.astype(int)
    temp.assign_coords({"npft": i})
    # combine along PFT variable
    if i == 1:
        dsOut = temp
    else:
        dsOut = xr.concat([dsOut, temp], dim="npft")
dsOut

starting pft 1
starting pft 2
starting pft 3
starting pft 4
starting pft 5
starting pft 6
starting pft 7
starting pft 8
starting pft 9
starting pft 10
starting pft 11
starting pft 12
starting pft 13
starting pft 14
starting pft 15


In [6]:
# align subset pft output with plotting data array
target = ds0b.GPP.isel(time=0)
AlignOut, target = xr.align(dsOut, target, join="right")

In [7]:
dsplot = ds0.max('time')
dsplot[var] = AlignOut[var]
dsplot['pfts1d_wtgcell'] = AlignOut['pfts1d_wtgcell']
dsplot

In [8]:
pft_names = ['NET Temperate', 'NET Boreal', 'NDT Boreal',
             'BET Tropical', 'BET Temperate', 'BDT Tropical',
             'BDT Temperate', 'BDT Boreal', 'BES Temperate',
             'BDS Temperate', 'BDS Boreal', 'C3 Grass Arctic',
             'C3 Grass', 'C4 Grass', 'UCrop UIrr']

In [38]:
transform = ccrs.PlateCarree()
proj = ccrs.PlateCarree()
cmap = plt.cm.viridis_r
cmap.set_under(color='deeppink')
cmap = cmap.resampled(7)
levels = [0.1, 1, 2, 3, 4, 5, 6,7]

# create figure object
fig, axs = plt.subplots(5,3,
    facecolor="w",
    constrained_layout=True,
    subplot_kw=dict(projection=proj) )
axs=axs.flatten()

# Loop over pfts
for i in range((npft-1)):
    ac = dsplot[var].isel(npft=i).to_polycollection(projection=proj)
    ac.set_cmap(cmap)
    ac.set_antialiased(False)
    ac.set_transform(transform)
    ac.set_clim(vmin=0.1,vmax=6.9)
    axs[i].add_collection(ac)
    #cbar = plt.colorbar(ac, ax=axs[i], orientation='vertical', pad=0.05, shrink=0.8)

    #Titles, statistics
    wgts = dsplot.area * dsplot.landfrac * dsplot.pfts1d_wtgcell.isel(npft=i)
    wgts = wgts / wgts.sum()
    mean = str(np.round((dsplot[var].isel(npft=i)*wgts).sum().values,2))
    dead = ((dsplot[var].isel(npft=i)<0.1)*wgts).sum()
    live = ((dsplot[var].isel(npft=i)>0.1)*wgts).sum()
    livefrac = str(np.round((live/(live+dead)).values,2))
    axs[i].set_title(pft_names[i], loc='left',size=6)
    axs[i].text(-30, -45,'mean = '+ mean, fontsize=5)
    axs[i].text(-45, -60,'live frac = '+livefrac,fontsize=5)

for a in axs:
    a.coastlines()
    a.set_global()
    a.spines['geo'].set_linewidth(0.1) #cartopy's recommended method
    a.set_extent([-180, 180, -65, 86])

#fig.subplots_adjust(right=0.97)
cbar_ax = fig.add_axes([0.92, 0.05, 0.02, 0.8])
fig.colorbar(ac, cax=cbar_ax, pad=0.05, shrink=0.8, aspect=40,
            extend='both')
fig.suptitle("max LAI "+ case,size='medium')
fig.set_layout_engine("compressed")

fig.savefig('h1_test', bbox_inches='tight', dpi=300)
print('-- wrote pft '+var+' figure --')
plt.show()

-- wrote pft TLAI figure --
