In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import scipy as sci

### plotting
import matplotlib.pyplot as plt

### flox for GroupBy Reductions
import flox.xarray

import dask.array as da
from dask.distributed import Client


KeyboardInterrupt: 

In [None]:
client = Client(threads_per_worker=1)

In [None]:
# preprocesser to drop unwanted variables
def drop_stuff(ds, coords_to_drop,vars_to_drop):
    """
    Preprocessor function to drop specified coordinates and variables from a dataset loaded via xr.open_mfdataset

    Parameters:
        ds (xarray.Dataset): The dataset from which coordinates & variables are to be dropped.
        coords_to_drop (list of str): List of coordinate names to drop.
        vars_to_drop(list of str): List of variable names to drop

    Returns:
        xarray.Dataset: Dataset with specified coordinates and variables dropped.
    """
    # Drop coordinates if they are in the dataset
    ds = ds.drop_vars(coords_to_drop, errors='ignore')
    ds = ds.drop_vars(vars_to_drop, errors='ignore')
    return ds

In [None]:
def print_chunks(data_array):
    ''' 
    Print the chunk sizes for each dimension in an Xarray dataset
    
    Parameters: 
    data_array (xarray.Daraset): The dataset from which to print chunks
    '''
    # Get chunk size
    chunks = data_array.chunks
    # Get dimension names
    dim_names = data_array.dims
    
    readable_chunks = {dim: chunks[i] for i, dim in enumerate(dim_names)}
    
    # Print chunk sizes for each dimension
    for dim, sizes in readable_chunks.items():
        print(f"{dim} chunks: {sizes}")

In [None]:
%%time

# Define coordinates and variables to drop
coords_to_drop =['st_edges_ocean','nv']
vars_to_drop =['Time_bounds','average_DT','average_T1','average_T2']

dir = '/g/data/fp2/OFAM3/jra55_historical.1/surface/'

# Load the datasets with preprocessing
sst = xr.open_mfdataset(
    dir + 'ocean_temp_sfc_*.nc', 
    parallel=True, 
    preprocess = lambda x: drop_stuff(x, 
                                         coords_to_drop, 
                                         vars_to_drop)).squeeze() #combine='by_coords' is default

# Rename Time to time
sst = sst.rename({'Time':'time'})

In [None]:
# Function to calculate daily climatology and convert to single precision
def monthly_climatology(ds, time_dim):
    """
    Calculate the monthly climatology for a given dataset.

    This function takes an xarray dataset and computes the monthly climatology
    by averaging data for each month over all years. The function adds a 'month'
    dimension to the dataset based on the provided time dimension.

    Parameters:
        ds (xarray.Dataset): The input dataset containing the time dimension.
        time_dim (str): The name of the time dimension in the dataset.

    Returns
        xarray.Dataset: Dataset with the mean values for each month across all years. 
        The data type of returned values is float32.
    """
    ds['doy'] = ds[time_dim].dt.dayofyear
    climatology_cohorts = flox.xarray.xarray_reduce(
        ds,
        'doy',
        func='mean',
        method='cohorts',
    )
    return climatology_cohorts.astype(np.float32)

In [None]:
# Process and save climatology data
def process_climatology(ds, time_dim, variable):
    data = getattr(ds, variable) 
    clim = monthly_climatology(data, time_dim).persist()

    file_path = '/g/data/xv83/users/ep5799/Heatwaves/Australian_SST_monthly_climatology.nc'
    
    # Save as netCDF
    clim.to_netcdf(file_path, compute=True)

In [None]:
%%time
process_climatology(sst_southern, 'time', 'temp')

In [None]:
threshold90.to_netcdf('/g/data/ia39/ncra/ocean/peacey/mhw/Australian_SST_monthly_MHWthreshold.nc', mode='w', compute=True)