## CESM2 - LARGE ENSEMBLE (LENS2)

- This notebook aims to compute the average of essential terms for incident solar radiation in the South Atlantic, such as cloud cover fraction and thickness. 

### Imports

In [None]:
import xarray as xr
import pandas as pd
import numpy as np 
import dask
import cf_xarray
import intake
import cftime
import nc_time_axis
import intake_esm
import matplotlib.pyplot as plt
import pop_tools
from dask.distributed import Client, wait
from ncar_jobqueue import NCARCluster
import warnings, getpass, os
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1 import make_axes_locatable
import cartopy.crs as ccrs
import cmocean
import dask
from matplotlib.offsetbox import AnchoredText
from matplotlib.pyplot import figure

### Dask

In [None]:
mem_per_worker = 20 # memory per worker in GB 
num_workers = 40 # number of workers
cluster = NCARCluster(cores=1,
                      processes=1,
                      memory=f'{mem_per_worker} GB',
                      resource_spec=f'select=1:ncpus=1:mem={mem_per_worker}GB',
                      walltime='1:00:00',
                      log_directory='./dask-logs',
                     )
cluster.scale(num_workers)
client = Client(cluster)
print(client)
client

### Read the data

In [None]:
catalog = intake.open_esm_datastore(
    '/glade/collections/cmip/catalog/intake-esm-datastore/catalogs/glade-cesm2-le.json'
)

#### Ocean Component

In [None]:
%%time

all_vars = ['TGCLDLWP','FSDS','FSNS','CLDTOT','SST','OCNFRAC']
cat_subset = catalog.search(component='atm',variable=all_vars,frequency='month_1')
# Load catalog entries for subset into a dictionary of xarray datasets
dset_dict_raw  = cat_subset.to_dataset_dict(zarr_kwargs={'consolidated': True}, storage_options={'anon': True})#, xarray_open_kwargs=('chunks': {'':}))
print(f'\nDataset dictionary keys:\n {dset_dict_raw.keys()}')

In [None]:
%%time

# Concatenation of variables
ff=('cmip6','smbb')                      # Forcings
ds_dict = dict()
for var in all_vars:
    # 1- combine historical and ssp370 (concatenate in time)
    ds_dict_tmp = dict()
    for scenario in ff:
        ds_dict_tmp[scenario] = xr.combine_nested([dset_dict_raw[f'atm.historical.cam.h0.{scenario}.{var}'], dset_dict_raw[f'atm.ssp370.cam.h0.{scenario}.{var}']],concat_dim=['time'])
        
        # 2- combine cmip6 and smbb (concatenate in member_id)
    ds_dict[var] = xr.combine_nested([ds_dict_tmp['cmip6'], ds_dict_tmp['smbb']], concat_dim=['member_id'])
    del ds_dict_tmp

### Mask 
- We need to mask the data of the atmospheric component over the continent

###### 1. Replace the SST data equal to 0 (continents) by NaN

In [None]:
ds_dict['SST']['SST'].isel(time=0,member_id=0).plot()

In [None]:
ds_dict['SST']=ds_dict['SST'].where(ds_dict['SST'] != 0.)
ds_dict['SST']['SST'].isel(time=0,member_id=0).plot()

###### 2. Building the mask. The ocean model and the atmospheric model feed data into the coastal region, so we need to take data that is 100% on the ocean model grid to ensure that we are not looking at data over the continent. 

In [None]:
ds_dict['OCNFRAC']['OCNFRAC'].isel(time=0,member_id=0).plot()

###### 2.2 Since we are not working with polar regions, we will put NAN on all data that is different from 1 to build the mask

In [None]:
ds_dict['OCNFRAC']=ds_dict['OCNFRAC'].where(ds_dict['OCNFRAC'] ==1.)
ds_dict['OCNFRAC']['OCNFRAC'].isel(time=0,member_id=0).plot()

In [None]:
mask_ocean = 2 * np.ones((ds_dict['OCNFRAC'].dims['lat'], ds_dict['OCNFRAC'].dims['lon'])) * np.isfinite(ds_dict['OCNFRAC'].OCNFRAC.isel(time=0,member_id=0))  
mask_land = 1 * np.ones((ds_dict['OCNFRAC'].dims['lat'], ds_dict['OCNFRAC'].dims['lon'])) * np.isnan(ds_dict['OCNFRAC'].OCNFRAC.isel(time=0,member_id=0))  
mask_array = mask_ocean + mask_land
mask_array.plot()

###### 3. Applying the mask for the other variables

