In [1]:
# functions to get daily P, E, Q

import numpy as np
import xarray as xr
import os
import datetime
import math
from tqdm.auto import tqdm

def get_awra_var(var_name, awra_dir = '/g/data/fj8/BoM/AWRA/DATA/SCHEDULED-V6/', lat_slice = None, lon_slice = None, time_slice = None):
    file_names = var_name + '_*.nc' 
    ds = xr.open_mfdataset(awra_dir + file_names, chunks = {'lat':400,'lon':400})

    if lat_slice is None:
        if time_slice is None:
            da_var = ds[var_name].rename({'latitude':'lat','longitude':'lon'})
        else:
            da_var = ds[var_name].sel(time = time_slice).rename({'latitude':'lat','longitude':'lon'})
    else:
        if time_slice is None:
            da_var = ds[var_name].sel(latitude = lat_slice, longitude = lon_slice).rename({'latitude':'lat','longitude':'lon'})
        else:
            da_var = ds[var_name].sel(time = time_slice, latitude = lat_slice, longitude = lon_slice).rename({'latitude':'lat','longitude':'lon'})      
    return da_var

def get_agcd_var(agcd_dir = '/g/data/zv2/agcd/v1/precip/total/r005/01day/', agcd_files = 'agcd_v1_precip_total_r005_daily_*.nc', 
                   lat_slice = slice(-44, -10), lon_slice = slice(112, 154), time_slice = None):
    ds_agcd = xr.open_mfdataset(agcd_dir + agcd_files) #, chunks = {'lat':400,'lon':400})
    if time_slice is None:
        da_P = ds_agcd['precip'].sel(lat = lat_slice, lon = lon_slice)
    else:
        da_P = ds_agcd['precip'].sel(lat = lat_slice, lon = lon_slice, time = time_slice)
    return da_P

def calc_daily_PmEQ_roll(window, out_dir = None, lat_slice_P = None, lat_slice_EQ = None, lon_slice = None, time_slice = None):
    
    # read the data from orig files
    if lat_slice_P is None:
        if time_slice is None:
            da_P = get_agcd_var()
        else:
            da_P = get_agcd_var(time_slice = time_slice)
    else:
        if time_slice is None:
            da_P = get_agcd_var(lat_slice = lat_slice_P, lon_slice = lon_slice)
        else:
            da_P = get_agcd_var(lat_slice = lat_slice_P, lon_slice = lon_slice, time_slice = time_slice)

    if lat_slice_EQ is None:
        if time_slice is None:
            da_E = get_awra_var("etot")
            da_Q = get_awra_var("qtot")
        else:
            da_E = get_awra_var("etot", time_slice = time_slice)
            da_Q = get_awra_var("qtot", time_slice = time_slice)
    else:
        if time_slice is None:
            da_E = get_awra_var("etot", lat_slice = lat_slice_EQ, lon_slice = lon_slice)
            da_Q = get_awra_var("qtot", lat_slice = lat_slice_EQ, lon_slice = lon_slice)
        else:
            da_E = get_awra_var("etot", lat_slice = lat_slice_EQ, lon_slice = lon_slice, time_slice = time_slice)
            da_Q = get_awra_var("qtot", lat_slice = lat_slice_EQ, lon_slice = lon_slice, time_slice = time_slice)

    time_new = da_P['time'].dt.floor('D')
    da_P = da_P.assign_coords(time=time_new)
    
    # converting the datatypes of E to match P
    lat_new = np.float32(da_E['lat'])
    lon_new = np.float32(da_E['lon'])
    da_E = da_E.assign_coords(lat = lat_new)
    da_E = da_E.assign_coords(lon = lon_new)
    lat_new = np.float32(da_Q['lat'])
    lon_new = np.float32(da_Q['lon'])
    da_Q = da_Q.assign_coords(lat=lat_new)
    da_Q = da_Q.assign_coords(lon=lon_new)
    
    # return da_P, da_E, da_Q
    
    da_PmEQ = (da_P - da_E - da_Q).rename('PminusEQ')
    
    window_centre = math.floor(window/2)
    daydiff = np.timedelta64(window_centre + 1, 'D')  # add 1 because these are accumulated variables. 
                                                      # value at one date is actually the accumulated amount upto that date (i.e., accumultaion from the previous date)
    da_time_new = da_PmEQ.time - daydiff
    
    da_PmEQ_roll_temp = da_PmEQ.rolling(time=window, center=True).sum().assign_coords({'time': da_time_new})
    da_PmEQ_roll = da_PmEQ_roll_temp[(window_centre + 1):(len(da_time_new) - window_centre + 1),:,:]
    if out_dir is None:
        return da_PmEQ_roll
    else:
        for year, sample in tqdm(da_PmEQ_roll.groupby('time.year')):
            out_file = out_dir + 'PminusEQ_daily_roll_' + str(window) + 'days_' + str(year) + '.nc'
            sample.to_netcdf(out_file)
            return None

In [None]:
lat_slice_P = slice(-39, -32)
lat_slice_EQ = slice(-32, -39)
lon_slice = slice(139, 152)
time_slice = slice('1911-01-01', '2020-05-31')

for ts in [6, 12]:
    window = ts*7
    da_PmEQ_roll = calc_daily_PmEQ_roll(window = window, lat_slice_P = lat_slice_P, lat_slice_EQ = lat_slice_EQ, lon_slice = lon_slice, time_slice = time_slice)

    main_dir = '/g/data/w97/ad9701/p_prob_analysis/temp_files/'
    out_dir = main_dir + 'GLM_results_full_record/validation/PminusEQ_week2_roll_daily/'

    for year, sample in tqdm(da_PmEQ_roll.groupby('time.year')):
        out_file = out_dir + 'PminusEQ_daily_roll_' + str(ts) + 'weeks_' + str(year) + '.nc'
        sample.to_netcdf(out_file)

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


  0%|          | 0/110 [00:00<?, ?it/s]

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


  0%|          | 0/110 [00:00<?, ?it/s]