In [None]:
import sys
import os

### data handling
import numpy as np
import pandas as pd
import xarray as xr
import scipy as sci
import numpy_groupies as npg 

### plotting
import matplotlib.pyplot as plt

### flox for GroupBy Reductions
import flox.xarray


In [None]:
from dask.distributed import Client
import dask.array as da
import dask
client = Client(threads_per_worker=1)
client

In [None]:
# preprocesser to drop unwanted variables
def drop_stuff(ds, coords_to_drop,vars_to_drop):
    """
    Preprocessor function to drop specified coordinates and variables from a dataset loaded via xr.open_mfdataset

    Parameters:
        ds (xarray.Dataset): The dataset from which coordinates & variables are to be dropped.
        coords_to_drop (list of str): List of coordinate names to drop.
        vars_to_drop(list of str): List of variable names to drop

    Returns:
        xarray.Dataset: Dataset with specified coordinates and variables dropped.
    """
    # Drop coordinates if they are in the dataset
    ds = ds.drop_vars(coords_to_drop, errors='ignore')
    ds = ds.drop_vars(vars_to_drop, errors='ignore')
    return ds

In [None]:
def print_chunks(data_array):
    ''' 
    Print the chunk sizes for each dimension in an Xarray dataset
    
    Parameters: 
    data_array (xarray.Daraset): The dataset from which to print chunks
    '''
    # Get chunk size
    chunks = data_array.chunks
    # Get dimension names
    dim_names = data_array.dims
    
    readable_chunks = {dim: chunks[i] for i, dim in enumerate(dim_names)}
    
    # Print chunk sizes for each dimension
    for dim, sizes in readable_chunks.items():
        print(f"{dim} chunks: {sizes}")

In [None]:
wrkdir = "/g/data/fp2/OFAM" #not using os.chdir() 

In [None]:
%%time

# Define coordinates and variables to drop
coords_to_drop =['st_edges_ocean','nv']
vars_to_drop =['Time_bounds','average_DT','average_T1','average_T2']

# Load the datasets with preprocessing
big_sst = xr.open_mfdataset(
    [wrkdir, "/jra55_historical.1/surface/ocean_temp_sfc_*.nc"], 
    parallel=True, 
    preprocess = lambda x: drop_stuff(x, 
                                         coords_to_drop, 
                                         vars_to_drop)).squeeze() #combine='by_coords' is default

# Rename Time to time
big_sst = big_sst.rename({'Time':'time'})

big_sst 

In [None]:
#Should this be assigned as a new coordinate? 

big_sst['doy'] = big_sst['time'].dt.dayofyear
doy

In [None]:
# Subset the dataset around the Australian continent
sst_reduced = big_sst.isel(yt_ocean=slice(y1,y2), 
                            xt_ocean=slice(x1,x2)).drop_vars('st_ocean')
sst_reduced

In [None]:
# Group the subset by DOY and take the mean using the cohort method, no chunking 
climatology_cohorts = flox.xarray.xarray_reduce(
    sst_reduced,
    doy,
    func="mean",
    method="cohorts",
)

In [None]:
%%time
#XXL-mem normal queue
climatology_cohorts = climatology_cohorts.compute()

In [None]:
# Define a function to calculate mean and quantile
def calculate_stats(chunk):
    mean = chunk.groupby('doy').mean(dim='time')
    thresh = chunk.groupby('doy').quantile(0.9, dim='time', skipna=True)
    return mean, thresh

# Map function across chunks using Dask delayed
seas_list = []
thresh_list = []
for ii in xt_chunks:
    for jj in yt_chunks:
        chunk = sst_chunked.isel(xt_ocean=slice(ii, ii+di), yt_ocean=slice(jj, jj+dj))
        mean, thresh = dask.delayed(calculate_stats)(chunk)
        seas_list.append(mean)
        thresh_list.append(thresh)

# Compute results
seas_results, thresh_results = dask.compute(seas_list, thresh_list)

# Merge the results into single xarrays
seas_new = xr.merge(seas_results)
thresh_new = xr.merge(thresh_results)

In [None]:
# Pad and apply rolling mean window of size 31 along DOY
climatology = climatology_cohorts.pad(doy=(31-1)//2, mode='wrap').rolling(doy=31, center=True).mean()
threshold90 = threshold_cohorts.pad(doy=(31-1)//2, mode='wrap').rolling(doy=31, center=True).mean(skipna=True)

# Chunk the data and select a subset along DOY
climatology = climatology_cohorts.chunk({'doy':-1, 'yt_ocean':50, 'xt_ocean':50}).isel(doy=slice(15,-15))
threshold90 = threshold_cohorts.chunk({'doy':-1, 'yt_ocean':50, 'xt_ocean':50}).isel(doy=slice(15,-15)).drop_vars('quantile')

In [None]:
%%time
os.chdir("/g/data/xv83/users/ep5799/heatwaves")
os.getcwd()

print("Saving climatology and threshold to disk")
climatology.to_netcdf('Australian_SST_daily_climatology.nc', mode='w')
threshold90.to_netcdf('Australian_SST_daily_MHWthreshold.nc', mode='w')