In [None]:
ds_dict['TGCLDLWP'] =ds_dict['TGCLDLWP'].where(mask_array == 2.)  
ds_dict['FSDS'] =ds_dict['FSDS'].where(mask_array == 2.) 
ds_dict['CLDTOT'] =ds_dict['CLDTOT'].where(mask_array == 2.)  
ds_dict['FSNS'] =ds_dict['FSNS'].where(mask_array == 2.)  

In [None]:
ds_dict['FSDS']['FSDS'].isel(member_id=0,time=0).plot()

In [None]:
ds_dict['FSNS']['FSNS'].isel(member_id=0,time=0).plot()

In [None]:
ds_dict['CLDTOT']['CLDTOT'].isel(member_id=0,time=0).plot()

In [None]:
ds_dict['TGCLDLWP']['TGCLDLWP'].isel(member_id=0,time=0).plot()  

### Cut and center the variable in the South Atlantic

In [None]:
%%time
# Cutting out and centering the variables in the South Atlantic
dask.config.set({"array.slicing.split_large_chunks": True})

# Ocean component
ilon1, flon1, ilon2, flon2 = 245, 288, 0, 17 # longitude (initial (i), final (f)) 
ilan=0
ilas=-34
fb=['TGCLDLWP','FSDS','FSNS','CLDTOT']
for var in fb:
    if var not in ds_dict:
        continue
    ds_dict[f'{var}']=xr.combine_nested([[
        ds_dict[f'{var}'].where((ds_dict[f'{var}'].lat >= ilas) & (ds_dict[f'{var}'].lat <= ilan), drop=True).isel(
            lon = slice(ilon1,flon1)),
        ds_dict[f'{var}'].where((ds_dict[f'{var}'].lat >= ilas) & (ds_dict[f'{var}'].lat <= ilan), drop=True).isel(
            lon = slice(ilon2,flon2))]],
        concat_dim=['lat','lon'])
    ds_dict[f'{var}'].coords['lon'] = (ds_dict[f'{var}'].coords['lon'] + 180) % 360 - 180 
    ds_dict[f'{var}'] = ds_dict[f'{var}'].sortby(ds_dict[f'{var}'].lon)

In [None]:
ds_dict['FSDS']['FSDS'].isel(member_id=0,time=0).plot()

In [None]:
ds_dict['TGCLDLWP']['TGCLDLWP'].isel(member_id=0,time=0).plot()

In [None]:
ds_dict['FSNS']['FSNS'].isel(member_id=0,time=0).plot()

In [None]:
ds_dict['CLDTOT']['CLDTOT'].isel(member_id=0,time=0).plot()

