In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import tempfile
from dask import delayed, compute
from dask.diagnostics import ProgressBar

import geocat.comp as geocomp

In [None]:
# from multiprocessing.pool import ThreadPool
# import dask
# dask.config.set(pool=ThreadPool(4))

if snakemake.input.cdnc[0].endswith('.zarr'):
    cdnc = xr.open_dataset(snakemake.input.cdnc[0], engine='zarr')
    ta = xr.open_dataset(snakemake.input.ta[0],engine='zarr')
else:
    cdnc = xr.open_dataset(snakemake.input.cdnc[0])
    ta = xr.open_dataset(snakemake.input.ta[0])

if len(cdnc.time) > 600:
    cdnc = cdnc.isel(time=slice(180,660))

dvar='cdncvi'
p2 = snakemake.params.get('p1', 10000)
attrs = cdnc.attrs.copy()

model = snakemake.input.cdnc[0].split('/')[-1].split('_')[-2]

In [None]:
cdnc = cdnc.chunk({'time':1})
ta = ta.chunk({'time':1})

ta = ta.sel(time=cdnc.time)

In [None]:
if snakemake.rule != 'column_integrate_cdnc_UKESM':
    with xr.set_options(keep_attrs=True):
        if model == 'CNRM-ESM2-1':
            dim = cdnc.dims
            b_bnds = cdnc['b_bnds'].values.ravel().reshape((dim['lev'], dim['bnds']))
            a_bnds = cdnc['ap_bnds'].values.ravel().reshape((dim['lev'], dim['bnds']))
            a_bnds = xr.DataArray(a_bnds, coords={'lev':cdnc.lev}, dims=['lev', 'bnds'])
            b_bnds = xr.DataArray(b_bnds, coords={'lev':cdnc.lev}, dims=['lev', 'bnds'])
            cdnc = cdnc.assign(ap_bnds=a_bnds,b_bnds=b_bnds)
        elif model == 'IPSL-CM6A-LR-INCA' and 'lev' in cdnc.dims:
            cdnc = cdnc.rename({'lev':'plev'})
            cdnc['plev_bnds'] = cdnc['ap_bnds'] + cdnc['b_bnds']*cdnc['ps']
            # cdnc = cdnc.transpose('time', 'lat', 'lon', 'plev', 'lev', 'bnds','axis_nbounds', 'half_lev', 'klevp1', missing_dims="ignore" )

        # Remove unnecessary variables
            cdnc = cdnc.drop_vars(['ap', 'b', 'ps', 'p0', 'ap_bnds', 'b_bnds', 'lev_bnds', 'orog'], errors='ignore')

            
    if 'ap' in cdnc.data_vars:
        cdnc = cdnc.rename_vars({'ap': 'a'})
    if 'ap_bnds' in cdnc.data_vars:
        cdnc = cdnc.rename_vars({'ap_bnds': 'a_bnds',})

    if 'lev_bounds' in cdnc.data_vars:
        cdnc = cdnc.rename_vars({'lev_bounds': 'lev_bnds'})
        # ds_exp = ds_exp.rename_vars({'lev_bounds': 'lev_bnds'})

    if 'lon_bounds' in cdnc.data_vars:
        cdnc = cdnc.rename_vars({'lon_bounds': 'lon_bnds'})
        cdnc = cdnc.rename_vars({'lat_bounds': 'lat_bnds'})

if model =='GISS-E2-1-G':
    print(model)
    with xr.set_options(keep_attrs=True):
        cdnc = cdnc.assign(cdnc=cdnc['cdnc']*(1e6))



In [None]:
try:
    z_axis = cdnc.cf['Z']
except KeyError:
    z_axis=None

if z_axis is not None:
    formula = z_axis.formula
else:
    formula = ''

if formula == 'z = a + b*orog':
    from functools import partial
    orog = cdnc['orog'].copy()
    pfull = xr.open_dataset(snakemake.input.pfull[0]).mean(dim='time', keep_attrs=True)
    
    valid_pressures = pfull.where(pfull['pfull'] > p2)
    cdnc = cdnc.assign(cdnc=cdnc['cdnc'].where(np.isnan(valid_pressures['pfull']),cdnc['cdnc']))
    z = (cdnc['lev'] + cdnc['orog']*cdnc['b']).copy()
    dz = z.where(np.isnan(valid_pressures['pfull'])==False, np.nan)
    dz = (dz.max(dim='lev') - dz.isel(lev=0))

elif model == 'IPSL-CM6A-LR-INCA':
    # dp = (cdnc['plev_bnds'].diff(dim='bnds')).squeeze()
    dp = cdnc['plev_bnds'][:,1]-cdnc['plev_bnds'][:,0]
    dp = dp.sel(klevp1=slice(0,len(cdnc['klevp1'])-1)).assign_coords({'klevp1': cdnc.plev.data})
    dp = dp.rename({'klevp1':'lev'})
    cdnc = cdnc.rename({'plev':'lev'})
    _cdnc = dp/9.81*cdnc['cdnc']
    
    cdncvi = np.sum(_cdnc, axis=_cdnc.get_axis_num('lev'), keepdims=True).squeeze()
    



