## Developing an open-source pipeline to characterize the spatio-temporal evolution of marine heatwaves using the output of multiple object tracking

### Analysis Workflow

#### Introduction

In this Jupyter notebook, we summarize marine heatwave features. Our planned workflow is the following: (1) load dataset, (2) apply Ocetrac on dataset (we can change the radius and quartile for each simulation so that we generate several different sets of MHWs, which we will later try to define which set is useful), (3) from Ocetrac, we get a set of marine heatwave events and we define and summarize features, and (4) calculate feature distances, and (5) clustering with (5a) visual analysis and (5b) numerical analysis (stability). 

#### Note that in this notebook, we work with a smaller dataset

- Have applied Ocetrac to all ensemble members
- Collected marine heatwave events with an imprint in the northeast Pacific Ocean that last longer than 2 months and covers at least 25% of our pre-defined region.
- We now work with a dataset of 1131 heatwaves.

However, the script to load the entire dataset is also included in this notebook. This code is commented out. 

In [1]:
##### LOADING IN PACKAGES #--------------------------------------------------------------
import s3fs; import xarray as xr; import numpy as np
import pandas as pd; 
import dask.array as da
import ocetrac

import matplotlib.pyplot as plt; import cartopy.crs as ccrs

import warnings; import expectexception
warnings.filterwarnings('ignore')

import netCDF4 as nc; import datetime as dt
import scipy

import intake; import pprint
# Allow multiple lines per cell to be displayed without print (default is just last line)
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
# Enable more explicit control of DataFrame display (e.g., to omit annoying line numbers)
from IPython.display import HTML

### (1) Load and save the dataset

In [5]:
## Open original collection description file #----------------------------------------------
# cat_url_orig = '/glade/collections/cmip/catalog/intake-esm-datastore/catalogs/glade-cesm2-le.json'
# coll_orig = intake.open_esm_datastore(cat_url_orig)

In [6]:
# subset = coll_orig.search(component='atm',variable='SST',frequency='month_1',experiment='historical')
# member_id_list = subset.df.member_id.unique()
# print(member_id_list)

In [7]:
## for i in range(50,len(member_id_list)):
# for i in range(0,50):
# for i in range(5): # trying to see if this works
#     subset = coll_orig.search(component='atm',variable='SST',frequency='month_1',experiment='historical',member_id= str(member_id_list[i]))
#     dsets = subset.to_dataset_dict(zarr_kwargs={"consolidated": True}, storage_options={"anon": True})
#     ds = dsets['atm.historical.cam.h0.cmip6.SST'] # before 50
#     # ds = dsets['atm.historical.cam.h0.smbb.SST'] # after 50 # Ask Liz
#     SST = ds.SST.isel(member_id=0)
#     SST.load()
    
#     ###### DETRENDING 
#     dyr = SST.time.dt.year + (SST.time.dt.month-0.5)/12
#     # Our 6 coefficient model is composed of the mean, trend, annual sine and cosine harmonics, & semi-annual sine and cosine harmonics
#     model = np.array([np.ones(len(dyr))] + [dyr-np.mean(dyr)] + [np.sin(2*np.pi*dyr)] + [np.cos(2*np.pi*dyr)] + [np.sin(4*np.pi*dyr)] + [np.cos(4*np.pi*dyr)])
#     # Take the pseudo-inverse of model to 'solve' least-squares problem
#     pmodel = np.linalg.pinv(model)
#     model_da = xr.DataArray(model.T, dims=['time','coeff'], coords={'time':SST.time.values[-481:], 'coeff':np.arange(1,7,1)}) 
#     pmodel_da = xr.DataArray(pmodel.T, dims=['coeff','time'], coords={'coeff':np.arange(1,7,1), 'time':SST.time.values[-481:]})
#     # resulting coefficients of the model
#     sst_mod = xr.DataArray(pmodel_da.dot(SST), dims=['coeff','lat','lon'], coords={'coeff':np.arange(1,7,1), 'lat':SST.lat.values, 'lon':SST.lon.values})
#     # Construct mean, trend, and seasonal cycle
#     mean = model_da[:,0].dot(sst_mod[0,:,:])
#     trend = model_da[:,1].dot(sst_mod[1,:,:])
#     seas = model_da[:,2:].dot(sst_mod[2:,:,:])
#     # compute anomalies by removing all  the model coefficients 
#     ssta_notrend = SST-model_da.dot(sst_mod) #this is anomalies
#     detrended = ssta_notrend
#     detrended.to_netcdf('/glade/work/cassiacai/'+str(member_id_list[i])+'_detrended.nc')
    