In [None]:
ds_CLDTOT=ds_dict['CLDTOT']['CLDTOT'].mean(dim=['member_id','lon','lat']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31'))
ds_FSNS=ds_dict['FSNS']['FSNS'].mean(dim=['member_id','lon','lat']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31'))
ds_TGCLDLWP=ds_dict['TGCLDLWP']['TGCLDLWP'].mean(dim=['member_id','lon','lat']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31'))
ds_FSDS=ds_dict['FSDS']['FSDS'].mean(dim=['member_id','lon','lat']).resample(time='1Y', closed='left').mean('time').sel(time=slice('1960-01-01','2100-12-31'))

In [None]:
def calculate_ticks(ax, ticks, round_to=0, center=True):
    upperbound = np.ceil(ax.get_ybound()[1]/round_to)
    lowerbound = upperbound-0.84
    #lowerbound = np.floor(ax.get_ybound()[0]/round_to)
    dy = upperbound - lowerbound
    fit = np.floor(dy/(ticks - 1)) + 1
    dy_new = (ticks - 1)*fit
    if center:
        offset = np.floor((dy_new - dy)/2)
        lowerbound = lowerbound - offset
    values = np.linspace(lowerbound, lowerbound + dy_new, ticks)
    return values*round_to

In [None]:
def myfunc(x):
    return slope * x + intercept

In [None]:
from scipy import stats

In [None]:
CLDTOT_sts=ds_CLDTOT.sel(time=slice('2015-01-01','2100-12-31'))
x=np.squeeze(range(0,len(CLDTOT_sts)))
slope, intercept, r, p, std_err = stats.linregress(x, CLDTOT_sts)
mymodel_CLDTOT_sts = list(map(myfunc, x))
mymodel_CLDTOT_sts=mymodel_CLDTOT_sts
m_CLDTOT_sts=slope*10 # per decade
p_CLDTOT_sts=p
r_CLDTOT_sts=r*r

FSNS_sts=ds_FSNS.sel(time=slice('2015-01-01','2100-12-31'))
x=np.squeeze(range(0,len(FSNS_sts)))
slope, intercept, r, p, std_err = stats.linregress(x, FSNS_sts)
mymodel_FSNS_sts = list(map(myfunc, x))
mymodel_FSNS_sts=mymodel_FSNS_sts
m_FSNS_sts=slope*10 # per decade
p_FSNS_sts=p
r_FSNS_sts=r*r
             
TGCLDLWP_sts=ds_TGCLDLWP.sel(time=slice('2015-01-01','2100-12-31'))
x=np.squeeze(range(0,len(TGCLDLWP_sts)))
slope, intercept, r, p, std_err = stats.linregress(x, TGCLDLWP_sts)
mymodel_TGCLDLWP_sts = list(map(myfunc, x))
mymodel_TGCLDLWP_sts=mymodel_TGCLDLWP_sts
m_TGCLDLWP_sts=slope*10 # per decade
p_TGCLDLWP_sts=p
r_TGCLDLWP_sts=r*r

FSDS_sts=ds_FSDS.sel(time=slice('2015-01-01','2100-12-31'))
x=np.squeeze(range(0,len(FSDS_sts)))
slope, intercept, r, p, std_err = stats.linregress(x, FSDS_sts)
mymodel_FSDS_sts = list(map(myfunc, x))
mymodel_FSDS_sts=mymodel_FSDS_sts
m_FSDS_sts=slope*10 # per decade
p_FSDS_sts=p
r_FSDS_sts=r*r

In [None]:
p_FSNS_sts

In [None]:
letts=['A','B','C','D']
fig, axs = plt.subplots(1,4, figsize=(25, 7))
ds_CLDTOT.plot(ax=axs[0],label=None, linewidth=1,color='blue')
ds_FSNS.plot(ax=axs[1],label=None, linewidth=1,color='red')
ds_TGCLDLWP.plot(ax=axs[2],label=None, linewidth=1,color='maroon')
ds_FSDS.plot(ax=axs[3],label=None, linewidth=1,color='green')
axs[0].plot(ds_CLDTOT.sel(time=slice('2015-01-01','2100-12-31')).coords['time'],mymodel_CLDTOT_sts,color='blue',linestyle='dashed')
axs[0].annotate(f'{m_CLDTOT_sts:.4f} Fraction per decade', xy=(0.009, 0.95), color='blue',fontsize=20,xycoords=axs[0].transAxes)
axs[1].plot(ds_FSNS.sel(time=slice('2015-01-01','2100-12-31')).coords['time'],mymodel_FSNS_sts,color='red',linestyle='dashed')
axs[1].annotate(f'{m_FSNS_sts:.4f} W/m2 per decade', xy=(0.1, 0.95), color='red',fontsize=20,xycoords=axs[1].transAxes)
axs[2].plot(ds_TGCLDLWP.sel(time=slice('2015-01-01','2100-12-31')).coords['time'],mymodel_TGCLDLWP_sts,color='maroon',linestyle='dashed')
axs[2].annotate(f'{m_TGCLDLWP_sts:.4f} kg/m2 per decade', xy=(0.06, 0.95), color='maroon',fontsize=20,xycoords=axs[2].transAxes)
axs[3].plot(ds_FSDS.sel(time=slice('2015-01-01','2100-12-31')).coords['time'],mymodel_FSDS_sts,color='green',linestyle='dashed')
axs[3].annotate(f'{m_FSDS_sts:.4f} W/m2 per decade', xy=(0.1, 0.95), color='green',fontsize=20,xycoords=axs[3].transAxes)
for i in range(len(axs)):
    axs[i].grid(color='gray', linestyle='-', linewidth=0.7)
    axs[i].set_xlabel('Time [Years]',fontsize=16) 
    axs[i].tick_params(axis='x', labelsize=16); axs[i].tick_params(axis='y', labelsize=16)
    at = AnchoredText(letts[i], prop=dict(size=20), frameon=True, loc='lower left'); at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
    axs[i].add_artist(at)
#    axs[i].set_yticks(calculate_ticks(axs[i], 8))
    
axs[0].set_ylabel('CLDTOT [Fraction]',fontsize=16)
axs[1].set_ylabel('FSNS [W/m2]',fontsize=16)
axs[2].set_ylabel('TGCLDLWP [kg/m2]',fontsize=16)
axs[3].set_ylabel('FSDS [W/m2]',fontsize=16)

plt.subplots_adjust(wspace=0.3)
plt.savefig('clauds.png',dpi=300,bbox_inches='tight')
plt.show()