In [1]:
import sys
import os

### Data handling
import numpy as np
import pandas as pd
import xarray as xr
import scipy as sci

### Plotting
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cmocean.cm as cmo

### Marine heatwaves python package
from xmhw.xmhw import threshold, detect
from dask.distributed import Client
import dask.array as da
import dask.dataframe as dd
client = Client(threads_per_worker=1)
client

2024-06-14 15:50:16,607 - distributed.preloading - INFO - Creating preload: /g/data/hh5/public/apps/dask-optimiser/schedplugin.py
2024-06-14 15:50:16,610 - distributed.utils - INFO - Reload module schedplugin from .py file
2024-06-14 15:50:16,686 - distributed.preloading - INFO - Import preload module: /g/data/hh5/public/apps/dask-optimiser/schedplugin.py


Modifying workers


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: /node/gadi-hmem-bdw-0009.gadi.nci.org.au/21143/proxy/8787/status,

0,1
Dashboard: /node/gadi-hmem-bdw-0009.gadi.nci.org.au/21143/proxy/8787/status,Workers: 7
Total threads: 7,Total memory: 0 B
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:44369,Workers: 7
Dashboard: /node/gadi-hmem-bdw-0009.gadi.nci.org.au/21143/proxy/8787/status,Total threads: 7
Started: Just now,Total memory: 0 B

0,1
Comm: tcp://127.0.0.1:34355,Total threads: 1
Dashboard: /node/gadi-hmem-bdw-0009.gadi.nci.org.au/21143/proxy/44405/status,Memory: 0 B
Nanny: tcp://127.0.0.1:37099,
Local directory: /jobfs/118294990.gadi-pbs/dask-worker-space/worker-vs9m3eh7,Local directory: /jobfs/118294990.gadi-pbs/dask-worker-space/worker-vs9m3eh7

0,1
Comm: tcp://127.0.0.1:43781,Total threads: 1
Dashboard: /node/gadi-hmem-bdw-0009.gadi.nci.org.au/21143/proxy/39049/status,Memory: 0 B
Nanny: tcp://127.0.0.1:43787,
Local directory: /jobfs/118294990.gadi-pbs/dask-worker-space/worker-ncduukj2,Local directory: /jobfs/118294990.gadi-pbs/dask-worker-space/worker-ncduukj2

0,1
Comm: tcp://127.0.0.1:40221,Total threads: 1
Dashboard: /node/gadi-hmem-bdw-0009.gadi.nci.org.au/21143/proxy/45873/status,Memory: 0 B
Nanny: tcp://127.0.0.1:45183,
Local directory: /jobfs/118294990.gadi-pbs/dask-worker-space/worker-dnc7o_ps,Local directory: /jobfs/118294990.gadi-pbs/dask-worker-space/worker-dnc7o_ps

0,1
Comm: tcp://127.0.0.1:44867,Total threads: 1
Dashboard: /node/gadi-hmem-bdw-0009.gadi.nci.org.au/21143/proxy/35261/status,Memory: 0 B
Nanny: tcp://127.0.0.1:38881,
Local directory: /jobfs/118294990.gadi-pbs/dask-worker-space/worker-qeiwciv5,Local directory: /jobfs/118294990.gadi-pbs/dask-worker-space/worker-qeiwciv5

0,1
Comm: tcp://127.0.0.1:35535,Total threads: 1
Dashboard: /node/gadi-hmem-bdw-0009.gadi.nci.org.au/21143/proxy/40257/status,Memory: 0 B
Nanny: tcp://127.0.0.1:46351,
Local directory: /jobfs/118294990.gadi-pbs/dask-worker-space/worker-rn0cw_yl,Local directory: /jobfs/118294990.gadi-pbs/dask-worker-space/worker-rn0cw_yl

0,1
Comm: tcp://127.0.0.1:46167,Total threads: 1
Dashboard: /node/gadi-hmem-bdw-0009.gadi.nci.org.au/21143/proxy/41099/status,Memory: 0 B
Nanny: tcp://127.0.0.1:43771,
Local directory: /jobfs/118294990.gadi-pbs/dask-worker-space/worker-l1b2af1k,Local directory: /jobfs/118294990.gadi-pbs/dask-worker-space/worker-l1b2af1k

0,1
Comm: tcp://127.0.0.1:39865,Total threads: 1
Dashboard: /node/gadi-hmem-bdw-0009.gadi.nci.org.au/21143/proxy/42431/status,Memory: 0 B
Nanny: tcp://127.0.0.1:42551,
Local directory: /jobfs/118294990.gadi-pbs/dask-worker-space/worker-wdbnfd_q,Local directory: /jobfs/118294990.gadi-pbs/dask-worker-space/worker-wdbnfd_q