#     ###### THRESHOLD and FEATURES
#     if detrended.chunks:
#         detrended = detrended.chunk({'time': -1})
#     threshold = detrended.groupby('time.month').quantile(0.9,dim=('time')) 
#     # can change 0.9 to another value depending on threshold
#     features_ssta = detrended.where(detrended.groupby('time.month')>=threshold, other=np.nan)
#     features_ssta = features_ssta[:,:,:].load()
#     ##### MASKING
#     full_mask_land = features_ssta
#     full_masked = full_mask_land.where(full_mask_land != 0)
#     binary_out_afterlandmask=np.isfinite(full_masked)
    
#     newmask = np.isfinite(ds.SST[0,:,:,:][:])
    
#     Tracker = ocetrac.Tracker(binary_out_afterlandmask[:,:,:], newmask, radius=3, min_size_quartile=0.75, timedim = 'time', xdim = 'lon', ydim='lat', positive=True)
#     # we define the minimum radius above as well as the minimum size quartile
#     blobs = Tracker.track()
#     blobs.attrs
#     mo = Tracker._morphological_operations()
#     blobs.to_netcdf('/glade/work/cassiacai/'+str(member_id_list[i])+'_rad3_blobs.nc')

In [8]:
# Open original collection description file #----------------------------------------------
# cat_url_orig = '/glade/collections/cmip/catalog/intake-esm-datastore/catalogs/glade-cesm2-le.json'
# coll_orig = intake.open_esm_datastore(cat_url_orig)

# subset = coll_orig.search(component='atm',variable='SST',frequency='month_1',experiment='historical')
# filenamechange = list(subset.df.member_id.unique())

# list_of_xarrays = []
# list_of_xarrays_SSTA = []

# for i in filenamechange:
    
#     string_head = '/glade/work/cassiacai/' + str(i) + '_rad3_blobs.nc'    
#     xarray_file = xr.open_dataset(str(string_head))
#     list_of_xarrays.append(xarray_file)
    
#     string_head_SSTA = '/glade/work/cassiacai/' + str(i) + '_detrended.nc'
#     xarray_file_SSTA = xr.open_dataset(str(string_head_SSTA))
#     list_of_xarrays_SSTA.append(xarray_file_SSTA)

# concated_xarray = xr.concat(list_of_xarrays, "new_dim")
# concated_xarray_SSTA = xr.concat(list_of_xarrays_SSTA, "new_dim")

# combined_xarray = xr.combine_by_coords([concated_xarray, concated_xarray_SSTA])
# combined_xarray['SSTA'] = combined_xarray['__xarray_dataarray_variable__']
# combined_xarray = combined_xarray.drop(['__xarray_dataarray_variable__'])

# for i in range(100):
#     member_ = combined_xarray.isel(new_dim = i)
#     member_.to_netcdf('/glade/work/cassiacai/member'+str(i)+'_events.nc')

##### From here, we begin using a smaller dataset.
We look at marine heatwave events from the last 40 years of CESM-LE simulations after running Ocetrac on the 100 CESM-LE simulations, setting a radius size 3. We currently will also work with only one ensemble member.

In [27]:
%%time

kept_heatwaves_all = []
for ens_memb in range(0,1):
    filename = '/glade/work/cassiacai/member{}_events.nc'.format(ens_memb)
    member_ = xr.open_dataset(filename)
    ending_val = len(np.unique(member_.labels))
    
    # kept_heatwaves = []
    for i in range(1,ending_val):
        mhw_id = i
        mhw_event_rel = member_.where(member_.labels==mhw_id, drop=False)
        x = mhw_event_rel.SSTA[:,:,:]
        no_nans_x = np.nan_to_num(x)
        no_nans_x[no_nans_x != 0] = 1
        
        sum_of_one = no_nans_x.sum(axis=(0))
        event_len = np.nanmax(sum_of_one)
        
        if np.nanmax(sum_of_one) > 2: 
            array_of_interest = np.copy(sum_of_one[:,:])
            array_of_interest[array_of_interest != 0] = 1
            if np.nansum(array_of_interest) / (42*53*0.01) >= 0: 
                mhw_event_sel = member_.where(member_.labels==i, drop=False)
                # kept_heatwaves.append(mhw_event_sel)
                kept_heatwaves_all.append(mhw_event_sel)

CPU times: user 2min 3s, sys: 1min 42s, total: 3min 46s
Wall time: 3min 53s


In [None]:
concated_kept_heatwaves_xarray = xr.concat(kept_heatwaves_all, "heatwave_label")

In [None]:
concated_kept_heatwaves_xarray