In [None]:
def _chaching_calculation(ds,vname,temp_tag=''):
    import tempfile
    chunck_size = ds.chunksizes['time']
    print(chunck_size)
    tempdir = tempfile.mkdtemp(f'{temp_tag}')
    temp_ds = xr.zeros_like(ds)
    ds_shape = ds.dims
    temp_ds.to_zarr(tempdir,mode='w', compute=False)
    drop_dims = [d for d in list(ds_shape.keys()) if d != 'time']
    j_1 = 0
    j_2 = 0
    for j in range(len(chunck_size)):
        print(j)
        j_2 += chunck_size[j]
        dd = ds[vname].isel(time=slice(j_1,j_2)).drop(drop_dims)
        dd = dd.to_dataset().compute()
        print(dd)
        dd.to_zarr(tempdir, region={'time': slice(j_1,j_2)},mode='r+')
        j_1 += chunck_size[j]
    return xr.open_dataset(tempdir, engine='zarr')

In [None]:
def cdnc_column_integration(ta: xr.Dataset, cdnc: xr.Dataset, p0=1,pressure_lim=None):
    if cdnc.cf.get('Z'):
        z_dim = cdnc.cf.get('Z')
    else:
        z_dim='lev'
    ps = cdnc['ps']
    hya_bnds = cdnc['a_bnds']
    hyb_bnds = cdnc['b_bnds']
    pressure = cdnc['a']*p0 + cdnc['b']*ps
    dm=abs( (hya_bnds[:,1] - hya_bnds[:,0] ) * p0 + ( hyb_bnds[:,1] - hyb_bnds[:,0] )  * ps) / 9.81   # dm = dp / g
    Rair  = 287.058
    # return pressure
    if np.all(pressure.lat.values == ta.lat.values) == False:
        ta = ta.assign_coords(lat=pressure.lat)
    if np.all(pressure.lev.values == ta.lev.values) == False:
        ta = ta.assign_coords(lev=pressure.lev)
    
    density  = pressure / Rair / ta['ta'] 
    # return density, dm
    if pressure_lim is not None:
        pressure_mask = xr.where(pressure>pressure_lim,True, False)
        density = density.where(pressure_mask, drop=True)
        cdnc = cdnc.where(pressure_mask, drop=True)
        dm = dm.where(pressure_mask, drop=True)
    if model =='GISS-E2-1-G':
        dm = dm*100
    col_sum = cdnc['cdnc']*(dm/density)
    # return cdnc['cdnc'], (dm/density)
    return col_sum.sum(dim=z_dim)


In [None]:
if snakemake.rule == 'column_integrate_cdnc_UKESM':
    cdnc = cdnc.sum(dim='lev')*dz
    cdncvi = cdnc['cdnc']
else:
    if model == 'IPSL-CM6A-LR-INCA':
        cdncvi=cdncvi
        
    else:

        if formula == 'p = a*p0 + b*ps':

            p0=cdnc.get('p0', 100000.0)
                                            
        elif formula in ['p = ap + b*ps','p(n,k,j,i) = ap(k) + b(k)*ps(n,j,i)']:

            p0=cdnc.get('p0', 1)
        
        else:
            print('formula not reconized')

            
        
        cdnc_column_integration_delayed = delayed(cdnc_column_integration)
        
        # Create a list of delayed objects
        delayed_objects = [cdnc_column_integration_delayed(ta.isel(time=i), cdnc.isel(time=i), p0=p0,
                                                           pressure_lim=None)
                           for i in range(len(ta.time))]
        
        # Compute the results in parallel
        with ProgressBar():
            results = compute(*delayed_objects)
        cdncvi=xr.concat(results, dim='time')

In [None]:
if model == 'GISS-E2-1-G':
    cdncvi.attrs['units']='m-2'    
else:
    cdncvi.attrs['units']='m-2'
cdncvi.attrs['standard_name'] = 'vertically_integrated_number_concentration_of_cloud_liquid_water_particles_in_air'
cdncvi.attrs['long_name'] = 'Vertically Integrated Cloud Liquid Droplet Number Concentration'

In [None]:
out_ds = cdncvi.to_dataset(name=dvar)
out_ds = out_ds.drop(['plev','lev'],  errors='ignore')
out_ds.attrs = attrs
out_ds.attrs['history'] = f'{out_ds.attrs.get("history","")},{snakemake.rule} vertical intragration'
out_ds.attrs['variable_id'] = dvar
with xr.set_options(keep_attrs=True):
    out_ds = out_ds.resample(time='Y').mean()

In [None]:

try:
    from dask.diagnostics import ProgressBar
    with ProgressBar():
        results = out_ds.compute()
except MemoryError:
    print("OutOf memmory using Caching option")
    results = _chaching_calculation(out_ds.copy(),dvar, snakemake.output.outpath.split('/')[-1])
    

In [None]:
results.to_netcdf(snakemake.output.outpath)