In [2]:
# Define constants
WRKDIR = "/g/data/fp2/OFAM3"
coords_to_drop =['st_edges_ocean','nv','st_ocean']
vars_to_drop =['Time_bounds','average_DT','average_T1','average_T2','st_ocean']
MIN_DURATION = 10

# preprocesser to drop unwanted variables
def drop_stuff(ds, coords_to_drop,vars_to_drop):
    """
    Preprocessor function to drop specified coordinates and variables from a dataset loaded via xr.open_mfdataset

    Parameters:
        ds (xarray.Dataset): The dataset from which coordinates & variables are to be dropped.
        coords_to_drop (list of str): List of coordinate names to drop.
        vars_to_drop(list of str): List of variable names to drop

    Returns:
        xarray.Dataset: Dataset with specified coordinates and variables dropped.
    """
    # Drop coordinates if they are in the dataset
    ds = ds.drop_vars(coords_to_drop, errors='ignore')
    ds = ds.drop_vars(vars_to_drop, errors='ignore')
    return ds

In [3]:
threshold90 =  xr.open_dataset('/g/data/xv83/users/ep5799/Heatwaves/temp_90th_percentile_current.nc')['temp']
threshold90 = threshold90.chunk({'yt_ocean': 50, 'xt_ocean': 50})

In [4]:
def plot_variable(data, variable, title, cmap, levels, lon_name='xt_ocean', lat_name='yt_ocean'):
    fig, ax = plt.subplots(figsize=(12, 6), 
                           subplot_kw={'projection': ccrs.PlateCarree(central_longitude=180)})
    fig.suptitle(title, fontsize=20)

    lon = data[lon_name]
    lat = data[lat_name]
    data_2d = data.squeeze()

    img = ax.pcolormesh(lon, lat, 
                        data_2d, 
                        cmap=cmap, 
                        transform=ccrs.PlateCarree(), 
                        vmin=levels[0], 
                        vmax=levels[-1])
    ax.coastlines()
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    ax.set_extent([0, 360, -90, 0], crs=ccrs.PlateCarree())

    cbar = plt.colorbar(img, ax=ax, orientation='vertical', shrink=0.5, pad=0.05)
    cbar.set_label(variable)
    
    plt.xlim([-180, 180])
    plt.ylim([-90, 0])

    # Set ticks and labels for latitude and longitude
    ax.set_xticks([0, 60, 120, 180, 240, 300, 360], crs=ccrs.PlateCarree())
    ax.set_xticklabels(['', '60°E', '120°E', '180°', '120°W', '60°W', ''])
    ax.set_yticks(np.arange(-90, 1, 30), crs=ccrs.PlateCarree())
    ax.set_yticklabels(['90°S', '60°S', '30°S', '0°'])

    plt.show()

In [5]:
min_duration = 10

# Efficient loading and preprocessing
sst = xr.open_mfdataset("/g/data/fp2/OFAM3/jra55_rcp8p5/surface/ocean_temp_sfc_205*.nc",
                        parallel=True,
                        preprocess=lambda x: drop_stuff(x, coords_to_drop, vars_to_drop),
                        chunks={'Time': -1, 'yt_ocean': 50, 'xt_ocean': 50}).squeeze()


In [14]:
sst = sst.rename({"Time":"time"})

In [15]:
sst

Unnamed: 0,Array,Chunk
Bytes,73.47 GiB,302.73 kiB
Shape,"(3652, 1500, 3600)","(31, 50, 50)"
Dask graph,259200 chunks in 242 graph layers,259200 chunks in 242 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 73.47 GiB 302.73 kiB Shape (3652, 1500, 3600) (31, 50, 50) Dask graph 259200 chunks in 242 graph layers Data type float32 numpy.ndarray",3600  1500  3652,

Unnamed: 0,Array,Chunk
Bytes,73.47 GiB,302.73 kiB
Shape,"(3652, 1500, 3600)","(31, 50, 50)"
Dask graph,259200 chunks in 242 graph layers,259200 chunks in 242 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [16]:
# Scatter the threshold90 dataset to all workers
scattered_threshold90 = client.scatter(threshold90, broadcast=True).result()

# Adding coordinates and adjusting for leap years
sst['year'] = sst['time'].dt.year
sst['doy'] = (('time',), da.array([pd.Timestamp(t).dayofyear for t in sst['time'].values]))

