# XMHW tests on the OFAM3 dataset

Purpose
-------
    The following will investigate the capability of xmhw to parallelise the MHW analysis on a subset of temperature data from the OFAM3 - 10th degree resolution global simulation from 1980-2100. The simulation runs from 1980 to 2006 under JRA55 atmospheric forcing, and thereafter the reanalysis is repeated but with the addition of the RCP8.5 climate trend.

    Contents:
        1. Load in Temperature Data and visualise (2D in space, 1D in time)
        2. Select the region around Australia to perform the heatwave analysis and throw rest away
        3. Calculate the climatology required for the heatwave analysis and save as a new netcdf file
            [ this will be read in later and in a new session for performing the heatwave analysis ]
        4. Perform heatwave analysis using xmhw by iterating around the subsetted grid

Thanks to John Reilly for sharing his [code](https://github.com/Thomas-Moore-Creative/shared_sandbox/blob/main/mhw-3d-scalingTests-gadiJup.ipynb)
    


 some sandbox edits here from Thomas Moore - 27 April 2024

### imports

In [1]:
import sys
import os

### data handling
import numpy as np
import pandas as pd
import xarray as xr
import scipy as sci

### plotting
import matplotlib.pyplot as plt
from matplotlib import ticker
from matplotlib.gridspec import GridSpec
import matplotlib.colors as mcolors
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cmocean.cm as cmo
from cmocean.tools import lighten

### marine heatwaves python package
from xmhw.xmhw import threshold, detect

# print versions of packages
print("python version =",sys.version[:5])
print("numpy version =", np.__version__)
print("pandas version =", pd.__version__)
print("xarray version =", xr.__version__)
print("scipy version =", sci.__version__)
print("matplotlib version =", sys.modules[plt.__package__].__version__)
print("cmocean version =", sys.modules[cmo.__package__].__version__)
print("cartopy version =", sys.modules[ccrs.__package__].__version__)

python version = 3.10.
numpy version = 1.26.4
pandas version = 2.2.1
xarray version = 2024.3.0
scipy version = 1.12.0
matplotlib version = 3.8.3
cmocean version = v3.0.3
cartopy version = 0.22.0


### import the dask client for assessing performance

In [2]:
from dask.distributed import Client
client = Client(threads_per_worker=1)
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: /proxy/8787/status,

0,1
Dashboard: /proxy/8787/status,Workers: 28
Total threads: 28,Total memory: 251.18 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:39755,Workers: 28
Dashboard: /proxy/8787/status,Total threads: 28
Started: Just now,Total memory: 251.18 GiB

0,1
Comm: tcp://127.0.0.1:36981,Total threads: 1
Dashboard: /proxy/38595/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:36469,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-vauty8b0,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-vauty8b0

0,1
Comm: tcp://127.0.0.1:38997,Total threads: 1
Dashboard: /proxy/38755/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:36489,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-haja_ktz,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-haja_ktz

0,1
Comm: tcp://127.0.0.1:40087,Total threads: 1
Dashboard: /proxy/43203/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:33315,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-o21lphxn,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-o21lphxn

0,1
Comm: tcp://127.0.0.1:42537,Total threads: 1
Dashboard: /proxy/45961/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:46609,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-rvqcbyww,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-rvqcbyww

0,1
Comm: tcp://127.0.0.1:33261,Total threads: 1
Dashboard: /proxy/38155/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:33601,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-ioaohhpo,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-ioaohhpo

0,1
Comm: tcp://127.0.0.1:37187,Total threads: 1
Dashboard: /proxy/41115/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:40341,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-wq0ubmme,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-wq0ubmme

0,1
Comm: tcp://127.0.0.1:41267,Total threads: 1
Dashboard: /proxy/35123/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:42417,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-i2ysvok4,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-i2ysvok4

0,1
Comm: tcp://127.0.0.1:46249,Total threads: 1
Dashboard: /proxy/39677/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:41343,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-ld2q2khe,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-ld2q2khe

0,1
Comm: tcp://127.0.0.1:38247,Total threads: 1
Dashboard: /proxy/32801/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:33939,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-l6gcd7et,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-l6gcd7et

0,1
Comm: tcp://127.0.0.1:42357,Total threads: 1
Dashboard: /proxy/42813/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:40579,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-073kt0kv,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-073kt0kv

0,1
Comm: tcp://127.0.0.1:45035,Total threads: 1
Dashboard: /proxy/43251/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:43403,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-g6qsxvr1,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-g6qsxvr1

0,1
Comm: tcp://127.0.0.1:34253,Total threads: 1
Dashboard: /proxy/37017/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:34801,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-cacg5oys,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-cacg5oys

0,1
Comm: tcp://127.0.0.1:33897,Total threads: 1
Dashboard: /proxy/35227/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:34357,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-1yiaav0g,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-1yiaav0g

0,1
Comm: tcp://127.0.0.1:40815,Total threads: 1
Dashboard: /proxy/35929/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:39929,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-8of3yeef,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-8of3yeef

0,1
Comm: tcp://127.0.0.1:44521,Total threads: 1
Dashboard: /proxy/38845/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:46465,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-zvcejy0_,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-zvcejy0_

0,1
Comm: tcp://127.0.0.1:40117,Total threads: 1
Dashboard: /proxy/33613/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:39959,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-e0fy1phr,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-e0fy1phr

0,1
Comm: tcp://127.0.0.1:42493,Total threads: 1
Dashboard: /proxy/33983/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:34131,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-g1giertd,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-g1giertd

0,1
Comm: tcp://127.0.0.1:38615,Total threads: 1
Dashboard: /proxy/44463/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:46391,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-o6qmpn66,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-o6qmpn66

0,1
Comm: tcp://127.0.0.1:42043,Total threads: 1
Dashboard: /proxy/42777/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:34387,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-ffibxemo,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-ffibxemo

0,1
Comm: tcp://127.0.0.1:35959,Total threads: 1
Dashboard: /proxy/36857/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:37163,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-8uxhbwl6,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-8uxhbwl6

0,1
Comm: tcp://127.0.0.1:40331,Total threads: 1
Dashboard: /proxy/35523/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:35549,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-ic1xiybp,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-ic1xiybp

0,1
Comm: tcp://127.0.0.1:37503,Total threads: 1
Dashboard: /proxy/40865/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:33597,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-53jbn11d,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-53jbn11d

0,1
Comm: tcp://127.0.0.1:34171,Total threads: 1
Dashboard: /proxy/46533/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:43785,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-4gn0bj53,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-4gn0bj53

0,1
Comm: tcp://127.0.0.1:37663,Total threads: 1
Dashboard: /proxy/38629/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:33907,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-c3kztri4,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-c3kztri4

0,1
Comm: tcp://127.0.0.1:41579,Total threads: 1
Dashboard: /proxy/43371/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:38073,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-91coko39,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-91coko39

0,1
Comm: tcp://127.0.0.1:34845,Total threads: 1
Dashboard: /proxy/40163/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:46623,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-npxnhuvh,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-npxnhuvh

0,1
Comm: tcp://127.0.0.1:43721,Total threads: 1
Dashboard: /proxy/44085/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:40563,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-6ljnv4js,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-6ljnv4js

0,1
Comm: tcp://127.0.0.1:35947,Total threads: 1
Dashboard: /proxy/40879/status,Memory: 8.97 GiB
Nanny: tcp://127.0.0.1:34087,
Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-hzb0fql2,Local directory: /jobfs/114626156.gadi-pbs/dask-scratch-space/worker-hzb0fql2


## grab the future temperature data from fp2

In [7]:
wrkdir = "/g/data/fp2/OFAM3" #not using os.chdir() 

### what is native chunking ???
```
/g/data/fp2/OFAM3/jra55_rcp8p5/surface du -hs ocean_temp_sfc_2050_03.nc
320M	ocean_temp_sfc_2050_03.nc

short temp(Time, st_ocean, yt_ocean, xt_ocean) ;
		temp:long_name = "Potential temperature" ;
		temp:units = "degrees C" ;
		temp:valid_range = -32767s, 32767s ;
		temp:missing_value = -32768s ;
		temp:_FillValue = -32768s ;
		temp:packing = 4 ;
		temp:scale_factor = 0.001678518f ;
		temp:add_offset = 45.f ;
		temp:cell_methods = "time: mean" ;
		temp:time_avg_info = "average_T1,average_T2,average_DT" ;
		temp:coordinates = "geolon_t geolat_t" ;
		temp:standard_name = "sea_water_potential_temperature" ;

```
#### `du -hs` reveals no chunking information. Is temp a single chunk? Short variable is 320MB so bigger float value expected once loaded

In [21]:
# preprocesser to drop unwanted variables
def drop_stuff(ds, coords_to_drop,vars_to_drop):
    """
    Preprocessor function to drop specified coordinates and variables from a dataset loaded via xr.open_mfdataset

    Parameters:
        ds (xarray.Dataset): The dataset from which coordinates & variables are to be dropped.
        coords_to_drop (list of str): List of coordinate names to drop.
        vars_to_drop(list of str): List of variable names to drop

    Returns:
        xarray.Dataset: Dataset with specified coordinates and variables dropped.
    """
    # Drop coordinates if they are in the dataset
    ds = ds.drop_vars(coords_to_drop, errors='ignore')
    ds = ds.drop_vars(vars_to_drop, errors='ignore')
    return ds

In [22]:
%%time
coords_to_drop =['st_edges_ocean','nv']
vars_to_drop =['Time_bounds','average_DT','average_T1','average_T2']
sst = xr.open_mfdataset(wrkdir + "/jra55_rcp8p5/surface/ocean_temp_sfc_205*.nc", parallel=True,preprocess=lambda x: drop_stuff(x, coords_to_drop,vars_to_drop)).squeeze() #combine='by_coords' is default
sst

CPU times: user 1.04 s, sys: 103 ms, total: 1.14 s
Wall time: 1.16 s


Unnamed: 0,Array,Chunk
Bytes,73.47 GiB,638.58 MiB
Shape,"(3652, 1500, 3600)","(31, 1500, 3600)"
Dask graph,120 chunks in 242 graph layers,120 chunks in 242 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 73.47 GiB 638.58 MiB Shape (3652, 1500, 3600) (31, 1500, 3600) Dask graph 120 chunks in 242 graph layers Data type float32 numpy.ndarray",3600  1500  3652,

Unnamed: 0,Array,Chunk
Bytes,73.47 GiB,638.58 MiB
Shape,"(3652, 1500, 3600)","(31, 1500, 3600)"
Dask graph,120 chunks in 242 graph layers,120 chunks in 242 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [None]:
print("Future SST dataset = %i Gb"%(sst.nbytes/1e9))

## fetch the climatology and threshold that were calculated in the other notebook

In [None]:
os.chdir("/g/data/es60/pjb581/heatwaves")
climatology = xr.open_dataset('Australian_SST_daily_climatology.nc')['temp']
threshold90 = xr.open_dataset('Australian_SST_daily_MHWthreshold.nc')['temp']

threshold90


## detect MHWs

In [None]:
sst1 = sst.copy(deep=True).chunk({'time':-1, 'yt_ocean':50, 'xt_ocean':50})

print("Adding 'year' as a new coordinate") 
sst1['year'] = sst1['time'].dt.year

print("Adjust for proper day indexing in leap years")
def adjust_dayofyear(times):
    return np.array([pd.Timestamp(t).dayofyear for t in times])

print("Adding 'doy' as a new coordinate")
print(" This is essential because we need to be able to compare the threshold with dimensions ('doy', 'lat', 'lon') to the new SST dataset")
print(" i.e., at each day of the year, we see if the SST is greater than the threshold")
sst1['doy'] = (('time',), adjust_dayofyear(sst1['time'].values))

sst1 = sst1.set_index(time=['year', 'doy'])
sst1


In [None]:
print("Detect MHWs")

mhws = {}
diff = {}

for yr in np.unique(sst1['year'].values):
    print(yr)
    tmp = sst1.sel(year=yr)
    if (yr == 2052) or (yr == 2056):
        print("leap")
        mhws['%i'%(yr)] = tmp > threshold90
        diff['%i'%(yr)] = tmp - threshold90
    else:
        mhws['%i'%(yr)] = tmp > threshold90
        diff['%i'%(yr)] = tmp - threshold90
    mhws['%i'%(yr)] = mhws['%i'%(yr)].rename({"doy":"time"}).assign_coords(time = pd.date_range(start=f"{yr}-01-01", end=f"{yr}-12-31", freq='D'))
    diff['%i'%(yr)] = diff['%i'%(yr)].rename({"doy":"time"}).assign_coords(time = pd.date_range(start=f"{yr}-01-01", end=f"{yr}-12-31", freq='D'))

print("Concatenate the years")
tmp = list(mhws.values())
mhws = xr.concat(tmp, dim='time')

tmp = list(diff.values())
diff = xr.concat(tmp, dim='time')

del tmp

In [None]:
plt.figure()
plt.pcolormesh(mhws.mean(dim='time'))
plt.colorbar()

plt.figure()
plt.pcolormesh(diff.isel(time=365), cmap = cmo.balance, vmin=-2, vmax=2)
plt.colorbar()



## Now apply a rolling mean on the data in time to determine where MHWs occur
    If the sum over a window of x days is at least x, it indicates that the SST was above the threshold for x consecutive days.

In [None]:
%%time

### set minimum duration of a heatwave
min_duration = 10
mhws_occurance = mhws.rolling(time=min_duration, center=False).sum() >= min_duration

plt.figure()
plt.pcolormesh(mhws_occurance.mean(dim='time'))
plt.colorbar()

plt.figure()
plt.plot(mhws.coords['time'], mhws.isel(xt_ocean=100,yt_ocean=100))
plt.plot(mhws_occurance.coords['time'], mhws_occurance.isel(xt_ocean=100,yt_ocean=100))



## create an event label data array
    all events at each point in space are numbered from 1 onwards

In [None]:
%%time
# Label continuous events in time
def label_events(da):
    shifted = da.shift(time=1, fill_value=False)
    new_event_start = da & ~shifted
    event_labels = new_event_start.cumsum(dim='time')
    return event_labels.where(da)

event_labels = label_events(mhws_occurance).compute()


In [None]:
plt.figure()
plt.pcolormesh(event_labels.isel(time=365))
plt.colorbar()

plt.figure()
plt.pcolormesh(diff.isel(time=365), cmap = cmo.balance, vmin=-2, vmax=2)
plt.colorbar()


In [None]:
%%time

# Find maximum number of events in the dataset
max_event_number = int(event_labels.max().values)
print("Max number of events over the grid =", max_event_number)

# Create a new Dataset with the event dimension set to 71
mhws_ds = xr.Dataset({
    'duration': xr.DataArray(
        np.nan, 
        dims=('event', 'yt_ocean', 'xt_ocean'),
        coords={
            'event': np.arange(1, max_event_number+1),  # Using 1-based indexing for events
            'yt_ocean': event_labels.coords['yt_ocean'].values,
            'xt_ocean': event_labels.coords['xt_ocean'].values
        }
    ),
    'intensity_mean': xr.DataArray(
        np.nan, 
        dims=('event', 'yt_ocean', 'xt_ocean'),
        coords={
            'event': np.arange(1, max_event_number+1),  # Using 1-based indexing for events
            'yt_ocean': event_labels.coords['yt_ocean'].values,
            'xt_ocean': event_labels.coords['xt_ocean'].values
        }
    ),
    'intensity_max': xr.DataArray(
        np.nan, 
        dims=('event', 'yt_ocean', 'xt_ocean'),
        coords={
            'event': np.arange(1, max_event_number+1),  # Using 1-based indexing for events
            'yt_ocean': event_labels.coords['yt_ocean'].values,
            'xt_ocean': event_labels.coords['xt_ocean'].values
        }
    )
})
mhws_ds


### calculate the duration of the MHWs

In [None]:
%%time

for event_id in range(1, max_event_number+1):
    if event_id%5 == 0:
        print(event_id)
    event_mask = event_labels == event_id
    event_duration = event_mask.sum(dim='time')  # Counts the days per event
    
    # Place the calculated duration into the correct position in the new dataset
    mhws_ds['duration'].loc[{'event': event_id}] = event_duration



In [None]:
diff = diff.compute()


### Calculate the mean intensity of the MHWs

In [None]:
%%time

for event_id in range(1, max_event_number+1):
    if event_id%5 == 0:
        print(event_id)
    event_meanintensity = diff.where(event_labels == event_id).mean(dim='time')
    
    # Place the calculated mean intensity into the correct position in the new dataset
    mhws_ds['intensity_mean'].loc[{'event': event_id}] = event_meanintensity


### Calculate the maximum intensity of the MHWs

In [None]:
%%time

for event_id in range(1, max_event_number+1):
    if event_id%5 == 0:
        print(event_id)
    event_maxintensity = diff.where(event_labels == event_id).max(dim='time')
    
    # Place the calculated max intensity into the correct position in the new dataset
    mhws_ds['intensity_max'].loc[{'event': event_id}] = event_maxintensity


In [None]:
plt.figure()
mhws_ds['duration'].isel(event=0).plot()
plt.figure()
mhws_ds['intensity_mean'].isel(event=0).plot()
plt.figure()
mhws_ds['intensity_max'].isel(event=0).plot()


## save to disk

In [None]:
%%time
os.chdir("/g/data/es60/pjb581/heatwaves")
os.getcwd()

print("Saving mhws_ds to disk")
mhws_ds.to_netcdf('Australian_MHWs_2050s.nc', mode='w')
