In [25]:
import xarray as xr
import os
import numpy as np
import warnings
import logging

import functions.eddy_feedback as ef
import functions.data_wrangling as dw
import functions.aos_functions as aos

# Functions

In [2]:
def seasonal_mean(ds, months, cut_ends=False, take_mean=False):
    """
    Calculate seasonal means for a given list of 3 months.

    Parameters:
    - ds: xr.Dataset or xr.DataArray
        Input data, must cover full years if cut_ends=False.
    - months: list of int
        List of 3 months defining the season (e.g., [12,1,2] for DJF).
    - cut_ends: bool
        If True, removes incomplete seasonal data at the start/end.
    - take_mean: bool
        If True, also take the mean over all seasons (time dim removed).

    Returns:
    - xr.Dataset or xr.DataArray with seasonal means, one per season-year.
    """

    if not (isinstance(months, list) and all(isinstance(m, int) and 1 <= m <= 12 for m in months)):
        raise ValueError(f"`months` must be a list of 3 integers between 1–12. Got: {months}")

    if len(months) != 3:
        raise ValueError(f"`months` must have exactly 3 elements. Got: {months}")

    # Optionally cut ends to ensure complete seasons
    if cut_ends:
        first_valid_time = ds['time'].sel(time=ds['time'].dt.month.isin([months[0]])).isel(time=0).values
        last_valid_time  = ds['time'].sel(time=ds['time'].dt.month.isin([months[-1]])).isel(time=-1).values
        ds = ds.sel(time=slice(first_valid_time, last_valid_time))

    # Select months of interest
    ds_season = ds.sel(time=ds['time'].dt.month.isin(months))

    # Create a "season_year" coordinate to group by (to handle DJF & NDJ properly)
    def assign_season_year(time):
        """
        Returns an array of years adjusted for seasons that cross calendar years
        (e.g., DJF or NDJ).
        """
        year = time.dt.year
        month = time.dt.month

        # If the first month of the season is December, then Dec belongs to next year
        if months[0] == 12:
            year = xr.where(month == 12, year + 1, year)

        return year

    season_year = assign_season_year(ds_season['time'])
    ds_season.coords['season_year'] = ('time', season_year.data)

    # Compute mean over each season
    result = ds_season.groupby('season_year').mean('time')

    # Rename `season_year` to `time`
    result = result.rename({'season_year': 'time'})

    if take_mean:
        return result.mean('time')
    else:
        return result
    
season_month_dict = {
    'DJF': [12,1,2], 'JFM': [1,2,3], 'FMA': [2,3,4], 'MAM': [3,4,5],
    'AMJ': [4,5,6], 'MJJ': [5,6,7], 'JJA': [6,7,8], 'JAS': [7,8,9],
    'ASO': [8,9,10], 'SON': [9,10,11], 'OND': [10,11,12], 'NDJ': [11,12,1]
}

In [3]:
def calculate_efp_annual_cycle(ds, months, calc_south_hemis=False, which_div1=None, time_slice=None, cut_ends=False):
  

    # Apply hemisphere-specific processing
    if calc_south_hemis:
        ds = ds.sel(lat=slice(-90, 0))
        efp_lat_slice = slice(-75, -25)
    else:
        ds = ds.sel(lat=slice(0, 90))
        efp_lat_slice = slice(25, 75)
        
    # time slice if required, then calculate seasonal mean
    if time_slice is not None:
        ds = ds.sel(time=time_slice)
    else:
        print(f'Calculating EFP for full time period: {ds.time.min().values} to {ds.time.max().values}')
    ds = seasonal_mean(ds, months=months, cut_ends=cut_ends)
    
    #---------------------------------
    # Compute Eddy Feedback Parameter
    
    try:

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            corr = xr.corr(ds[which_div1], ds.ubar, dim='time').load()**2

        corr = corr.sel(lat=efp_lat_slice, level=slice(600., 200.))
        corr = corr.mean('level')

        weights = np.cos(np.deg2rad(corr.lat))
        efp = corr.weighted(weights).mean('lat')

        return round(float(efp.values), 4)
    
    except Exception as e:
        raise RuntimeError(f"Error during Eddy Feedback Parameter calculation: {e}")

