In [None]:
import sys
import os

### Data handling
import numpy as np
import pandas as pd
import xarray as xr
import scipy as sci

### Plotting
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cmocean.cm as cmo

### Marine heatwaves python package
from xmhw.xmhw import threshold, detect

In [None]:
from dask.distributed import Client
client = Client(threads_per_worker=1)
client

In [None]:
# Define constants
WRKDIR = "/g/data/fp2/OFAM3"
COORDS_TO_DROP = ['st_edges_ocean', 'nv']
VARS_TO_DROP = ['Time_bounds', 'average_DT', 'average_T1', 'average_T2']
MIN_DURATION = 10


In [None]:
# Preprocessor to drop unwanted variables
def drop_stuff(ds):
    return ds.drop_vars(COORDS_TO_DROP + VARS_TO_DROP, errors='ignore')


In [None]:
# Load and preprocess SST data
def load_sst_data():
    sst = xr.open_mfdataset(
        os.path.join(WRKDIR, "jra55_rcp8p5/surface/ocean_temp_sfc_205*.nc"),
        parallel=True,
        preprocess=drop_stuff
    ).squeeze()
    return sst

In [None]:
# Add 'year' and 'doy' coordinates
def add_time_coords(sst):
    sst['year'] = sst['time'].dt.year
    sst['doy'] = (('time',), np.array([pd.Timestamp(t).dayofyear for t in sst['time'].values]))
    return sst.set_index(time=['year', 'doy'])

In [None]:
# Detect MHWs and calculate differences
def detect_mhws_and_diff(sst, threshold):
    mhws, diff = {}, {}
    for yr in np.unique(sst['year'].values):
        tmp = sst.sel(year=yr)
        mhws[str(yr)] = tmp > threshold
        diff[str(yr)] = tmp - threshold
        mhws[str(yr)] = mhws[str(yr)].rename({"doy": "time"}).assign_coords(
            time=pd.date_range(start=f"{yr}-01-01", end=f"{yr}-12-31", freq='D')
        )
        diff[str(yr)] = diff[str(yr)].rename({"doy": "time"}).assign_coords(
            time=pd.date_range(start=f"{yr}-01-01", end=f"{yr}-12-31", freq='D')
        )
    return xr.concat(list(mhws.values()), dim='time'), xr.concat(list(diff.values()), dim='time')

In [None]:
# Label continuous events in time
def label_events(da):
    shifted = da.shift(time=1, fill_value=False)
    new_event_start = da & ~shifted
    event_labels = new_event_start.cumsum(dim='time')
    return event_labels.where(da)

In [None]:
# Calculate MHW characteristics
def calculate_mhw_characteristics(event_labels, diff):
    max_event_number = int(event_labels.max().values)
    mhws_ds = xr.Dataset({
        'duration': xr.DataArray(
            np.nan, dims=('event', 'yt_ocean', 'xt_ocean'),
            coords={'event': np.arange(1, max_event_number + 1), 'yt_ocean': event_labels.coords['yt_ocean'].values, 'xt_ocean': event_labels.coords['xt_ocean'].values}
        ),
        'intensity_mean': xr.DataArray(
            np.nan, dims=('event', 'yt_ocean', 'xt_ocean'),
            coords={'event': np.arange(1, max_event_number + 1), 'yt_ocean': event_labels.coords['yt_ocean'].values, 'xt_ocean': event_labels.coords['xt_ocean'].values}
        ),
        'intensity_max': xr.DataArray(
            np.nan, dims=('event', 'yt_ocean', 'xt_ocean'),
            coords={'event': np.arange(1, max_event_number + 1), 'yt_ocean': event_labels.coords['yt_ocean'].values, 'xt_ocean': event_labels.coords['xt_ocean'].values}
        )
    })
    for event_id in range(1, max_event_number + 1):
        event_mask = event_labels == event_id
        mhws_ds['duration'].loc[{'event': event_id}] = event_mask.sum(dim='time')
        mhws_ds['intensity_mean'].loc[{'event': event_id}] = diff.where(event_mask).mean(dim='time')
        mhws_ds['intensity_max'].loc[{'event': event_id}] = diff.where(event_mask).max(dim='time')

    return mhws_ds

In [None]:
# Plot variables
def plot_variable(data, variable, title, cmap, levels, lon_name='xt_ocean', lat_name='yt_ocean'):
    fig, ax = plt.subplots(figsize=(12, 6), subplot_kw={'projection': ccrs.PlateCarree(central_longitude=180)})
    fig.suptitle(title, fontsize=20)

    lon = data[lon_name]
    lat = data[lat_name]
    data_2d = data.squeeze()

    img = ax.pcolormesh(lon, lat, data_2d, cmap=cmap, transform=ccrs.PlateCarree(), vmin=levels[0], vmax=levels[-1])
    ax.coastlines()
    ax.add_feature(cfeature.BORDERS, linestyle=':')
    ax.set_extent([0, 360, -90, 0], crs=ccrs.PlateCarree())

    cbar = plt.colorbar(img, ax=ax, orientation='vertical', shrink=0.5, pad=0.05)
    cbar.set_label(variable)

    plt.show()

In [None]:
# Main Workflow
def main():
    # Load data
    sst = load_sst_data()
    climatology = xr.open_dataset('/g/data/xv83/users/ep5799/Heatwaves/Australian_SST_daily_climatology.nc')['temp']
    threshold90 = xr.open_dataset('/g/data/xv83/users/ep5799/Heatwaves/Australian_SST_daily_MHWthreshold.nc')['temp']

    # Add coordinates
    sst1 = add_time_coords(sst).chunk({'time': -1, 'yt_ocean': 50, 'xt_ocean': 50})

    # Detect MHWs and differences
    mhws, diff = detect_mhws_and_diff(sst1, threshold90)

    # Plot initial results
    plot_variable(mhws.mean(dim='time'), 'MHWs', 'Mean MHWs', 'viridis', [0, 1])
    plot_variable(diff.isel(time=365), 'SST Difference', 'SST Difference on Day 365', cmo.balance, [-2, 2])

    # Calculate MHW occurrence
    mhws_occurrence = mhws.rolling(time=MIN_DURATION, center=False).sum() >= MIN_DURATION
    plot_variable(mhws_occurrence.mean(dim='time'), 'MHW Occurrence', 'Mean MHW Occurrence', 'viridis', [0, 1])

    # Label events
    event_labels = label_events(mhws_occurrence).compute()

    # Calculate MHW characteristics
    mhws_ds = calculate_mhw_characteristics(event_labels, diff)

    # Plot MHW characteristics
    plot_variable(mhws_ds['duration'].isel(event=0), 'Duration', 'MHW Duration', 'viridis', [0, 365])
    plot_variable(mhws_ds['intensity_mean'].isel(event=0), 'Mean Intensity', 'Mean Intensity of MHW', cmo.balance, [-2, 2])
    plot_variable(mhws_ds['intensity_max'].isel(event=0), 'Max Intensity', 'Max Intensity of MHW', cmo.balance, [-2, 2])

if __name__ == "__main__":
    main()