In [1]:
ens = 30

import numpy as np
import xarray as xr
import math
import glob
import time
import numexpr as ne
import sys
np.set_printoptions(threshold=sys.maxsize)
from netCDF4 import Dataset
from datetime import datetime
from itertools import groupby

import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from matplotlib.axes import Axes
from cartopy.mpl.geoaxes import GeoAxes
GeoAxes._pcolormesh_patched = Axes.pcolormesh

##########################################################

def load_data(ens, scenario):

    if scenario == 'hist':forcing = 'historical'
    else:forcing = 'scenarioSSP5-85'

    path  = '/Volumes/Data/SPEAR/'+str(forcing)+'/day/tasmax/tasmax_day_GFDL-SPEAR-MED_'+str(forcing)+'_r'+str(ens)+'i1p1f1'
    flist = sorted(glob.glob(path+'*31.nc'))

    if scenario == 'hist':flist = flist[-4:-1] # 1981-2010
    else:flist = flist[:4] # 2015-2050
    
    for n in range(len(flist)):
        print(flist[n])
        start = time.time()
        if n == 0:
            ds = xr.open_dataset(flist[0])['tasmax']
            ds = ds.sel(lat=slice(20,50),lon=slice(230,300))

        else:
            new  = xr.open_dataset(flist[n])['tasmax']
            new = new.sel(lat=slice(20,50),lon=slice(230,300))

            ds = xr.concat([ds, new], dim='time')
        print('flist', n+1, 'computing', round((time.time() - start)/60.,2), 'min')

    return ds

##########################################################