# Prepare data

In [4]:
path = '/disca/share/sit204/jra_55/1958_2016_6hourly_data_efp'
data_path = os.path.join(path, '*_daily_averages.nc')
# data_path = os.path.join(path, '*_epflux.nc')
ds6h = xr.open_mfdataset(data_path)
ds6h['ubar'] = ds6h.ucomp.mean('lon')
ds6h = ds6h[['ubar', 'div1_QG', 'div1_QG_123', 'div1_QG_gt3']]
ds6h = dw.data_checker1000(ds6h, check_vars=False)
ds6h

Unnamed: 0,Array,Chunk
Bytes,215.96 MiB,3.67 MiB
Shape,"(21550, 37, 71)","(366, 37, 71)"
Dask graph,59 chunks in 123 graph layers,59 chunks in 123 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 215.96 MiB 3.67 MiB Shape (21550, 37, 71) (366, 37, 71) Dask graph 59 chunks in 123 graph layers Data type float32 numpy.ndarray",71  37  21550,

Unnamed: 0,Array,Chunk
Bytes,215.96 MiB,3.67 MiB
Shape,"(21550, 37, 71)","(366, 37, 71)"
Dask graph,59 chunks in 123 graph layers,59 chunks in 123 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,431.91 MiB,7.34 MiB
Shape,"(21550, 37, 71)","(366, 37, 71)"
Dask graph,59 chunks in 121 graph layers,59 chunks in 121 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 431.91 MiB 7.34 MiB Shape (21550, 37, 71) (366, 37, 71) Dask graph 59 chunks in 121 graph layers Data type float64 numpy.ndarray",71  37  21550,

Unnamed: 0,Array,Chunk
Bytes,431.91 MiB,7.34 MiB
Shape,"(21550, 37, 71)","(366, 37, 71)"
Dask graph,59 chunks in 121 graph layers,59 chunks in 121 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,431.91 MiB,7.34 MiB
Shape,"(21550, 37, 71)","(366, 37, 71)"
Dask graph,59 chunks in 121 graph layers,59 chunks in 121 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 431.91 MiB 7.34 MiB Shape (21550, 37, 71) (366, 37, 71) Dask graph 59 chunks in 121 graph layers Data type float64 numpy.ndarray",71  37  21550,

Unnamed: 0,Array,Chunk
Bytes,431.91 MiB,7.34 MiB
Shape,"(21550, 37, 71)","(366, 37, 71)"
Dask graph,59 chunks in 121 graph layers,59 chunks in 121 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,431.91 MiB,7.34 MiB
Shape,"(21550, 37, 71)","(366, 37, 71)"
Dask graph,59 chunks in 121 graph layers,59 chunks in 121 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 431.91 MiB 7.34 MiB Shape (21550, 37, 71) (366, 37, 71) Dask graph 59 chunks in 121 graph layers Data type float64 numpy.ndarray",71  37  21550,

Unnamed: 0,Array,Chunk
Bytes,431.91 MiB,7.34 MiB
Shape,"(21550, 37, 71)","(366, 37, 71)"
Dask graph,59 chunks in 121 graph layers,59 chunks in 121 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [5]:
# Reference EFP calculation
print('NH EFP:', ef.calculate_efp(ds6h, data_type='reanalysis', which_div1='div1_QG'))
print('SH EFP:', ef.calculate_efp(ds6h, data_type='reanalysis', which_div1='div1_QG', calc_south_hemis=True))

NH EFP: 0.4277
SH EFP: 0.3269


In [6]:
# cut ends = False | Time slice (1979, 2016)
print('NH EFP:', calculate_efp_annual_cycle(ds6h, months=[12,1,2], which_div1='div1_QG', time_slice=slice('1979-01-01', '2016-12-31')))
print('SH EFP:', calculate_efp_annual_cycle(ds6h, months=[7,8,9], which_div1='div1_QG', calc_south_hemis=True, time_slice=slice('1979-01-01', '2016-12-31')))

