# Ocetrac on LENS

In [None]:
import xarray as xr
import LENSfunctions as funcs

In [None]:
! pwd

## Functions

In [None]:
def get_fut_file_paths(var, directory, path_intermed_fut, index):
    attrib_title = path_intermed_fut[index]
    file_paths = []

    for start_year in range(2015, 2021, 10):  # Only include one decade: 2015–2020
        end_year = min(start_year + 9, 2020)  # Cap the end year at 2020
        file_path = f'{directory}/{attrib_title}.{var}.{start_year}01-{end_year}12.nc'
        file_paths.append(file_path)

    return file_paths


In [None]:
import xarray as xr

# Define inputs
var = 'SST'
directory = '/glade/campaign/cgd/cesm/CESM2-LE/atm/proc/tseries/month_1/SST'
path_intermed_fut = ['b.e21.BSSP370cmip6.f09_g17.LE2-1231.001']  # example
index = 0

# Get file paths
file_paths = get_fut_file_paths(var, directory, path_intermed_fut, index)

# Load as mfdataset
#ds = xr.open_mfdataset(file_paths, combine='by_coords')


In [None]:
import xarray as xr
import os

def get_paths_for_member(var, directory, member_id):
    file_paths = []
    for start_year in [2015]:  # Simplest—just capture the 2015–2020 span
        end_year = 2020
        filename = f"{member_id}.{var}.{start_year}01-{end_year}12.nc"
        file_paths.append(os.path.join(directory, filename))
    return file_paths

# Example usage:
ensemble_paths = []
for member_id in path_intermed_fut:  # iterate over all ensemble member IDs
    ensemble_paths.extend(get_paths_for_member(var, directory, member_id))

## Now load via xarray
#ds = xr.open_mfdataset(ensemble_paths, combine="by_coords")

## Double-check that the time range is within 2015–2020
#ds = ds.sel(time=slice("2015-01-01", "2020-12-31"))


In [None]:
# removing linear trend

def calculate_anomalies_notrend(ds):
    sst = ds
    dyr = ds.time.dt.year + ds.time.dt.month/12
    # Our 6 coefficient model is composed of the mean, trend, annual sine and cosine harmonics, & semi-annual sine and cosine harmonics
    model = np.array([np.ones(len(dyr))] + [dyr-np.mean(dyr)] + [np.sin(2*np.pi*dyr)] + [np.cos(2*np.pi*dyr)] + [np.sin(4*np.pi*dyr)] + [np.cos(4*np.pi*dyr)])
    
    # Take the pseudo-inverse of model to 'solve' least-squares problem
    pmodel = np.linalg.pinv(model)
    
    # Convert model and pmodel to xarray DataArray
    model_da = xr.DataArray(model.T, dims=['time','coeff'], coords={'time':sst.time.values, 'coeff':np.arange(1,7,1)}) 
    pmodel_da = xr.DataArray(pmodel.T, dims=['coeff','time'], coords={'coeff':np.arange(1,7,1), 'time':sst.time.values})  
    
    # resulting coefficients of the model
    sst_mod = xr.DataArray(pmodel_da.dot(sst), dims=['coeff','lat','lon'], coords={'coeff':np.arange(1,7,1), 'lat':sst.lat.values, 'lon':sst.lon.values})
    
    # Construct mean, trend, and seasonal cycle
    mean = model_da[:,0].dot(sst_mod[0,:,:])
    trend = model_da[:,1].dot(sst_mod[1,:,:])
    seas = model_da[:,2:].dot(sst_mod[2:,:,:])
    
    # compute anomalies by removing linear trend
    ssta_withtrend = sst - (trend)
    
    # Use the 90th percentile as a threshold and find anomalies that exceed it. 
    if ssta_withtrend.chunks:
        ssta_withtrend = ssta_withtrend.chunk({'time': -1})
    
    threshold = ssta_withtrend.quantile(.9, dim=('time'))
    features_withtrend = ssta_withtrend.where(ssta_withtrend>=threshold, other=np.nan)
    return mean, trend, seas, features_withtrend, ssta_withtrend


## Ocetrac using different detrending methods

In [None]:
ensemble_mean= xr.open_dataset(f'/glade/work/cassiacai/ensemble_mean.nc')
ensemble_mean.SST

