In [None]:
import xarray as xr
from dask.distributed import Client
import dask.array as da
from dask import delayed
import numpy as np
import os

In [None]:
client = Client()
print("Dask dashboard available at:", client.dashboard_link)

In [None]:
coords_to_drop =['st_edges_ocean','nv','st_ocean']
vars_to_drop =['Time_bounds','average_DT','average_T1','average_T2','st_ocean']

# preprocesser to drop unwanted variables
def drop_stuff(ds, coords_to_drop,vars_to_drop):
    """
    Preprocessor function to drop specified coordinates and variables from a dataset loaded via xr.open_mfdataset

    Parameters:
        ds (xarray.Dataset): The dataset from which coordinates & variables are to be dropped.
        coords_to_drop (list of str): List of coordinate names to drop.
        vars_to_drop(list of str): List of variable names to drop

    Returns:
        xarray.Dataset: Dataset with specified coordinates and variables dropped.
    """
    # Drop coordinates if they are in the dataset
    ds = ds.drop_vars(coords_to_drop, errors='ignore')
    ds = ds.drop_vars(vars_to_drop, errors='ignore')
    return ds

In [None]:
def process_threshold(ds, time_dim, start, end, variable, period):
    # Rechunk the data along the time dimension
    data = getattr(ds, variable).sel(**{time_dim: slice(start, end)}).chunk({'Time': -1, 'xt_ocean': 50, 'yt_ocean': 50}).persist()

    # Debug: Print the shape and chunking of the data
    print(f"Processing period {period}, data shape: {data.shape}, chunks: {data.chunks}")

    # Calculate the 90th percentile
    percentile_90 = data.groupby('Time.month').quantile(0.9, dim='Time').compute()

    # Convert the result back to an xarray DataArray
    percentile_90_da = xr.DataArray(
        percentile_90,
         coords={
            'xt_ocean': data.coords['xt_ocean'],
            'yt_ocean': data.coords['yt_ocean'],
            'month': np.arange(1, 13)
        },
        dims=['yt_ocean', 'xt_ocean', 'month']
    )

    # Define the output file path
    file_path_90th = f'/g/data/ia39/ncra/ocean/peacey/{variable}_percentile_monthly__{period}.nc'

    # Save 90th percentile as netCDF
    percentile_90_da.to_netcdf(file_path_90th, compute=True)

In [None]:
# Define the periods for processing
GWL_periods = {
    'current': ('1995-01-01', '2014-12-31'),
    'GW1p2': ('2001-01-01', '2020-12-31'),
    'GW1p5': ('2015-01-01', '2034-12-31'),
    'GW2p0': ('2030-01-01', '2049-12-31'),
    'GW3p0': ('2053-01-01', '2072-12-31'),
    'GW4p0': ('2074-01-01', '2093-12-31')
}

# Directory paths for SST
dir1_new = '/g/data/fp2/OFAM3/jra55_historical.1/surface/'
dir2_new = '/g/data/fp2/OFAM3/jra55_rcp8p5/surface/'

# Load datasets with chunking
dsst1 = xr.open_mfdataset(dir1_new + 'ocean_temp_sfc_*', parallel=True, 
                                        preprocess = lambda x: drop_stuff(x, 
                                         coords_to_drop, 
                                         vars_to_drop)).squeeze()
dsst2 = xr.open_mfdataset(dir2_new + 'ocean_temp_sfc*.nc', parallel=True, 
                                        preprocess = lambda x: drop_stuff(x, 
                                         coords_to_drop, 
                                         vars_to_drop)).squeeze()

In [None]:
# Process SST
for period, (start, end) in GWL_periods.items():
    if period == 'current':
        process_threshold(dsst1, 'Time', start, end, 'temp', period)
    else:
        process_threshold(dsst2, 'Time', start, end, 'temp', period)

client.close()