NH EFP: 0.4057
SH EFP: 0.3269


In [7]:
# cut ends = True | Time slice (1979, 2016)
print('NH EFP:', calculate_efp_annual_cycle(ds6h, months=[12,1,2], which_div1='div1_QG', time_slice=slice('1979-01-01', '2016-12-31'), cut_ends=True))
print('SH EFP:', calculate_efp_annual_cycle(ds6h, months=[7,8,9], which_div1='div1_QG', calc_south_hemis=True, time_slice=slice('1979-01-01', '2016-12-31')))

NH EFP: 0.4277
SH EFP: 0.3269


In [9]:
daily_path = '/home/links/ct715/data_storage/reanalysis/jra55_daily/k123_QG_epfluxes'
data_path = os.path.join(daily_path, '*_daily_averages.nc')
day = xr.open_mfdataset(data_path)
day = day[['ubar', 'div1_QG', 'div1_QG_123', 'div1_QG_gt3']]
day = dw.data_checker1000(day, check_vars=False)
day


Unnamed: 0,Array,Chunk
Bytes,84.18 kiB,1.43 kiB
Shape,"(21550,)","(366,)"
Dask graph,59 chunks in 119 graph layers,59 chunks in 119 graph layers
Data type,int32 numpy.ndarray,int32 numpy.ndarray
"Array Chunk Bytes 84.18 kiB 1.43 kiB Shape (21550,) (366,) Dask graph 59 chunks in 119 graph layers Data type int32 numpy.ndarray",21550  1,

Unnamed: 0,Array,Chunk
Bytes,84.18 kiB,1.43 kiB
Shape,"(21550,)","(366,)"
Dask graph,59 chunks in 119 graph layers,59 chunks in 119 graph layers
Data type,int32 numpy.ndarray,int32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,222.04 MiB,3.77 MiB
Shape,"(21550, 37, 73)","(366, 37, 73)"
Dask graph,59 chunks in 119 graph layers,59 chunks in 119 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 222.04 MiB 3.77 MiB Shape (21550, 37, 73) (366, 37, 73) Dask graph 59 chunks in 119 graph layers Data type float32 numpy.ndarray",73  37  21550,

Unnamed: 0,Array,Chunk
Bytes,222.04 MiB,3.77 MiB
Shape,"(21550, 37, 73)","(366, 37, 73)"
Dask graph,59 chunks in 119 graph layers,59 chunks in 119 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,444.08 MiB,7.54 MiB
Shape,"(21550, 37, 73)","(366, 37, 73)"
Dask graph,59 chunks in 119 graph layers,59 chunks in 119 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 444.08 MiB 7.54 MiB Shape (21550, 37, 73) (366, 37, 73) Dask graph 59 chunks in 119 graph layers Data type float64 numpy.ndarray",73  37  21550,

Unnamed: 0,Array,Chunk
Bytes,444.08 MiB,7.54 MiB
Shape,"(21550, 37, 73)","(366, 37, 73)"
Dask graph,59 chunks in 119 graph layers,59 chunks in 119 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,444.08 MiB,7.54 MiB
Shape,"(21550, 37, 73)","(366, 37, 73)"
Dask graph,59 chunks in 119 graph layers,59 chunks in 119 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 444.08 MiB 7.54 MiB Shape (21550, 37, 73) (366, 37, 73) Dask graph 59 chunks in 119 graph layers Data type float64 numpy.ndarray",73  37  21550,

Unnamed: 0,Array,Chunk
Bytes,444.08 MiB,7.54 MiB
Shape,"(21550, 37, 73)","(366, 37, 73)"
Dask graph,59 chunks in 119 graph layers,59 chunks in 119 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,444.08 MiB,7.54 MiB
Shape,"(21550, 37, 73)","(366, 37, 73)"
Dask graph,59 chunks in 119 graph layers,59 chunks in 119 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 444.08 MiB 7.54 MiB Shape (21550, 37, 73) (366, 37, 73) Dask graph 59 chunks in 119 graph layers Data type float64 numpy.ndarray",73  37  21550,

