In [None]:
from dust.process_data_dust import determine_units
import xarray as xr
import dust
from dask.distributed import Client, LocalCluster
from dask_jobqueue import SLURMCluster
import dask
import numpy as np
from netCDF4 import date2num, num2date
import pandas as pd
import time
import sys
from dask.diagnostics import ProgressBar
import numpy
import zarr
from numcodecs import Blosc

In [None]:
def setup_outda(ds, x0, x1, y0, y1, kind=None):
    """
    DESCRIPTION:
    ===========
        Setup output dataset for storing processes flexpart output
    ARGUMENTS:
    ==========
        dset: xarray.dataset
            dataset containing FLEXPART model output
        x0: float
            lon of lower left corner of cutout.
        x1: float
            longtude of upper right corner of cutout.
        y0: float
            latitude of lower left corner of cutout.
        y1: float
            latidue of upper left corner of cutout.

    RETURN:
    =======
        out_data: empty output DataArray.
    """

    # select output domain
    dset = ds.sel(lon=slice(x0, x1), lat=slice(y0, y1))

    # Assumes that the first part of the RELCOM string contains the date.
    if isinstance(dset.RELCOM.values[0], np.bytes_):
        dset = dset.assign(RELCOM=dset.RELCOM.astype(str))

    # Determine size of backward time dimmension
    lout_step = abs(dset.attrs["loutstep"])
    btime_size = int(dset["LAGE"] / lout_step * 1e-9)
    lout_step_h = int(lout_step / (60 * 60))
    btime_array = -np.arange(
        lout_step_h, btime_size * lout_step_h + lout_step_h, lout_step_h
    )
    # Create new time forward time dimmension
    t0 = pd.to_datetime(dset.ibdate + dset.ibtime) + pd.to_timedelta(
        dset["LAGE"].values, unit="ns"
    )
    t0=t0.to_pydatetime()[0]
    # Assumes that the first part of the RELCOM string contains the date.
    time_a = pd.to_datetime(
        [date.split(" ")[0] for date in dset.RELCOM.values], format="%Y%m%d%H"
    ).to_pydatetime()
    
    time_a = date2num(
        time_a, units="hours since {}".format(t0.strftime("%Y-%m-%d %H:%S"))
    )
    
    time_var = xr.Variable(
        "time",
        time_a,
        attrs=dict(
            units="hours since {}".format(t0.strftime("%Y-%m-%d %H:%M:%S")),
            calendar="proleptic_gregorian",
        ),
    )
    # create output DataArray
    out_data = xr.DataArray(
        np.zeros(
            (len(dset["pointspec"]), btime_size, len(dset["lat"]), len(dset["lon"])),
            dtype=np.float32,
        ),
        dims=["time", "btime", "lat", "lon"],
        coords={
            "time": time_var,
            "btime": (
                "btime",
                btime_array,
                dict(long_name="time along back trajectory", units="hours"),
            ),
            "lon": ("lon", dset["spec001_mr"].lon.data, dset.lon.attrs),
            "lat": ("lat", dset["spec001_mr"].lat.data, dset.lat.attrs),
        },
        attrs=dict(
            spec_com=dset.spec001_mr.attrs["long_name"],
        ),
    )
    # print(scale_factor)
    last_btime = out_data.btime[-1].values
    first_btime = out_data.btime[0].values
    time_units = out_data.time.units
    if kind:
        out_data.name = kind
    return out_data

def reorder_da(da,ems_sens, groups):
    pspec = groups[da.time.values] # get spec that I want
    ems_sens = ems_sens[pspec]
    date0 = da.time.values + da.btime[0].values
    date1 =da.time.values + da.btime[-1].values
    ems_sens = ems_sens.sel(time=slice(date1, date0))
    ems_sens = ems_sens.rename(time="btime")
    da = ems_sens[::-1].assign_coords(btime=da.btime)
    return da

def calc_source_contrib(da, ds_ems, scale_factor):
    date0 = da.time.values + da.btime[0].values
    date1 =da.time.values + da.btime[-1].values

    ds_ems = ds_ems.sel(time=da.time+da.btime)
    
    source_contrib = da*ds_ems*scale_factor
    
    return source_contrib

In [None]:
params = snakemake.params
x0 = params.x0
x1 = params.x1
y0 = params.y0
y1 = params.y1
height = params.get('height',100)
use_dask = params.get('use_dask',True)
use_Slurm = params.get('use_slurm', False)
kind = snakemake.wildcards.kind
if use_dask:
    dask.config.set({'distributed.worker.memory.spill':0.85,
                    'distributed.worker.memory.pause':0.9,
                    'distributed.worker.memory.target':0.95})
    
    if use_Slurm:
        res = snakemake.resources
        cluster = SLURMCluster(account="nn2806k",cores=1, memory=res.memory_per_job,
                      walltime=res.time, interface='ib0',
                      scheduler_options={'dashboard_address':None})
        cluster.adapt(maximum_jobs=res.max_threads)
    else:
        cluster = LocalCluster(n_workers=snakemake.threads,
                       threads_per_worker=1,
                       memory_limit='16GB'
                       )
    client = Client(cluster)

