In [1]:
from datetime import datetime
import fsspec
from itertools import groupby, repeat
from multiprocessing import Pool
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os
import requests
import shutil
import warnings
import xarray as xr
import pandas as pd
from operator import itemgetter
import math

from ormhw.core import CE04,CE02,CE01, BBox, DATA_DIR, CURATED_DIR
from ormhw.oisstv2 import OISSTV2

## OISSTV2 Helper Functions

In [2]:
def group_events(subds, diff = '2D'):
    subds = subds.where(subds.gte == 1, drop = True)
    subds = subds.drop_vars('depth')
    df = subds.to_dataframe().reset_index()
    
    dt = df.time
    day = pd.Timedelta('1d')
    in_block = ((dt - dt.shift(-1)).abs() == day) | (dt.diff() == day)
    filt = df.loc[in_block]
    breaks = filt['time'].diff() != day
    groups = breaks.cumsum()
    df['group'] = groups
    df = df.dropna()
     
    new_df = pd.DataFrame()
    for group in np.unique(df.group):
        subdf = df[df.group == group]
        if len(subdf) < 5:
            category = 'spike'
        else:
            category = 'mhw'
        start = subdf.time.min()
        end = subdf.time.max()

        d = {'Ts': [start], 'Te': [end], 'group': [group],'type':[category]}
        new_df = pd.concat([new_df,pd.DataFrame(d)])
    new_df = new_df.reset_index(drop = True)
    
    spikes = new_df[new_df.type == 'spike']
    mhws = new_df[new_df.type == 'mhw']
    mhws = mhws.reset_index(drop = True)
    
    
    ndf = mhws.groupby(((mhws.Ts  - mhws.Te.shift(1)) > pd.Timedelta('2D')).cumsum()).agg({'Ts':'min', 'Te':'max'})
    ndflen = len(ndf)
    while True:
        ndf = ndf.groupby(((ndf.Ts  - ndf.Te.shift(1)) > pd.Timedelta('2D')).cumsum()).agg({'Ts':'min', 'Te':'max'})
        if len(ndf) == ndflen:
            break
        else:
            ndflen = len(ndf)
            
            
    ndf['type'] = 'mhw'
    ndf['D'] = (ndf.Te-ndf.Ts).dt.days + 1
    
    
    
    spikes = spikes.drop('group',axis = 1)
    spikes['D'] = (spikes.Te-spikes.Ts).dt.days + 1

    cdf = pd.concat([ndf, spikes])
    
    return ndf, spikes
    

## Curate OISSTV2 Data

In [6]:
%%time
oisstv2 = OISSTV2()
local_files = oisstv2.find_local_files()

ds = oisstv2.import_files(local_files, bounding_box = BBox)
ds.to_netcdf(f"{CURATED_DIR}/sst.nc")

dsc = oisstv2.build_climatology(ds, window = 11)
dsc.to_netcdf(f"{CURATED_DIR}/sst_clim_w11.nc")

dsp = oisstv2.build_percentile(ds, window = 11)
dsp.to_netcdf(f"{CURATED_DIR}/sst_90th_w11.nc")

dsgte = oisstv2.mask_gte(ds,dsp,years = range(2014, 2024))
dsgte.to_netcdf(f"{CURATED_DIR}/sst_mhw_mask.nc")

CPU times: user 5min 22s, sys: 23.4 s, total: 5min 46s
Wall time: 6min 42s


In [7]:
sites = ['CE01','CE02','CE04']
locs = [CE01, CE02, CE04]
for site in sites:
    idx = sites.index(site)
    loc = locs[idx]
    subds = dsgte.sel(latitude = loc.lat, longitude = loc.lon, method = 'nearest')
    mhws, spikes = group_events(subds)
    mhws.to_csv(os.path.join(CURATED_DIR, f'mhws_{site.lower()}.csv'))
    spikes.to_csv( os.path.join(CURATED_DIR, f'spikes_{site.lower()}.csv'))

In [8]:

def mhw_cell_stats(subds):
    years = range(2015,2024)
    agg_ds = xr.Dataset()
    agg_ds = agg_ds.assign_coords({'latitude': subds.latitude, 'longitude': subds.longitude,'year': years})
    try:
        mhws, spikes = group_events(subds)
        mhw_days_list = []
        num_events_list = []
        total_days_list = []
        ratio_list = []
        for year in range(2015, 2024):
            bdt = datetime(year,7,9)
            edt = datetime(year,10,8,23,59,59)
            total_days = (edt - bdt).days
            df = mhws[(mhws.Te >= bdt) & (mhws.Ts <= edt)]
            df = df.reset_index(drop = True)

            total_dtr = pd.date_range(bdt, edt)
            flags = []
            for day in total_dtr:
                flag = 0
                for idx, row in df.iterrows():
                    mhw_dtr = pd.date_range(row.Ts, row.Te)
                    if day in mhw_dtr:
                        flag = 1
                        break  
                flags.append(flag)
            total_days = len(total_dtr)
            mhw_days = len([v for v in flags if v == 1])

            num_events = len(df)
            ratio = mhw_days/total_days

            mhw_days_list.append(mhw_days)
            num_events_list.append(num_events)
            total_days_list.append(total_days)
            ratio_list.append(ratio)
    except:
        mhw_days_list = [np.nan] * len(years)
        num_events_list = [np.nan] * len(years)
        total_days_list = [np.nan] * len(years)
        ratio_list = [np.nan] * len(years)
    agg_ds['mhw_days'] = (['year'], mhw_days_list)
    agg_ds['total_days'] = (['year'], total_days_list)
    agg_ds['num_events'] = (['year'], num_events_list)
    agg_ds['ratio'] = (['year'], ratio_list)
    agg_ds = agg_ds.assign_coords({'latitude': [agg_ds.latitude], 'longitude': [agg_ds.longitude]})
    return agg_ds

In [9]:
ds = xr.open_dataset(os.path.join(CURATED_DIR, 'sst.nc'))
dsc = xr.open_dataset(os.path.join(CURATED_DIR, 'sst_clim_w11.nc'))
dsp = xr.open_dataset(os.path.join(CURATED_DIR, 'sst_90th_w11.nc'))
dsgte = xr.open_dataset(os.path.join(CURATED_DIR, 'sst_mhw_mask.nc'))

In [10]:
latitudes = dsgte.latitude.values
longitudes = dsgte.longitude.values

In [11]:
%%time
ds_list = []
for latitude in latitudes:
    for longitude in longitudes:
        _ds = mhw_cell_stats(dsgte.sel(latitude = latitude, longitude = longitude, method ='nearest'))
        _ds = _ds.assign_coords({'latitude':_ds.latitude,'longitude':_ds.longitude})
        ds_list.append(_ds)
mhw_spatial = xr.combine_by_coords(ds_list)
mhw_spatial.to_netcdf(os.path.join(CURATED_DIR, 'mhw_spatial_stats.nc'))

CPU times: user 33min 52s, sys: 1min 43s, total: 35min 36s
Wall time: 35min 33s
