In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import scipy as sci

### plotting
import matplotlib.pyplot as plt

### flox for GroupBy Reductions
import flox.xarray

import dask.array as da
from dask.distributed import Client


In [2]:
client = Client(threads_per_worker=1)

2024-06-07 13:45:41,976 - distributed.preloading - INFO - Creating preload: /g/data/hh5/public/apps/dask-optimiser/schedplugin.py
2024-06-07 13:45:41,981 - distributed.utils - INFO - Reload module schedplugin from .py file
2024-06-07 13:45:42,169 - distributed.preloading - INFO - Import preload module: /g/data/hh5/public/apps/dask-optimiser/schedplugin.py


Modifying workers


In [3]:
# preprocesser to drop unwanted variables
def drop_stuff(ds, coords_to_drop,vars_to_drop):
    """
    Preprocessor function to drop specified coordinates and variables from a dataset loaded via xr.open_mfdataset

    Parameters:
        ds (xarray.Dataset): The dataset from which coordinates & variables are to be dropped.
        coords_to_drop (list of str): List of coordinate names to drop.
        vars_to_drop(list of str): List of variable names to drop

    Returns:
        xarray.Dataset: Dataset with specified coordinates and variables dropped.
    """
    # Drop coordinates if they are in the dataset
    ds = ds.drop_vars(coords_to_drop, errors='ignore')
    ds = ds.drop_vars(vars_to_drop, errors='ignore')
    return ds

In [4]:
def print_chunks(data_array):
    ''' 
    Print the chunk sizes for each dimension in an Xarray dataset
    
    Parameters: 
    data_array (xarray.Daraset): The dataset from which to print chunks
    '''
    # Get chunk size
    chunks = data_array.chunks
    # Get dimension names
    dim_names = data_array.dims
    
    readable_chunks = {dim: chunks[i] for i, dim in enumerate(dim_names)}
    
    # Print chunk sizes for each dimension
    for dim, sizes in readable_chunks.items():
        print(f"{dim} chunks: {sizes}")

In [5]:
%%time

# Define coordinates and variables to drop
coords_to_drop =['st_edges_ocean','nv']
vars_to_drop =['Time_bounds','average_DT','average_T1','average_T2']

dir = '/g/data/fp2/OFAM3/jra55_historical.1/surface/'

# Load the datasets with preprocessing
sst = xr.open_mfdataset(
    dir + 'ocean_temp_sfc_*.nc', 
    parallel=True, 
    preprocess = lambda x: drop_stuff(x, 
                                         coords_to_drop, 
                                         vars_to_drop)).squeeze() #combine='by_coords' is default

# Rename Time to time
sst = sst.rename({'Time':'time'})

sst 

CPU times: user 10.4 s, sys: 1.96 s, total: 12.3 s
Wall time: 26.9 s


Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,638.58 MiB
Shape,"(13149, 1500, 3600)","(31, 1500, 3600)"
Dask graph,432 chunks in 866 graph layers,432 chunks in 866 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 264.51 GiB 638.58 MiB Shape (13149, 1500, 3600) (31, 1500, 3600) Dask graph 432 chunks in 866 graph layers Data type float32 numpy.ndarray",3600  1500  13149,

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,638.58 MiB
Shape,"(13149, 1500, 3600)","(31, 1500, 3600)"
Dask graph,432 chunks in 866 graph layers,432 chunks in 866 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [6]:
# Function to calculate monthly climatology and convert to single precision
def monthly_climatology(ds, time_dim):
    """
    Calculate the monthly climatology for a given dataset.

    This function takes an xarray dataset and computes the monthly climatology
    by averaging data for each month over all years. The function adds a 'month'
    dimension to the dataset based on the provided time dimension.

    Parameters:
        ds (xarray.Dataset): The input dataset containing the time dimension.
        time_dim (str): The name of the time dimension in the dataset.

    Returns
        xarray.Dataset: Dataset with the mean values for each month across all years. 
        The data type of returned values is float32.
    """
    ds['month'] = ds[time_dim].dt.month
    climatology_cohorts = flox.xarray.xarray_reduce(
        ds,
        'month',
        func='mean',
        method='cohorts',
    )
    return climatology_cohorts.astype(np.float32)

In [7]:
# Process and save climatology data
def process_climatology(ds, time_dim, variable):
    data = getattr(ds, variable) 
    clim = monthly_climatology(data, time_dim).persist()

    file_path = '/g/data/xv83/users/ep5799/Heatwaves/Australian_SST_monthly_climatology.nc'
    
    # Save as netCDF
    clim.to_netcdf(file_path, compute=True)

In [8]:
%%time
#process_climatology(sst, 'time', 'temp')

CPU times: user 3 µs, sys: 14 µs, total: 17 µs
Wall time: 40.5 µs


In [None]:
%%time
sst_chunked = sst.chunk({'time': -1, 'xt_ocean': 36, 'yt_ocean': 36})

# Group by Day of Year (DOY)
doy = sst_chunked['time'].dt.dayofyear

sst_chunked = sst_chunked.assign_coords(doy=doy)

def calc_90th_percentile(group):
    return group.quantile(0.9, dim='time', skipna=True).astype(np.float32)


result = sst_chunked['temp'].groupby('doy').apply(calc_90th_percentile).compute()

In [None]:
threshold90 = result.pad(dayofyear=(31-1)//2, mode='wrap').rolling(dayofyear=31, center=True).mean(skipna=True)

In [None]:
%%time
threshold90 = threshold90.chunk({'dayofyear':-1, 'yt_ocean':50, 'xt_ocean':50}).isel(dayofyear=slice(15,-15)).drop_vars('quantile')

In [None]:
print("Size (Mb) of daily threshold90 = %i"%(threshold90.nbytes/1e6))

In [None]:
# temp = threshold90['temp']
# xt_ocean = threshold90['xt_ocean']
# yt_ocean = threshold90['yt_ocean']

# # Plotting
# plt.figure(figsize=(12, 6))
# plt.pcolormesh(xt_ocean, yt_ocean, temp, cmap='viridis', shading='auto')
# plt.colorbar(label='Temperature (°C)')
# plt.title(f'Temperature for DOY {ds.dayofyear.item()}')
# plt.xlabel('Longitude')
# plt.ylabel('Latitude')
# plt.grid(True)
# plt.show()

In [None]:
threshold90.to_netcdf('/g/data/xv83/users/ep5799/Heatwaves/Australian_SST_monthly_MHWthreshold.nc', mode='w', compute=True)