### Test notebook for logistic regression fit
##### In 12 hours the code completed days from 2017-01-01 to 2018-07-30 (575 days)
##### Resetup i_time loop from 575 to the end of da_thresh.time with ~15 hours run time

In [None]:
%who

In [1]:
%%time

## Fitting the model to P-E-Q by season for calculated thresholds of P-E-Q

import sys
new_path = '/home/566/ad9701/drought_probability/'
if new_path not in sys.path:
    sys.path.append(new_path)

import xarray as xr
import numpy as np
import pandas as pd
import my_glmfit_funcs as myglm
import os

main_dir = '/g/data/w97/ad9701/p_prob_analysis/temp_files/'

varname = 'PminusEQ' #'P'   # the name of the directory and file
vname = 'PminusEQ'   #'precip'  # the name of the variable inside the files
fname = varname + '_*_*_*.nc'

def create_filepath_oneTime(ds, prefix='filename', root_path="."):
    """
    Generate a filepath when given an xarray dataset
    """
    time_str = ds.time.dt.strftime("%Y-%m-%d").data
    filepath = f'{root_path}/{prefix}_{time_str}.nc'
    return filepath

# select thresholds
# load the threshold data file & select the drought period of interest
PmEQ_events_file = 'sm_droughts/PmEQ_events_*.nc'
ds_thresh = xr.open_mfdataset(main_dir + PmEQ_events_file)
drght_time_slice = slice('1911-01-01', '2016-12-31')
drght_name = 'full_record'
drght_dir = 'GLM_results_' + drght_name

# select the thresholds for the time periods of the drought
thresName = 'PminusEQ'
da_thresh = ds_thresh[thresName].sel(time = drght_time_slice)

############################################
# GET THE SST PREDICTORS
############################################

# get the sst data
sst_dir = '/g/data/w97/ad9701/p_prob_analysis/sst_data/'
pNames = ['soi', 'sami', 'dmi', 'nino34_anom', 'nino4_anom']
pFiles = ['soi_monthly.nc', 'newsam.1957.2021.nc', 'dmi.had.long.data.nc', 'nino34.long.anom.data.nc', 'nino4.long.anom.data.nc']
for p in np.arange(len(pNames)):
    ds_temp = xr.open_dataset(sst_dir+pFiles[p])
    if (p>0):
        ds_p[pNames[p]]=ds_temp[pNames[p]]
    else:
        ds_p = ds_temp
    del ds_temp

# select the predictors to include in the model
predSel = ['soi', 'dmi']
formula = 'response ~ soi + dmi'
parameter = ['Intercept']
parameter.extend(predSel)

# select the sst predictors corresponding to the dates of the thresholds data
thresh_time_bymon = np.array(pd.to_datetime(da_thresh.time).to_period('M').to_timestamp().floor('D'))
da_p1_current = ds_p[predSel[0]].rename({'time': 'time'}).sel(time = thresh_time_bymon)
da_p2_current = ds_p[predSel[1]].rename({'time': 'time'}).sel(time = thresh_time_bymon)

############################################
# START A LOCAL CLUSTER
############################################

from dask.distributed import Client, LocalCluster
cluster = LocalCluster()
client = Client(cluster)
client

############################################
# PERFORM CALCULATIONS FOR THE MAIN SET
############################################

iW = 2
sub_dir = ''

# get data
data_dir = main_dir + varname + '_week' + str(iW) + '/' + sub_dir + '/'
ds = xr.open_mfdataset(data_dir + fname, chunks = {'lat':400, 'lon':400})
da_var_temp = ds[vname].reindex(lat=ds.lat[::-1]).chunk(chunks = {'lat':40,'lon':40,'time':-1}).rename({'time':'hist_time'})
da_var = da_var_temp.groupby('hist_time.season')

# select predictors for the same time points as the P-E or P-E-Q data at multi-weekly timescale
da_time_bymon = np.array(pd.to_datetime(ds.time).to_period('M').to_timestamp().floor('D'))
ds_p_sel = ds_p.sel(time = da_time_bymon)
ds_p1_sel_gb = ds_p_sel[predSel[0]].rename({'time':'hist_time'}).groupby('hist_time.season')
ds_p2_sel_gb = ds_p_sel[predSel[1]].rename({'time':'hist_time'}).groupby('hist_time.season')

dask_gufunc_kwargs = {'output_sizes':{"glm_parameter": len(parameter)}} #, 'time':1}}

