## CESM2 - LARGE ENSEMBLE (LENS2)

- The purpose of this notebook is to visualize maps for the mixed layer for the winter months. 

### Imports

In [None]:
import xarray as xr
import xgcm
from xgcm import Grid
import pop_tools
from dask.distributed import Client, wait
from ncar_jobqueue import NCARCluster
import dask
import intake
import intake_esm
import cmocean
import dask
import numpy as np 
import matplotlib.pyplot as plt
import warnings, getpass, os
import cartopy.crs as ccrs
import numpy.ma as ma

### Dask

In [None]:
mem_per_worker = 3 # memory per worker in GB 
num_workers = 20 # number of workers
cluster = NCARCluster(cores=1, processes=1, memory=f'{mem_per_worker} GB',resource_spec=f'select=1:ncpus=1:mem={mem_per_worker}GB', walltime='1:00:00')
cluster.scale(num_workers)
client = Client(cluster)
print(client)
client

### Load

In [None]:
catalog = intake.open_esm_datastore(
    '/glade/collections/cmip/catalog/intake-esm-datastore/catalogs/glade-cesm2-le.json'
)
cat_subset = catalog.search(component='ocn',variable=['XMXL'],frequency='month_1')
# Load catalog entries for subset into a dictionary of xarray datasets
dset_dict_raw  = cat_subset.to_dataset_dict(zarr_kwargs={'consolidated': True}, storage_options={'anon': True})
print(f'\nDataset dictionary keys:\n {dset_dict_raw.keys()}')

In [None]:
ff=('cmip6','smbb')               # Forcings
fb=(['XMXL']) # Variable

ds_dict = dict()
for var in fb:
    # 1- combine historical and ssp370 (concatenate in time)
    ds_dict_tmp = dict()
    for scenario in ff:
        ds_dict_tmp[scenario] = xr.combine_nested([dset_dict_raw[f'ocn.historical.pop.h.{scenario}.{var}'], dset_dict_raw[f'ocn.ssp370.pop.h.{scenario}.{var}']],concat_dim=['time'])
        
        # 2- combine cmip6 and smbb (concatenate in member_id)
    ds_dict[var] = xr.combine_nested([ds_dict_tmp['cmip6'], ds_dict_tmp['smbb']], concat_dim=['member_id'])
    del ds_dict_tmp

In [None]:
pop_grid = pop_tools.get_grid('POP_gx1v7')
ds_dict['TLAT'] = pop_grid['TLAT']
ds_dict['TLONG'] = pop_grid['TLONG']

In [None]:
ds_dict['XMXL']['XMXL']

def is_jas(month):
        return (month >= 7) & (month <= 9)
mxd_JAS_present = (ds_dict['XMXL']['XMXL'].sel(time=is_jas(ds_dict['XMXL']['XMXL']['time.month']))).resample(
    time='1Y', closed='left').mean('time').sel(time=slice('1990-01-01','2020-12-31')).mean(dim=['time','member_id'])

mxd_JAS_future = (ds_dict['XMXL']['XMXL'].sel(time=is_jas(ds_dict['XMXL']['XMXL']['time.month']))).resample(
    time='1Y', closed='left').mean('time').sel(time=slice('2070-01-01','2100-12-31')).mean(dim=['time','member_id'])

def is_jfm(month):
        return (month >= 1) & (month <= 3)
mxd_JFM_present = (ds_dict['XMXL']['XMXL'].sel(time=is_jfm(ds_dict['XMXL']['XMXL']['time.month']))).resample(
    time='1Y', closed='left').mean('time').sel(time=slice('1990-01-01','2020-12-31')).mean(dim=['time','member_id'])

mxd_JFM_future = (ds_dict['XMXL']['XMXL'].sel(time=is_jfm(ds_dict['XMXL']['XMXL']['time.month']))).resample(
    time='1Y', closed='left').mean('time').sel(time=slice('2070-01-01','2100-12-31')).mean(dim=['time','member_id'])

In [None]:
mxd_JFM_future.coords['TLAT']=mxd_JFM_future.coords['TLAT'].fillna(0); mxd_JFM_future.coords['TLONG']=mxd_JFM_future.coords['TLONG'].fillna(0)
mxd_JFM_present.coords['TLAT']=mxd_JFM_future.coords['TLAT'].fillna(0); mxd_JFM_present.coords['TLONG']=mxd_JFM_present.coords['TLONG'].fillna(0)
mxd_JAS_future.coords['TLAT']=mxd_JFM_future.coords['TLAT'].fillna(0); mxd_JAS_future.coords['TLONG']=mxd_JAS_future.coords['TLONG'].fillna(0)
mxd_JAS_present.coords['TLAT']=mxd_JFM_future.coords['TLAT'].fillna(0); mxd_JAS_present.coords['TLONG']=mxd_JAS_present.coords['TLONG'].fillna(0)

In [None]:
plt.figure(figsize=(10,6));
ax = plt.axes(projection=ccrs.Robinson());
pc = (mxd_JAS_present*0.01).plot.pcolormesh(ax=ax,
                        vmax=500,vmin=0,
                        transform=ccrs.PlateCarree(),
                        x='TLONG',
                        y='TLAT',
                        cmap='gnuplot',
                        add_colorbar=True,
                        cbar_kwargs={"label": "Mixed Layer Depth [m]"},
                        )                                    
ax.gridlines(draw_labels=True); ax.coastlines(); ax.gridlines()
plt.savefig('mxd_JAS_present.png',dpi=300,bbox_inches='tight')

In [None]:
plt.figure(figsize=(10,6));
ax = plt.axes(projection=ccrs.Robinson());
pc = (mxd_JAS_future*0.01).plot.pcolormesh(ax=ax,
                        vmax=500,vmin=0,
                        transform=ccrs.PlateCarree(),
                        x='TLONG',
                        y='TLAT',
                        cmap='gnuplot',
                        add_colorbar=True,
                        cbar_kwargs={"label": "Mixed Layer Depth [m]"},
                        )                                    
ax.gridlines(draw_labels=True); ax.coastlines(); ax.gridlines()
plt.savefig('mxd_JAS_future.png',dpi=300,bbox_inches='tight')