In [None]:
var = 'SST'
comp = 'atm'
directory = f'/glade/campaign/cgd/cesm/CESM2-LE/{comp}/proc/tseries/month_1/{var}/'
for ens_memb_index in range(0,1):
    print(ens_memb_index)
    ds_var_hist_SST, ds_var_fut_SST = get_ds_var(directory, 'SST','atm', ens_memb_index) # Potential Density
    # defining our region
    upper_lat = 65
    lower_lat = 5
    left_lon = 150
    right_lon = 250

    # selecting this region
    CESMLENS_SST_NEP_hist = ds_var_hist_SST.SST.sel(lon=slice(left_lon, right_lon), 
                                           lat=slice(lower_lat,upper_lat))
    CESMLENS_SST_NEP_fut = ds_var_fut_SST.SST.sel(lon=slice(left_lon, right_lon), 
                                           lat=slice(lower_lat,upper_lat))
    # selecting out observational ti,e period overlap with ERA5 (1979 - 2022)
    CESMLENS_SST_NEP_hist_time_slice = CESMLENS_SST_NEP_hist.sel(time=slice('1979-01-01', '2015-01-01'))
    
    CESMLENS_SST_NEP_fut_time_slice = CESMLENS_SST_NEP_fut.sel(time=slice('2015-02-01', '2022-12-01'))
    
    # concatening the historical and future simulations together to make one dataset
    CESMLENS_SST_NEP_ds = xr.concat([CESMLENS_SST_NEP_hist_time_slice, CESMLENS_SST_NEP_fut_time_slice], dim='time')
    CESMLENS_SST_NEP_ds = CESMLENS_SST_NEP_ds.compute()
    # set land to be NaN
    CESMLENS_SST_NEP_ds_no_nan = CESMLENS_SST_NEP_ds.where(CESMLENS_SST_NEP_ds != 0, np.nan)

    
    #using the function calculate_anomalies_trend_features to calculate inputs into ocetrac + data
    lintrend_mean, lintrend_trend, lintrend_seas, lintrend_features_notrend, lintrend_ssta_notrend = calculate_anomalies_trend_features(ds=CESMLENS_SST_NEP_ds_no_nan)
    all_mean, all_trend, all_seas, all_features_notrend, all_ssta_notrend = calculate_anomalies_trend_features(ds=CESMLENS_SST_NEP_ds_no_nan)
    quadtrend_mean, quadtrend_trend, quadtrend_seas, quadtrend_features_notrend, quadtrend_ssta_notrend = calculate_anomalies_trend_features(ds=CESMLENS_SST_NEP_ds_no_nan)
    ensmeantrend_mean, ensmeantrend_trend, ensmean_trend_seas, ensmean_trend_features_notrend, ensmean_trend_ssta_notrend = calculate_anomalies_trend_features(ds=CESMLENS_SST_NEP_ds_no_nan)

    ###### Linear ########
    full_mask_land = lintrend_features_notrend
    full_masked = full_mask_land.where(full_mask_land != 0)
    binary_out_afterlandmask=np.isfinite(full_masked)
    newmask = np.isfinite(mean[:,:,:][:])
    binary_out_afterlandmask = binary_out_afterlandmask.compute()
    newmask = newmask.compute()

    for radius_val in range(1,5):
        obj_Tracker = Tracker(
            binary_out_afterlandmask[:,:,:], 
            newmask, radius=radius_val,min_size_quartile= 0.75, timedim = 'time', xdim = 'lon', ydim='lat', positive=True)
        blobs = obj_Tracker.track()
        blobs.attrs
        mo = obj_Tracker._morphological_operations()
        blobs.to_netcdf('data/ens_{}_lintrend_mhwobj_rad{}_mean.nc'.format(ens_memb_index, radius_val))

    ###### Quadratic ########
    full_mask_land = quadtrend_features_notrend
    full_masked = full_mask_land.where(full_mask_land != 0)
    binary_out_afterlandmask=np.isfinite(full_masked)
    newmask = np.isfinite(mean[:,:,:][:])
    binary_out_afterlandmask = binary_out_afterlandmask.compute()
    newmask = newmask.compute()

    for radius_val in range(1,5):
        obj_Tracker = Tracker(
            binary_out_afterlandmask[:,:,:], 
            newmask, radius=radius_val,min_size_quartile= 0.75, timedim = 'time', xdim = 'lon', ydim='lat', positive=True)
        blobs = obj_Tracker.track()
        blobs.attrs
        mo = obj_Tracker._morphological_operations()
        blobs.to_netcdf('data/ens_{}_lintrend_mhwobj_rad{}_mean.nc'.format(ens_memb_index, radius_val))

    ###### Ensemble Mean ########
    full_mask_land = ensmean_trend_features_notrend
    full_masked = full_mask_land.where(full_mask_land != 0)
    binary_out_afterlandmask=np.isfinite(full_masked)
    newmask = np.isfinite(mean[:,:,:][:])
    binary_out_afterlandmask = binary_out_afterlandmask.compute()
    newmask = newmask.compute()

    for radius_val in range(1,5):
        obj_Tracker = Tracker(
            binary_out_afterlandmask[:,:,:], 
            newmask, radius=radius_val,min_size_quartile= 0.75, timedim = 'time', xdim = 'lon', ydim='lat', positive=True)
        blobs = obj_Tracker.track()
        blobs.attrs
        mo = obj_Tracker._morphological_operations()
        blobs.to_netcdf('data/ens_{}_lintrend_mhwobj_rad{}_mean.nc'.format(ens_memb_index, radius_val))
        
    notrend_ssta_notrend.to_netcdf('data/ens_{}_lintrend_ssta_notrend.nc'.format(ens_memb_index))
    quadtrend_ssta_notrend.to_netcdf('data/ens_{}_quadtrend_ssta_notrend.nc'.format(ens_memb_index))
    ensmean_trend_ssta_notrend.to_netcdf('data/ens_{}_ensmean_trend_ssta_notrend.nc'.format(ens_memb_index))