flexdust_dict = dust.read_flexdust_output(snakemake.input.flexdust_path+'/')
flexdust_ds = flexdust_dict['dset'].sel(lon=slice(x0,x1), lat=slice(y0,y1))
attrsfd = flexdust_dict['Summary']
attrsfd.pop('Corr. land/sea')
ds = xr.open_dataset(snakemake.input.flexpart_path[0]).sel(longitude=slice(x0,x1), latitude=slice(y0,y1))
ds = ds.chunk({'pointspec':1})
ds = ds.rename_dims({'longitude':'lon','latitude':'lat'})
ds = ds.rename_vars({'longitude':'lon','latitude':'lat'})

out_da = setup_outda(ds, x0,x1,y0,y1,kind)
out_da = out_da.to_dataset()
out_da = out_da.assign(surface_sensitivity=setup_outda(ds,x0,x1,y0,y1,kind='surface_sensitivity'))
out_da = xr.decode_cf(out_da).chunk({'time':1})

flexdust_ds = flexdust_ds.interp_like(out_da.isel(time=0))
lout_step = abs(ds.attrs["loutstep"])
ind_receptor = ds.ind_receptor 
if ind_receptor == 3 or ind_receptor == 4:
    scale_factor = (1 / (height)) * 1000  # Deposition is accumulative
else:
    # Concentration is not  accumulative Units of FLEXDUST need to be g/m^3s
    scale_factor = (1 / (height * lout_step)) * 1000

f_name, field_unit, sens_unit, field_name = determine_units(ind_receptor)

In [None]:
da = ds.spec001_mr

da = da.squeeze().sortby('time')

grb=out_da['surface_sensitivity'].groupby('time')

out_da = out_da.assign(surface_sensitivity=grb.map(reorder_da,args=[da.groupby('pointspec'),grb.groups]))

grb=out_da['surface_sensitivity'].groupby('time')

out_da=out_da.assign({kind:grb.map(calc_source_contrib, args=[flexdust_ds['Emission'], scale_factor])})
out_da = out_da.persist()

In [None]:
out_da = out_da.assign(
    {
        "RELLAT1": ds["RELLAT1"][0],
        "RELLNG1": ds["RELLNG1"][0],
        "RELZ1": ds["RELZZ1"][0],
        "RELZ2": ds["RELZZ2"][0],
        "RELPART": ds["RELPART"].sum(keep_attrs=True),
    }
)

# dset.attrs["ibdate"] = t0.strftime("%Y%m%d")
# dset.attrs["ibtime"] = t0.strftime("%H%M%S")
# out_da['surface_sensitivity'].attrs = {}
# out_da[kind].attrs = {}
out_da.attrs=ds.attrs.copy()
relcom_str= str(ds.RELCOM[0].values.astype('U35')).strip().split(' ',2)[1:]
out_da.attrs['relcom']=[s.strip() for s in relcom_str]
# out_da.attrs["relcom"] = receptor_name

out_da[kind].attrs["units"] = field_unit
out_da[kind].attrs["long_name"] = field_name
out_da["surface_sensitivity"].attrs["units"] = sens_unit

out_da.attrs["title"] = "FLEXPART/FLEXDUST model output"
out_da.attrs[
    "references"
] = "https://doi.org/10.5194/gmd-12-4955-2019, https://doi.org/10.1002/2016JD025482"
out_da.attrs["history"] = (
    "{} processed by {}, ".format(time.ctime(time.time()), snakemake.rule)
    + out_da.attrs["history"]
)
out_da.attrs['filename'] = snakemake.output.outpath.split('/')[-1]
out_da.attrs = {**out_da.attrs,**attrsfd}
out_da.attrs['varName'] = kind




In [None]:
def convert_to_py(attrs):
    out_at = {}
    for at in attrs:
            
        if len(at.split(' ')) > 1:
            k = "_".join(at.split(' '))
        else:
            k = at
        
    #     print(type(out_da.attrs[at]))
        if isinstance(attrs[at],numpy.float32):
            out_at[k]= str(attrs[at])
        elif isinstance(attrs[at],numpy.int32):
            out_at[k] = str(attrs[at])
#             print(type(out_da.attrs[at]))
        else:
            out_at[k] = attrs[at]
    return out_at

In [None]:
out_da.attrs = convert_to_py(out_da.attrs)
out_da['btime'].attrs = convert_to_py(out_da['btime'].attrs)
out_da['lon'].attrs = convert_to_py(out_da['lon'].attrs)
out_da['lat'].attrs = convert_to_py(out_da['lat'].attrs)
# out_da['surface_sensitivity'].attrs = {}
# out_da[kind].attrs = {}
# out_da.attrs = {}

encoding={'zlib':True, 'complevel':7, 
    'fletcher32' : False,'contiguous': False, 'shuffle' : True}

In [None]:
snakemake.output.outpath.endswith('nc')

In [None]:
if snakemake.output.outpath.endswith('zarr'):

    zarr.storage.default_compressor = Blosc(cname='lz4', clevel=7, shuffle=1, blocksize=0)
    outfile = out_da.to_zarr(snakemake.output.outpath,mode='w', consolidated=True, compute=False)
else:
    outfile = out_da.to_netcdf(snakemake.output.outpath, compute=False,encoding={kind:encoding,'surface_sensitivity': encoding},
                            unlimited_dims=['time'])

In [None]:
with ProgressBar():
    result=outfile.compute()

In [None]:
cluster.close()