# Batch code to compute ffdi and kdbi


In [1]:
import intake
import xarray as xr
from matplotlib import pyplot as plt
import glob
import warnings
from sys import argv
from datetime import datetime
from rechunker import rechunk
from xclim.indices import (
    keetch_byram_drought_index,
    griffiths_drought_factor,
    mcarthur_forest_fire_danger_index
)
from dask.distributed import Client
import dask

In [2]:
def setup_dask_client():
# Set configuration options
    dask.config.set({
    'distributed.comm.timeouts.connect': '90s',  # Timeout for connecting to a worker
    'distributed.comm.timeouts.tcp': '90s',  # Timeout for TCP communications
    })
    client = Client()
    return client

In [3]:
def load_catalogue(catalogue_path,type):
# type = ssp126 hist ssp370 ERA ...
    mRuns = sorted(glob.glob(catalogue_path+'/*'+type+'*.json'))
    return mRuns

In [4]:
def extract_model_info(file):
    model_name = file.split('/')[-1].split('.')
    tmp = model_name[0].split('_')
    ACS_model = tmp[1]
    match = f"{tmp[1]}_{tmp[2]}"
    return model_name[0], match

In [5]:
    
def compute_rainfall(match,qtype,dir):
    # read in the historical catalogue to compute 
    catalogue_path='/scratch/xv83/rxm599/nobiascor'
    catalogue_path=dir
    mHis = sorted(glob.glob(catalogue_path+'/*'+match+'*hist*.zarr'))
    print('historical datalen',len(mHis))
    print('historical file', mHis[0])
    dcat=xr.open_zarr(mHis[0])
    
    if qtype == 'raw' : 
        pr_his = dcat.pr
    elif qtype == 'adjust' : 
        pr_his = dcat.prAdjust
    pr_his
    
    # This is usually computed over a specific period
    #pr_annual = dsets.pr.resample(time="A").sum().mean("time")
    pr_annual = pr_his.resample(time="Y").sum().\
        sel(time=slice('1985-01-01', '2015-01-01')).mean("time").\
        assign_attrs(
        units="mm a-1", long_name="Annual average rainfall")
    pr_annual.load()
    return pr_annual

In [6]:
def open_zarr_file(dir, model_name):
    file = f"{dir}{model_name}.zarr"
    return xr.open_zarr(file)
    
def open_zarr_files(dir, model_name,match):
# add the end of historical simulation to the projection dataset to get FFDI values for start of projection
    mHis = sorted(glob.glob(dir+'/*'+match+'*hist*.zarr'))
    print('model_name = ', model_name)
    print('historical datalen = ',len(mHis))
    print('historical file = ', mHis[0])
    dhis=xr.open_zarr(mHis[0])
    dhisa=dhis.sel(time=slice('2014-12-12', '2015-01-01'))
    
    file = f"{dir}{model_name}.zarr"
    ds= xr.open_zarr(file)
#
    dsa=xr.concat([dhisa,ds],dim='time')
    dsall=dsa.chunk({"time": len(dsa.time)}) #, "lat": 33, "lon": 43})
    return dsall

In [7]:
def calculate_ffdi(pra, tasmaxa, pr_annual, hursmina, sfcWindmaxa):
    KBDI = keetch_byram_drought_index(pra, tasmaxa, pr_annual)
    DF = griffiths_drought_factor(pra, KBDI)
    FFDI = mcarthur_forest_fire_danger_index(DF, tasmaxa, hursmina, sfcWindmaxa)
    return FFDI.to_dataset(name='FFDI'), KBDI.to_dataset(name='KBDI')

In [8]:
def save_ffdi_dataset(ds1,ds2, model_name):
    out1 = f'/scratch/xv83/rxm599/ffdi/{model_name}_FFDI.zarr'
    out2 = f'/scratch/xv83/rxm599/ffdi/{model_name}_KBDI.zarr'
    ds2.to_zarr(out2, mode='w')
    ds1.to_zarr(out1, mode='w')
    return out1

In [9]:
def plot_ffdi(output_path):
    ds1 = xr.open_zarr(output_path)
    ds1.FFDI.isel(time=slice(19, None)).max('time', skipna=False).plot()
    ds1

In [10]:
# Example code to read from command line if not interactive (just make the cell code)
def is_interactive():
    import __main__ as main
    return not hasattr(main, '__file__')