# Detecting MHWs in parallel by year
years = range(1995, 2015)


In [17]:
def label_events(da):
    shifted = da.shift(time=1, fill_value=False)
    new_event_start = da & ~shifted
    event_labels = new_event_start.cumsum(dim='time')
    return event_labels.where(da)

In [None]:
# Initialize lists to store metrics
duration_list = []
intensity_mean_list = []

mhws = {}
diff = {}

for year in years:
    sst_year = sst.sel(time=slice(f'{year}-01-01', f'{year}-12-31'))
    mhws['%i'%(year)] = (sst_year > scattered_threshold90).compute()
    diff['%i'%(year)] = (sst_year - scattered_threshold90).compute()

# Concatenate the years
tmp = list(mhws.values())
mhws = xr.concat(tmp, dim='time')

tmp = list(diff.values())
diff = xr.concat(tmp, dim='time')

In [None]:
# Set minimum duration of a heatwave
min_duration = 10
mhws_occurance = mhws.rolling(time=min_duration, center=False).sum() >= min_duration

In [None]:
# Label events
event_labels = label_events(mhws_occurance).compute()
event_labels

In [None]:
max_event_number = int(event_labels.max().values)
    
print("Max number of events over the grid =", max_event_number)

mhws_ds = xr.Dataset({
    'duration': xr.DataArray(
        np.nan,
        dims=('event', 'yt_ocean', 'xt_ocean'),
        coords={
            'event': np.arange(1, max_event_number + 1),
            'yt_ocean': event_labels.coords['yt_ocean'].values,
            'xt_ocean': event_labels.coords['xt_ocean'].values
        }
    ),
    'intensity_mean': xr.DataArray(
        np.nan,
        dims=('event', 'yt_ocean', 'xt_ocean'),
        coords={
            'event': np.arange(1, max_event_number + 1),
            'yt_ocean': event_labels.coords['yt_ocean'].values,
            'xt_ocean': event_labels.coords['xt_ocean'].values
        }
    ),
    'intensity_max': xr.DataArray(
        np.nan,
        dims=('event', 'yt_ocean', 'xt_ocean'),
        coords={
            'event': np.arange(1, max_event_number + 1),
            'yt_ocean': event_labels.coords['yt_ocean'].values,
            'xt_ocean': event_labels.coords['xt_ocean'].values
        }
    )
})

In [None]:
for event_id in range(1, max_event_number + 1):
    if event_id % 5 == 0:
        print(event_id)

    # Calculate mean intensity
    event_mean_intensity = diff.where(event_labels == event_id).mean(dim='Time')

    # Update the dataset
    mhws_ds['intensity_mean'].loc[{'event': event_id}] = event_mean_intensity

for event_id in range(1, max_event_number + 1):
    if event_id % 5 == 0:
        print(event_id)

    # Calculate max intensity
    event_max_intensity = diff.where(event_labels == event_id).max(dim='Time')

    # Update the dataset
    mhws_ds['intensity_max'].loc[{'event': event_id}] = event_max_intensity

In [None]:
# # Calculate metrics in parallel
# duration_df = event_df.groupby('event').size().astype(np.float32)
# mean_intensity_df = diff.to_dask_dataframe().groupby('event').mean().astype(np.float32)

In [None]:
# # Append metrics to lists
# duration_list.append(duration_df)
# intensity_mean_list.append(mean_intensity_df)

In [None]:
# # Compute the lists
# duration_list = client.compute(duration_list)
# intensity_mean_list = client.compute(intensity_mean_list)

In [None]:
# Save to disk as a single file
duration_ds.to_netcdf('/g/data/ia39/ncra/ocean/peacey/mhw/MHWs_duration_current.nc', mode='w', compute=True)
intensity_mean_ds.to_netcdf('/g/data/ia39/ncra/ocean/peacey/mhw/MHWs_intensity_current.nc', mode='w', compute=True)

In [None]:
# Plot duration of MHWs
plot_variable(duration_concat['duration'], 
              'Duration (days)', 
              'Duration of Marine Heatwaves', 
              'viridis', 
              levels=np.arange(0, 100, 5))

# Plot intensity mean of MHWs
plot_variable(intensity_mean_concat['intensity_mean'], 
              'Intensity Mean (°C)', 
              'Intensity Mean of Marine Heatwaves', 
              'plasma', 
              levels=np.arange(0, 1.6, 0.1))

In [None]:
# Close the client
client.close()