def calc_percentile(ds, pct=95, moving_yr=None, attrs_period='(1961-2080)', fixed=True, removing_leap=True):
    
    if removing_leap == True:
        ds = remove_leap(ds)


    if fixed == True:
        
        gb = ds.groupby('time.dayofyear')
        thre = gb.reduce(np.percentile, dim='time', q=pct)
        
        if len(thre.dayofyear) == 366:
                thre = thre.drop_isel(dayofyear=[59])
                thre['dayofyear'] = np.arange(1,365+1)
                print('day of year=',len(thre.dayofyear))

        thre.attrs['calculation'] = 'The 95th percentile of each calendar day during the reference period (1981-2010)'
        return thre
    
    else:
        start = time.time()
        for i in range(len(ds)//365 - moving_yr):
            ds_30yr = ds[i*365:(i+moving_yr)*365]
            print(ds_30yr.time[0].dt.year.data,'-',ds_30yr.time[-1].dt.year.data)

            gb = ds_30yr.groupby('time.dayofyear')
            thre = gb.reduce(np.percentile, dim='time', q=pct)

            if len(thre.dayofyear) == 366:
                thre = thre.drop_isel(dayofyear=[59])
                thre['dayofyear'] = np.arange(1,365+1)
                print('day of year=',len(thre.dayofyear))

            if i == 0:
                comb = xr.concat([thre],       dim='year')
            else:
                comb = xr.concat([comb, thre], dim='year')

        print(round((time.time() - start)/60.,2), 'min')
        ds_thre = comb.assign_coords({'year': np.arange((ds.time[0].dt.year.data+moving_yr), (ds_30yr.time[-1].dt.year.data+1)+1)})
        ds_thre.attrs['calculation'] = 'The '+str(pct)+'th percentile of each calendar day in a year after a '+str(moving_yr)+'-year moving window during the reference period '+str(attrs_period)

        return ds_thre


def remove_leap(ds):
    
    ds = ds.sel(time=~((ds.time.dt.month == 2) & (ds.time.dt.day == 29)))
    print('Total length of year=', len(ds.time)//365, '(',
          ds.time[0].dt.year.data,'-',ds.time[-1].dt.year.data,')')
    
    return ds


def pct_reshape(ds, reference_ds=None, fixed=True, rolling_t=31, save=False, title=None):
    
    if fixed == True:
        
        # Duplicate the array at both ends
        extended_data = xr.concat([ds[-15:], ds, ds[:15]], dim='dayofyear')
        extended_data = extended_data.rename({'dayofyear': 'time'})

        # Apply the rolling mean with a window of 31 days
        rolling_mean = extended_data.rolling(time=rolling_t, center=True).mean()

        # Extract the smoothed values corresponding to the original range
        smoothed_data = rolling_mean[15:-15]
        
        if save == True:
            print('pct saving at',title)
            smoothed_data.to_netcdf(title)
        
        # Repeat pct_reshaped for each year (66 times)
        extended_pct_reshaped = np.tile(smoothed_data, (len(reference_ds)//365, 1, 1))
        
        # Create a new DataArray for extended_pct_reshaped with the same coordinates as tasmax_noleap_us
        extended_pct_reshaped_da = xr.DataArray(extended_pct_reshaped, 
                                                dims=["time", "lat", "lon"], 
                                                coords={"time": reference_ds.coords["time"], 
                                                        "lat": reference_ds.coords["lat"], 
                                                        "lon": reference_ds.coords["lon"]})

        return extended_pct_reshaped_da
        
    else:
        pct_reshaped = ds.stack(point=["year", "dayofyear"])
        pct_reshaped = pct_reshaped.rename({'point': 'time'}).transpose("time", "lat", "lon")
        pct_reshaped['time'] = reference_ds.time
        
        rolling_mean = pct_reshaped.rolling(time=rolling_t, center=True).mean()
        
        if save == True:
            print('pct saving at',title)
            rolling_mean.to_netcdf(title)

        return rolling_mean


##########################################################

def calc_spell(series):
    
    """ count 변수는 0과 temp의 일종의 이진법 형식으로 구성되어 있음
    1) temp가 있는 자리는 1로 반환하여 binary 형태로 변경
    2) 1이 연속인 날을 세고자 함
    3) 연속 일수가 3일 미만인 날은 0으로 두고 3일 이상인 날부터 heatwave가 발생했다고 계산 
    
    cf) 이 함수는 각 grid에서 1d(time) array를 input으로 넣어서 계산될 예정"""
    
    #print('===== Step 1 =====') # converting temp values into ones
    series = series.where(series<=0, other=1)

    #print('===== Step 2 =====')
    srun=np.zeros(series.shape)
    srun[1:]=np.diff(series,axis=0)
    
    if series[0] == 1: 
        srun[0] = 1 # srun[0][series[0]==1] = 1
    else :
        srun[0]=-1  # srun[0][series[0]!=1] = -1

    L= list(series)#.tolist()

    groups_hw=[]
    for k,g in groupby(L):
        if k==1:
            b=list(g)
            groups_hw.append(sum(b))
            
    #print('===== Step 3 =====')
    spell_hw=np.zeros((len(series),),dtype=int)
    
    if np.any(srun==1):
        spell_hw[srun==1]=np.asarray(groups_hw)
        
    spell_hw[spell_hw<5]=0
    
    return spell_hw


def calc_hw_index(count, ens, period, save=True):

    print('ensemble',ens)
    
    syear  = count.time.dt.year[0].values
    nyears = count.shape[0]//365
    ndays  = count.shape[0]
    nlat   = count.shape[-2]
    nlon   = count.shape[-1]
    print('syear, nyears, ndays, nlat, nlon =',syear,',',nyears,',',ndays,',',nlat,',',nlon)

    spell_all     = np.ones(count.shape,float)*np.nan
    heatwave_avg  = np.ones(count.shape,float)*np.nan
    heatwave_peak = np.ones(count.shape,float)*np.nan

    stime = time.time()
    for ilat in range(nlat):
        subtime = time.time()
        print('******** {0:02d}th latitude ********'.format(ilat+1))
        for ilon in range(nlon):
            start = time.time()
            spell_hw = calc_spell(count[:,ilat,ilon]) 

            # heatwave event가 기록된 index와 그 때의 spell 반환
            index = [i for i, x in enumerate(list(spell_hw)) if spell_hw[i]!=0]
            spell = [x for i, x in enumerate(list(spell_hw)) if spell_hw[i]!=0]

            for t in range(len(index)):
                # heatwave event내 max temp의 index 반환 → 해당 자리에 max temp 반환
                index_max = np.argmax(count[index[t]:index[t]+spell[t],ilat,ilon].values) # xarray라 끝에 .values 포함!!
                heatwave_peak[index[t]+index_max,ilat,ilon] = np.max(count[index[t]:index[t]+spell[t],ilat,ilon])
                heatwave_avg[index[t],ilat,ilon]            = np.nanmean(count[index[t]:index[t]+spell[t],ilat,ilon])

            spell_all[:,ilat,ilon]=spell_hw
            del spell_hw, index, spell#, index_max
            print('{0:03d}th longitude'.format(ilon+1), round((time.time() - start),2), 'sec')
    
        print('Subtotal computing time:', round((time.time() - subtime)/60.,2), 'min')
        
    print('==== Total computing time:', round((time.time() - stime)/3600.,2), 'hours ====')
    
    spell_all_xr = convert_xr(arr=spell_all, var_name='HWD', coords_base=count,
                          long_name='Heatwave duration',unit='days',
                          method='Consecutive days where daily maximum temperature exceeds threshold')
    heatwave_avg_xr = convert_xr(arr=heatwave_avg, var_name='HWM', coords_base=count,
                                 long_name='Heatwave magnitude',unit='K',
                                 method='Average daily maximum temperature during each heatwave event')
    heatwave_peak_xr = convert_xr(arr=heatwave_peak, var_name='HWA', coords_base=count,
                                  long_name='Heatwave amplitude',unit='K',
                                  method='Highest daily maximum temperature during each heatwave event')
    
    if save == True:
        start = time.time()
        path = '/Volumes/Data/SPEAR/scenarioSSP5-85/day/tasmax/tasmax_day_GFDL-SPEAR-MED_scenarioSSP5-85_r'
        print(path+str(ens)+'i1p1f1_gr3_'+str(period)+'_us_HWD_fixed_95pct_smoothed.nc')
        spell_all_xr.to_netcdf(path+str(ens)+'i1p1f1_gr3_'+str(period)+'_us_HWD_fixed_95pct_smoothed.nc')
        print('Subtotal saving time:', round((time.time() - start)/60.,2), 'min')
        
        start = time.time()
        print(path+str(ens)+'i1p1f1_gr3_'+str(period)+'_us_HWM_fixed_95pct_smoothed.nc')
        heatwave_avg_xr.to_netcdf(path+str(ens)+'i1p1f1_gr3_'+str(period)+'_us_HWM_fixed_95pct_smoothed.nc')
        print('Subtotal saving time:', round((time.time() - start)/60.,2), 'min')
        
        start = time.time()
        print(path+str(ens)+'i1p1f1_gr3_'+str(period)+'_us_HWA_fixed_95pct_smoothed.nc')
        heatwave_peak_xr.to_netcdf(path+str(ens)+'i1p1f1_gr3_'+str(period)+'_us_HWA_fixed_95pct_smoothed.nc')
        print('Subtotal saving time:', round((time.time() - start)/60.,2), 'min')
        
    return spell_all_xr, heatwave_avg_xr, heatwave_peak_xr


def convert_xr(arr,var_name,coords_base,long_name,unit,method):

    xr_arr = xr.DataArray(arr,
                          name=var_name,
                          dims=['time','lat','lon'],
                          coords=dict(time=coords_base.time,
                                       lat=coords_base.lat,
                                       lon=coords_base.lon),
                          attrs={'long_name': long_name,
                                 'units': unit,
                                 'calculation_methods': method,
                                 'valid_range': 'array([100., 400.], dtype=float)',
                                 'cell_methods': 'time: mean',
                                 'interp_method': 'conserve_order2',
                                 'cell_measures': 'area: areacella'})
    return xr_arr


##########################################################
##########################################################

path  = '/Volumes/Data/SPEAR/scenarioSSP5-85/day/tasmax/tasmax_day_GFDL-SPEAR-MED_scenarioSSP5-85_r'
hpath = '/Volumes/Data/SPEAR/historical/day/tasmax/tasmax_day_GFDL-SPEAR-MED_historical_r'+str(ens)+'i1p1f1'
period = '2015101-20801231' #'19610101-20801231'

print('------- Load hist -------')
hist_us   = load_data(ens=ens, scenario='hist')
print('------- Load ssp  -------')
tasmax_us    = load_data(ens=ens, scenario='ssp')

########################################################

thre = calc_percentile(hist_us, pct=95, moving_yr=None, attrs_period=None, fixed=True, removing_leap=True)
tasmax_noleap_us = remove_leap(tasmax_us)
pct_reshaped = pct_reshape(thre, tasmax_noleap_us, fixed=True, rolling_t=31, save=True, 
                           title=hpath+'_gr3_19810101-20101231_us_thre95pct_fixed_smoothed.nc')
print('threshold shape', pct_reshaped.shape)

del hist_us, tasmax_us, thre

##########################################################

start = time.time()
boolean_index = ne.evaluate('tasmax_noleap_us > pct_reshaped')
EHFsig = tasmax_noleap_us * boolean_index
print('Boolean index computing time:', round(time.time() - start,2), 'sec')
print(EHFsig[750:800,20,20].values)

start = time.time()
EHFsig.to_netcdf(path+str(ens)+'i1p1f1_gr3_'+str(period)+'_us_EHFsig_fixed_95pct_smoothed.nc')
print('EHFsig saving time:', round((time.time() - start)/60.,2), 'min')

##########################################################

spell_all, heatwave_avg, heatwave_peak = calc_hw_index(count=EHFsig, ens=ens, period=period, save=True)

------- Load hist -------
/Volumes/Data/SPEAR/historical/day/tasmax/tasmax_day_GFDL-SPEAR-MED_historical_r30i1p1f1_gr3_19810101-19901231.nc
flist 1 computing 2.06 min
/Volumes/Data/SPEAR/historical/day/tasmax/tasmax_day_GFDL-SPEAR-MED_historical_r30i1p1f1_gr3_19910101-20001231.nc
flist 2 computing 9.05 min
/Volumes/Data/SPEAR/historical/day/tasmax/tasmax_day_GFDL-SPEAR-MED_historical_r30i1p1f1_gr3_20010101-20101231.nc
flist 3 computing 5.43 min
------- Load ssp  -------
/Volumes/Data/SPEAR/scenarioSSP5-85/day/tasmax/tasmax_day_GFDL-SPEAR-MED_scenarioSSP5-85_r30i1p1f1_gr3_20150101-20201231.nc
flist 1 computing 1.33 min
/Volumes/Data/SPEAR/scenarioSSP5-85/day/tasmax/tasmax_day_GFDL-SPEAR-MED_scenarioSSP5-85_r30i1p1f1_gr3_20210101-20301231.nc
flist 2 computing 8.54 min
/Volumes/Data/SPEAR/scenarioSSP5-85/day/tasmax/tasmax_day_GFDL-SPEAR-MED_scenarioSSP5-85_r30i1p1f1_gr3_20310101-20401231.nc
flist 3 computing 6.25 min
/Volumes/Data/SPEAR/scenarioSSP5-85/day/tasmax/tasmax_day_GFDL-SPEAR-MED