In [None]:
def main():
    
    warnings.filterwarnings('ignore')
#    client.shutdown()

    client = setup_dask_client()
    if is_interactive():
        qtype='raw'  # raw uncorrected data
        qtype='adjust'  # bias corrected data (only applied to QME at present)
        kind='NSW-G'
        runs=range (24,50,1)
        print(f"Dashboard available at: {client.dashboard_link}")
    else:
        print(argv)
# raw or adjust
        qtype=argv[1]
        kind=argv[2]
        ss=argv[3:]
        print(ss)
        params = [int(i) for i in ss]
        runs=range(params[0],params[1],params[2])
    print (runs) 

    if qtype == 'raw' :
        catalogue_path = '/g/data/ia39/catalogues/bias-input'
        dir = '/scratch/xv83/rxm599/nobiascor/'
        target_store_base = "/scratch/xv83/rxm599/ffdi"
    elif qtype == 'adjust' :
        catalogue_path = '/g/data/ia39/catalogues/bias-output'
        dir = '/scratch/xv83/rxm599/biascor/'
        target_store_base = "/scratch/xv83/rxm599/ffdi"
    

    mRuns = load_catalogue(catalogue_path,kind)
    print(mRuns)
    print('Runs to process:', len(mRuns))
    model_old=''
    for run_index in runs:
        if run_index > len(mRuns):
            print(f"Run index {run_index} is out of range. There are only {len(mRuns)} runs available.")
        print(run_index,mRuns[run_index]) 
        file = mRuns[run_index]

        model_name, match = extract_model_info(file)
        print('main model-name, match =',model_name,match)

##        %time pr_annual = compute_and_plot_rainfall(match)
        if match != model_old:
            pr_annual = compute_rainfall(match,qtype,dir)
        else:
            print('Model match = ',match)
        print('Finished pr_annual')
        model_old=match

        if "hist" in model_name:
            ds0 = open_zarr_file(dir, model_name)
            print("hist length = ", len(ds0.time))
        else:
            ds0 = open_zarr_files(dir, model_name,match)
            print("ssp length = ", len(ds0.time))
            print(ds0)

        t1, t2 = '2015-01-01', '2099-12-31'
        t1, t2 = '1950-01-01', '2100-01-01'
        if qtype == 'raw' : 
            pra = ds0.pr.sel(time=slice(t1, t2))
            tasmaxa = ds0.tasmax.sel(time=slice(t1, t2))
            hursmina = ds0.hursmin.sel(time=slice(t1, t2))
            sfcWindmaxa = ds0.sfcWindmax.sel(time=slice(t1, t2))
        elif qtype == 'adjust' : 
            pra = ds0.prAdjust.sel(time=slice(t1, t2))
            tasmaxa = ds0.tasmaxAdjust.sel(time=slice(t1, t2))
            hursmina = ds0.hursminAdjust.sel(time=slice(t1, t2))
            sfcWindmaxa = ds0.sfcWindmaxAdjust.sel(time=slice(t1, t2))

        ds1, ds2 = calculate_ffdi(pra, tasmaxa, pr_annual, hursmina, sfcWindmaxa)
        output_path = save_ffdi_dataset(ds1,ds2, model_name)
#        plot_ffdi(output_path)

        print(f"complete processing for model {run_index} {model_name}")

    print("Processing of all catalogues complete.")
    client.close()

if __name__ == "__main__":
    main()

## Extra but must be markdown if batch version is to work.

out1='/scratch/xv83/rxm599/ffdi/'
out2='AGCD-05i_BOM_ACCESS-CM2_historical_r4i1p1f1_BARPA-R_v1-r1-ACS-QME-BARRA-R2-1980-2022_day_FFDI.zarr'
out2='AGCD-05i_BOM_ACCESS-CM2_ssp370_r4i1p1f1_BARPA-R_v1-r1-ACS-QME-BARRA-R2-1980-2022_day_FFDI.zarr'
plot_ffdi(out1+out2)

out2='AGCD-05i_BOM_ACCESS-CM2_ssp370_r4i1p1f1_BARPA-R_v1-r1-ACS-QME-BARRA-R2-1980-2022_day_KBDI.zarr'
ds1 = xr.open_zarr(out1+out2)

#ds1.KBDI.mean('time').plot()
ds1.KBDI.isel(time=slice(19, None)).mean('time', skipna=False).plot()