Unnamed: 0,Array,Chunk
Bytes,444.08 MiB,7.54 MiB
Shape,"(21550, 37, 73)","(366, 37, 73)"
Dask graph,59 chunks in 119 graph layers,59 chunks in 119 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [10]:
# Reference EFP calculation
print('NH EFP:', ef.calculate_efp(day, data_type='reanalysis', which_div1='div1_QG'))
print('SH EFP:', ef.calculate_efp(day, data_type='reanalysis', which_div1='div1_QG', calc_south_hemis=True))

NH EFP: 0.3631
SH EFP: 0.2031


In [12]:
# cut ends = True | Time slice (1979, 2016)
print('NH EFP:', calculate_efp_annual_cycle(day, months=[12,1,2], which_div1='div1_QG', time_slice=slice('1979-01-01', '2016-12-31'), cut_ends=True))
print('SH EFP:', calculate_efp_annual_cycle(day, months=[7,8,9], which_div1='div1_QG', calc_south_hemis=True, time_slice=slice('1979-01-01', '2016-12-31')))

NH EFP: 0.3631
SH EFP: 0.2031


# PAMIP

In [None]:
import json
logger = logging.getLogger(__name__)

def seasonal_mean(ds, months, take_mean=False):
    logger.info(f"Computing seasonal mean for months: {months}, take_mean={take_mean}")
    
    # pre-processing checks
    if not (isinstance(months, list) and all(isinstance(m, int) and 1 <= m <= 12 for m in months)):
        raise ValueError(f"`months` must be a list of 3 integers between 1–12. Got: {months}")
    if len(months) != 3:
        raise ValueError(f"`months` must have exactly 3 elements. Got: {months}")

    ds_season = ds.sel(time=ds['time'].dt.month.isin(months))

    def assign_season_year(time):
        year = time.dt.year
        month = time.dt.month
        if months[0] == 12:
            year = xr.where(month == 12, year + 1, year)
        return year

    season_year = assign_season_year(ds_season['time'])
    ds_season.coords['season_year'] = ('time', season_year.data)

    result = ds_season.groupby('season_year').mean('time')
    result = result.rename({'season_year': 'time'})

    if take_mean:
        logger.info("Returning mean over all seasons.")
        return result.mean('time')
    else:
        return result
    

def calculate_efp_annual_cycle(ds, months, calc_south_hemis=False, which_div1=None):
    hemi = 'Southern' if calc_south_hemis else 'Northern'
    logger.info(f"Calculating EFP annual cycle for {hemi} Hemisphere, div1: {which_div1}, months: {months}")
    
    if calc_south_hemis:
        ds = ds.sel(lat=slice(-90, 0))
        efp_lat_slice = slice(-75, -25)
    else:
        ds = ds.sel(lat=slice(0, 90))
        efp_lat_slice = slice(25, 75)
        
    # calc seasonal means
    ds = seasonal_mean(ds, months=months)
    logger.info(f"Seasonal means calculated. Dataset shape: {ds.sizes}")

    logger.debug("Computing correlation and EFP.")
    try:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            corr = xr.corr(ds[which_div1], ds.ubar, dim='ens_ax').load()**2

        corr = corr.sel(lat=efp_lat_slice, level=slice(600., 200.))
        corr = corr.mean('level')

        weights = np.cos(np.deg2rad(corr.lat))
        efp = corr.weighted(weights).mean('lat')

        efp_value = round(float(efp.values[0]), 4)
        logger.info(f"EFP = {efp_value}")
        return efp_value
    
    except Exception as e:
        logger.error(f"Error during EFP calculation: {e}")
        raise RuntimeError(f"Error during Eddy Feedback Parameter calculation: {e}")

In [38]:
pamip_path = '/home/links/ct715/data_storage/PAMIP/processed_daily'
k123_path = os.path.join(pamip_path, 'k123_daily_efp_mon-avg', 'CanESM5', '*.nc')
k123 = xr.open_mfdataset(k123_path, combine='nested', concat_dim='ens_ax')

