In [None]:
import xarray as xr


In [None]:
if snakemake.input.get('cdncvi',None) is not None:
    nd = xr.open_dataset(snakemake.input.cdncvi)
mask = xr.open_dataset(snakemake.input.mask)
config = snakemake.config 
if snakemake.rule == 'make_dust_cloud_diag_file_IPSL':
    mod_id = 'IPSL-CM6A-LR-INCA'
else:
    mod_id = snakemake.wildcards.model
exp_id = snakemake.wildcards.experiment
time_slice = config.get('time_slice', {'start':2, 'end': None})

time_slice = slice(time_slice['start'], time_slice['end'])
if snakemake.config['model_specific_variant'].get(exp_id, None):
    memb_id = snakemake.config['model_specific_variant'][exp_id].get(mod_id, snakemake.config['variant_default'])
else:
    memb_id = snakemake.config['variant_default']

cld_def = config['cloud_def']

In [None]:
import intake
from workflow.scripts.utils import resample_time
from multiprocessing.pool import ThreadPool
import dask
dask.config.set(pool=ThreadPool(4))

tab_id = 'Amon'
amon_diag = ['clivi','clwvi', 'clt', 'pr']
if mod_id == 'GFDL-ESM4':
    amon_4d_diag = ['cl', 'clw', 'ps'] 
elif mod_id == 'UKESM1-0-LL':
    amon_4d_diag = ['cl', 'clw', 'pfull'] 
    tab_id = ['Amon','AERmon']
else:
    amon_4d_diag = ['cl', 'clw']
    
cat = intake.open_esm_datastore(snakemake.input.catalog)

cat_exp = cat.search(experiment_id=exp_id, source_id=mod_id, member_id=memb_id)



In [None]:
dict_amon = cat_exp.search(variable_id=amon_diag, table_id='Amon').to_dataset_dict(aggregate=True
                                                                                          ,preprocess=resample_time)

dict_amon_3d = cat_exp.search(variable_id=amon_4d_diag, table_id=tab_id).to_dataset_dict(aggregate=True,preprocess=resample_time,
                                                                                          skip_on_error=False)
if mod_id in ['UKESM1-0-LL', 'IPSL-CM6A-LR-INCA']:
    ds_ta = None
else:
    dict_ta = cat_exp.search(variable_id='ta', table_id='CFmon').to_dataset_dict(aggregate=True,preprocess=resample_time,
                                                                                          skip_on_error=False)
    ds_ta = dict_ta[list(dict_ta.keys())[0]].squeeze()
ds_amon = dict_amon[list(dict_amon.keys())[0]].squeeze()
ds_amon_3d = dict_amon_3d[list(dict_amon_3d.keys())[0]].squeeze()


In [None]:
if len(ds_amon.time) > 31:
    ds_amon = ds_amon.isel(time=slice(1,31))
    ds_amon_3d = ds_amon_3d.isel(time=slice(1,31))
    if ds_ta is not None:
        ds_ta = ds_ta.isel(time=slice(1,31))
    nd = nd.isel(time=slice(1,31))

In [None]:
from workflow.scripts.utils import model_levels_to_pressure_levels
import numpy as np
from pyclim_noresm.general_util_funcs import global_avg

In [None]:
def calc_cld_lev(ds_cl,out_ds, height_def, vname='cl', ps=None, ds_pfull=None):
    ds_cl.attrs['variable_id'] = vname
    if ps is not None:
        ds_cl = ds_cl.assign_coords(ps=ps)
    if ds_cl.source_id == 'IPSL-CM6A-LR-INCA':
        rg_data = ds_cl.rename(lev='plev')
    elif ds_cl.source_id == 'UKESM1-0-LL':
        pass

    else:
        rg_data = model_levels_to_pressure_levels(ds_cl).compute()
    
    for h, hdef in height_def.items():
        if ds_cl.source_id == 'UKESM1-0-LL':
            print(h)
            pmask = xr.where((hdef['bottom']*100 > ds_pfull) &  (ds_pfull > hdef['top']*100),True, False) 
            temp_rg_data = ds_cl.where(pmask).mean(dim='lev')
        else:
            if hdef['top']:
                t = hdef['top']*100
            if hdef['bottom']:
                b = hdef['bottom']*100
            temp_rg_data = rg_data.sel(plev=slice(b, t)).mean(dim='plev')
        temp_rg_data[vname].attrs['comment'] = f'Percentage {h} cloud cover, including both large-scale and convective cloud.'
        temp_rg_data = temp_rg_data.rename({vname:f'{vname}_{h}'})
        out_ds = out_ds.merge(temp_rg_data)
        # return out_ds
    return out_ds    

def derive_ice_frac(clivi, clwvi,ds_out):
    liq_frac = clivi/clwvi
    liq_frac.attrs['long_name'] = "Ice cloud mass fraction"
    liq_frac.attrs['units'] = "1"
    liq_frac.attrs['comment'] = "Ice cloud mass diveded by total cloud water mass"
    
    return ds_out.assign(clifrac=liq_frac)

def conver_pr_rate(out_ds):
    with xr.set_options(keep_attrs=True):
        if out_ds['pr'].attrs['units'] == 'kg m-2 s-1':
            out_ds = out_ds.assign(pr = out_ds['pr']*3600*365)
        elif out_ds['pr'].attrs['units'] == 'kg m-2 year':
            pass
        out_ds['pr'].attrs['units'] = 'mm year-1' 
    return out_ds

In [None]:
def cald_dm(a,b,a_bnds,b_bnds,p0,ps, zdim='lev'):

    hya_bnds = a_bnds
    hyb_bnds = b_bnds
    pressure = a*p0 + b*ps
    dm=abs( (hya_bnds[:,1] - hya_bnds[:,0] ) * p0 + ( hyb_bnds[:,1] - hyb_bnds[:,0] )  * ps) / 9.81   # dm = dp / g

    return dm
    
    # return pressure