# looping over the current times
for i_time in range(2108, 2192):
    seas = da_thresh['time.season'].values[i_time]
    da_logistReg = xr.apply_ufunc(
        myglm.fit_logistReg_2Pred_oneThres,             # first the function, this function returns a tuple (GLM params, GLM pvalues, GLM modelled probabilities)
        da_var[seas],                                # function arg
        ds_p1_sel_gb[seas].values,
        ds_p2_sel_gb[seas].values,
        predSel,                                     #      "
        da_thresh.sel(timescale = iW).isel(time = i_time),                                  #      "
        formula,                                     #      "
        [da_p1_current.values[i_time]],                    #      "
        [da_p2_current.values[i_time]],                    #      "
        input_core_dims=[["hist_time"], ["hist_time"], ["hist_time"], ["predictors"], [], [], [], []], #["sample_time"], ["sample_time"]],   # list with one entry per arg, these are the dimensions not to be broadcast
        output_core_dims=[["glm_parameter"], ["glm_parameter"], [], []],                                # dimensions of the output
        vectorize=True,                                                                                                                    # broadcast over non-core dimensions of the input object?
        dask="parallelized",                                                                                                               # enable dask?
        dask_gufunc_kwargs=dask_gufunc_kwargs,                     
        output_dtypes=[float, float, float, float]
    )

    # assign co-ordinates add metadata
    new_coords_dict = {'glm_parameter':parameter} #, 'current_time':[da_thresh['current_time'][i_time]]}    
    ds_all = da_logistReg[2].rename('glm_probability').to_dataset()
    ds_all['glm_params'] = da_logistReg[0].rename('glm_params').assign_coords(new_coords_dict)
    ds_all['glm_pvalues'] = da_logistReg[1].rename('glm_pvalues').assign_coords(new_coords_dict)
    ds_all['glm_aic'] = da_logistReg[3].rename('glm_aic')
    ds_all[predSel[0]] = da_p1_current.isel(time = i_time)
    ds_all[predSel[1]] = da_p2_current.isel(time = i_time)
    
    full_dir_path = main_dir + '/' + drght_dir + '/' + varname + '_week' + str(iW) + '/' + sub_dir + '/by_day/'
    if not os.path.exists(full_dir_path):
        os.makedirs(full_dir_path)
        
    out_file = create_filepath_oneTime(ds_all, prefix = 'GLM_results_' + '_'.join(predSel), root_path = full_dir_path)
    ds_all.to_netcdf(out_file)


#############################################
# PERFORM CALCULATIONS FOR EACH SET OF DATA
# TO DO LATER
#############################################

# nSets = (7*iWeek)-1    # number of sets in addition to the original aggregation

# progress_file = "drought_probability/logistRegr_varThresh_progress.txt"
# for i in range(22, nSets):
#     sub_dir = '/set' + str(i+2)
#     data_dir = main_dir + varname + '_week' + str(iWeek) + '/' + sub_dir + '/'
#     out_file = data_dir + 'GLM_results_' + '_'.join(predSel) + '_bySeason.nc'
    
#     check = os.path.isfile(out_file)
#     if check is True:
#         progress_text = varname + '/week' + str(iWeek) + sub_dir + ' is already done'
#         with open(progress_file, "a") as file_object:
#             file_object.write("\n")
#             file_object.write(progress_text)
#     else:    
#         progress_text = varname + '/week' + str(iWeek) + sub_dir + ' ' + str(datetime.datetime.now())
#         with open(progress_file, "a") as file_object:
#             file_object.write("\n")
#             file_object.write(progress_text)
#         ds_all = myfuncs.fit_gridded_logistReg(main_dir = main_dir, varname = varname, iWeek = iWeek, threshold = threshold, \
#                                                sub_dir = sub_dir, ds_p = ds_p, x_new = x_new)
#         ds_all.to_netcdf(out_file)
#         with open(progress_file, "a") as file_object:
#             file_object.write(' end time:' + str(datetime.datetime.now()))

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


CPU times: user 4min 16s, sys: 30 s, total: 4min 46s
Wall time: 12min 5s


In [None]:
from dask.distributed import Client,Scheduler
from dask_jobqueue import SLURMCluster
cluster = SLURMCluster(cores=4,memory="31GB")
client = Client(cluster)
cluster.scale(cores=4)
client

In [1]:
import sys
new_path = '/home/566/ad9701/drought_probability/'
if new_path not in sys.path:
    sys.path.append(new_path)

import xarray as xr
import numpy as np
import pandas as pd
import my_glmfit_funcs as myglm
import os

main_dir = '/g/data/w97/ad9701/p_prob_analysis/temp_files/'

varname = 'PminusEQ' #'P'   # the name of the directory and file
vname = 'PminusEQ'   #'precip'  # the name of the variable inside the files
fname = varname + '_*_*_*.nc'

def create_filepath_oneTime(ds, prefix='filename', root_path="."):
    """
    Generate a filepath when given an xarray dataset
    """
    time_str = ds.time.dt.strftime("%Y-%m-%d").data
    filepath = f'{root_path}/{prefix}_{time_str}.nc'
    return filepath

# select thresholds
# load the threshold data file & select the drought period of interest
PmEQ_events_file = 'sm_droughts/PmEQ_events_*.nc'
ds_thresh = xr.open_mfdataset(main_dir + PmEQ_events_file)
drght_time_slice = slice('2017-01-01', '2020-03-31')
drght_name = 'recent_drght'
drght_dir = 'GLM_results_' + 'recent_drght'

# select the thresholds for the time periods of the drought
thresName = 'PminusEQ'
da_thresh = ds_thresh[thresName].sel(time = drght_time_slice)

In [5]:
ds_thresh.time[2108]

Unnamed: 0,Array,Chunk
Bytes,8 B,8.0 B
Shape,(),()
Count,34 Tasks,1 Chunks
Type,int64,numpy.ndarray
Array Chunk Bytes 8 B 8.0 B Shape () () Count 34 Tasks 1 Chunks Type int64 numpy.ndarray,,

Unnamed: 0,Array,Chunk
Bytes,8 B,8.0 B
Shape,(),()
Count,34 Tasks,1 Chunks
Type,int64,numpy.ndarray


In [7]:
ds_thresh.time[2191]

Unnamed: 0,Array,Chunk
Bytes,8 B,8.0 B
Shape,(),()
Count,34 Tasks,1 Chunks
Type,int64,numpy.ndarray
Array Chunk Bytes 8 B 8.0 B Shape () () Count 34 Tasks 1 Chunks Type int64 numpy.ndarray,,

Unnamed: 0,Array,Chunk
Bytes,8 B,8.0 B
Shape,(),()
Count,34 Tasks,1 Chunks
Type,int64,numpy.ndarray


In [None]:
i_time

In [None]:
%who