pamip_path = '/home/links/ct715/data_storage/PAMIP/processed_daily'
daily_path = os.path.join(pamip_path, 'daily_efp_mon-avg', 'CanESM5', '*.nc')
daily = xr.open_mfdataset(daily_path, combine='nested', concat_dim='ens_ax')

daily['divFy_k123'] = k123.divFy_k123
daily['divFy_gt3'] = daily.divFy - daily.divFy_k123
pamip = daily[['ubar', 'divFy', 'divFy_k123', 'divFy_gt3']]
pamip

Unnamed: 0,Array,Chunk
Bytes,5.57 MiB,57.00 kiB
Shape,"(100, 12, 19, 64)","(1, 12, 19, 64)"
Dask graph,100 chunks in 301 graph layers,100 chunks in 301 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 5.57 MiB 57.00 kiB Shape (100, 12, 19, 64) (1, 12, 19, 64) Dask graph 100 chunks in 301 graph layers Data type float32 numpy.ndarray",100  1  64  19  12,

Unnamed: 0,Array,Chunk
Bytes,5.57 MiB,57.00 kiB
Shape,"(100, 12, 19, 64)","(1, 12, 19, 64)"
Dask graph,100 chunks in 301 graph layers,100 chunks in 301 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,11.13 MiB,114.00 kiB
Shape,"(100, 12, 19, 64)","(1, 12, 19, 64)"
Dask graph,100 chunks in 301 graph layers,100 chunks in 301 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 11.13 MiB 114.00 kiB Shape (100, 12, 19, 64) (1, 12, 19, 64) Dask graph 100 chunks in 301 graph layers Data type float64 numpy.ndarray",100  1  64  19  12,

Unnamed: 0,Array,Chunk
Bytes,11.13 MiB,114.00 kiB
Shape,"(100, 12, 19, 64)","(1, 12, 19, 64)"
Dask graph,100 chunks in 301 graph layers,100 chunks in 301 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,11.13 MiB,114.00 kiB
Shape,"(100, 12, 19, 64)","(1, 12, 19, 64)"
Dask graph,100 chunks in 301 graph layers,100 chunks in 301 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 11.13 MiB 114.00 kiB Shape (100, 12, 19, 64) (1, 12, 19, 64) Dask graph 100 chunks in 301 graph layers Data type float64 numpy.ndarray",100  1  64  19  12,

Unnamed: 0,Array,Chunk
Bytes,11.13 MiB,114.00 kiB
Shape,"(100, 12, 19, 64)","(1, 12, 19, 64)"
Dask graph,100 chunks in 301 graph layers,100 chunks in 301 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,11.13 MiB,114.00 kiB
Shape,"(100, 12, 19, 64)","(1, 12, 19, 64)"
Dask graph,100 chunks in 603 graph layers,100 chunks in 603 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 11.13 MiB 114.00 kiB Shape (100, 12, 19, 64) (1, 12, 19, 64) Dask graph 100 chunks in 603 graph layers Data type float64 numpy.ndarray",100  1  64  19  12,

Unnamed: 0,Array,Chunk
Bytes,11.13 MiB,114.00 kiB
Shape,"(100, 12, 19, 64)","(1, 12, 19, 64)"
Dask graph,100 chunks in 603 graph layers,100 chunks in 603 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [39]:
# Reference EFP calculation
print('NH EFP:', ef.calculate_efp(pamip, data_type='pamip', which_div1='divFy'))
print('SH EFP:', ef.calculate_efp(pamip, data_type='pamip', which_div1='divFy', calc_south_hemis=True))

NH EFP: 0.3133
SH EFP: 0.3271


In [40]:
# New EFP calculation with seasonal means
print('NH EFP:', calculate_efp_annual_cycle(pamip, months=[12,1,2], which_div1='divFy'))
print('SH EFP:', calculate_efp_annual_cycle(pamip, months=[7,8,9], which_div1='divFy', calc_south_hemis=True))

[0.31333789]
NH EFP: 0.3133
[0.32709595]
SH EFP: 0.3271