In [None]:

def liqud_water_path(ds,model,vname, plev=None,t_mask=None, ps=None, pfull=None):
    ds=ds.copy()
    da=ds[vname] 
    
    if model == 'IPSL-CM6A-LR-INCA':
        ds = ds.rename(lev='plev')
        if t_mask is not None:
            t_mask = t_mask.rename(lev='plev')
    elif model == 'UKESM1-0-LL':
        pf_top = pfull[:,:-1,:,:]
        pf_bot = pfull[:,1:,:,:]
        pf_bot = pf_bot.assign_coords(lev=pf_top.lev)
        dm = (pf_top - pf_bot)/9.81
        
    else:
        try:
            z_axis = ds.cf['Z']
            formula = z_axis.formula
        except KeyError:
            raise('Could not find formula')

        
        if 'ap' in ds.coords:
            ds = ds.rename_vars({'ap': 'a'})
        if 'ap_bnds' in ds.coords:
            ds = ds.rename_vars({'ap_bnds': 'a_bnds',})

        if 'lev_bounds' in ds.coords:
            ds = ds.rename_vars({'lev_bounds': 'lev_bnds'})
            # ds_exp = ds_exp.rename_vars({'lev_bounds': 'lev_bnds'})
        
        if formula == 'p = a*p0 + b*ps':

            p0=ds.get('p0', 100000.0)
                                            
        elif formula in ['p = ap + b*ps','p(n,k,j,i) = ap(k) + b(k)*ps(n,j,i)']:

            p0=ds.get('p0', 1)
        ps_surf = ds.get('ps', ps)
        dm = cald_dm(a=ds['a'],b=ds['b'],a_bnds=ds['a_bnds'], b_bnds=ds['b_bnds'], p0=p0,ps=ps_surf)
        
        dm = dm.drop(['a','member_id','ps','b'],errors='ignore')
    if 'plev' in ds.dims:
        
        ds = ds.cf.add_bounds('plev')
        
        dp = ds['plev_bounds'][:,0]- ds['plev_bounds'][:,1] 
        
        dm = dp / 9.81
        
    
    _lwp = ds[vname]*dm
    if t_mask is not None:
        _lwp = _lwp.where(t_mask)
    if model == 'IPSL-CM6A-LR-INCA':
        lwp_sum = np.sum(_lwp, axis=_lwp.get_axis_num('plev'), keepdims=True).squeeze()
    else:
        lwp_sum = np.sum(_lwp, axis=_lwp.get_axis_num('lev'), keepdims=True).squeeze()

    lwp_sum = lwp_sum.assign_attrs(ds[vname].attrs)
    lwp_sum = lwp_sum.assign_attrs({
            'long_name': 'Liquid Water Path',
            'units': 'kg m-2',
            'mipTable': '',
            'out_name': 'lwp',
            'standard_name': 'atmosphere_mass_content_of_cloud_liquid_water',
            'title': 'Liquid Water Path',
            'variable_id': 'lwp',
            'original_units': 'kg/kg',
            'history': "Colum integration"
        })
    lwp = lwp_sum.rename('lwp')
    if 'bounds' in lwp.dims:
        lwp = lwp.isel(bounds=0)
    if 'bnds' in lwp.dims:
        lwp = lwp.isel(bnds=0)
    lwp = lwp.drop(['bounds', 'bnds'], errors='ignore')
    return lwp




In [None]:
ds_amon = calc_cld_lev(ds_amon_3d[['cl']],ds_amon,cld_def, 
                           ps = ds_amon_3d.get('ps', None),
                          ds_pfull=ds_amon_3d.get('pfull', None)
                      )
ds_amon = derive_ice_frac(ds_amon['clivi'], ds_amon['clwvi'], ds_amon)





In [None]:
lwp=liqud_water_path(ds_amon_3d,mod_id,'clw', ps = ds_amon_3d.get('ps'),
                                              pfull=ds_amon_3d.get('pfull',None))
ds_amon = ds_amon.assign(lwp=lwp)

In [None]:
if mod_id not in ['UKESM1-0-LL', 'IPSL-CM6A-LR-INCA']:
    ds_amon = ds_amon.assign(lwp_sl=liqud_water_path(ds_amon_3d, mod_id, 'clw',
                              t_mask = ds_ta['ta'] < 273,
                             pfull=ds_amon_3d.get('pfull', None)))
    sclf = ds_amon['lwp_sl']/ds_amon['clwvi']
    sclf.attrs['units']='dimmensionless'
    sclf.attrs['long_name']='Super cooled liquid fraction'
    sclf.attrs['variable_id']='sclf'
    sclf.attrs['comment'] = "liquid waterpath intergrated for temperatures below 273K dived by vertically integrated cloud water content"
    
    ds_amon = ds_amon.assign(sclf=sclf)




In [None]:
if snakemake.rule == 'make_dust_cloud_diag_file_IPSL':
    pass
else:
    if np.all(nd.lon.values==ds_amon.lon.values) == False:
        nd = nd.assign(lon=ds_amon.lon)
    if np.all(nd.lat.values==ds_amon.lat.values) == False:
        nd = nd.assign(lat=ds_amon.lat)
    ds_amon = ds_amon.merge(nd)

In [None]:
from dask.diagnostics import ProgressBar
with ProgressBar():
    ds_amon = ds_amon.compute()


In [None]:
ds = ds_amon.drop(['member_id', 'lev','ps'], errors='ignore')

In [None]:

ds_amon.to_netcdf(snakemake.output.dust_cloud_diag_exp)