In [None]:
import numpy as np                # library for mathematical operations with arrays
import pandas as pd               # library for data frames; includes useful functions for date arithmetic
import xarray as xr               # library for arrays, especially tailored to weather data
import matplotlib.pyplot as plt   # library for basic plotting
import matplotlib.colors as mcolors 
import datetime                   # library for date manipulation

from functools import reduce
from scipy.stats import pearsonr

from sklearn.linear_model import MultiTaskLassoCV
from sklearn.exceptions import ConvergenceWarning
from warnings import simplefilter

from os import path
import sys
sys.path.append("../src/confer_wp3/")
from confer_wp3.dataloading import load_raw_data, save_anomalies, save_eofs_pcs
from confer_wp3.validation import validate_anomalies1, validate_anomalies2, validate_anomalies3, validate_anomalies4, validate_eofs, validate_pcs
from confer_wp3.lasso_forecast import calculate_anomalies, compute_eofs_pcs, quantile_mapping, standardize_index, standardize_index_diff1

The following two variables specify the paths where the forecast and observation data are stored:

In [None]:
era5_dir = '/nr/samba/PostClimDataNoBackup/CONFER/EASP/raw_predictors/'
chirps_dir = '/nr/samba/PostClimDataNoBackup/CONFER/EASP/precip/chirps/'
indices_dir = '/home/michael/nr/samba/PostClimDataNoBackup/CONFER/EASP/fls/predictors/'
anomaly_dir = '/nr/samba/PostClimDataNoBackup/CONFER/EASP/precip/chirps/seasonal/halfdeg_res/'
eof_dir = '/home/michael/nr/samba/PostClimDataNoBackup/CONFER/EASP/eofs/chirps/halfdeg_res/'
fcst_dir = '/home/michael/nr/samba/PostClimDataNoBackup/CONFER/EASP/fls_pred/chirps/seasonal/halfdeg_res/'

Now, we set a number of parameters defining our forecast domain, training period, forecast year, etc.:

In [None]:
year_clm_start = 1993     # first year of the climatological reference period
year_clm_end = 2020       # last year of the climatological reference period

year_train_start = 1981   # first year of the training period
year_train_end = 2020     # last year of the training period

year_fcst = 2020          # year in which forecasts should be generated
month_init = 8            # month in which the forecast should be generated (based on data of the preceding month)
season = 'OND'

lon_bnds = [20, 53]       # longitude range of the domain of interest
lat_bnds = [-15, 23]      # latitude range of the domain of interest

period_clm = [year_clm_start, year_clm_end]
period_train = [year_train_start, year_train_end]

If not already available, load CHIRPS data, calculate seasonal precipitation anomalies, EOFs, and factor loading, and save out.

### CHIRPS Data

##### Loading data

In [None]:
# Temporary directory for storing results
val_dir = "/nr/samba/user/ahellevik/CONFER-WP3/validation_data/"

# Load data
year, lat, lon, prec_data = load_raw_data(chirps_dir, "chirps", [*range(year_train_start, year_train_end+1)], season, lat_bnds, lon_bnds) # 1993, 2021
prec_data.shape
print("Loaded data shape:", prec_data.shape)
print("NaN values in loaded data:", np.isnan(prec_data).sum())

##### Get/calculate anomalies and normalized anomalies

In [None]:
# Get anomalies
if not path.exists(f'{val_dir}chirps_anomalies.nc'):
    # Calculate anomalies
    anomalies = calculate_anomalies(prec_data, year, period_clm)
    # Save anomalies
    save_anomalies(anomalies, year, lat, lon, val_dir, normalized=False)
else:
    anomalies = xr.open_dataarray(f'{val_dir}chirps_anomalies.nc', engine='netcdf4')
    anomalies = anomalies.values
print(prec_data.shape)
print(anomalies.shape)
# Verify NaNs are handled correctly
print("NaN values in prec_data:", np.isnan(prec_data).sum())
print("NaN values in anomalies:", np.isnan(anomalies).sum())

In [None]:
# Get normalized anomalies
if not path.exists(f'{val_dir}chirps_anomalies_normal.nc'):
    # Apply the transformation to the anomalies data
    anomalies_normal = quantile_mapping(anomalies, year, period_clm)
    # Save normalized anomalies
    save_anomalies(anomalies_normal, year, lat, lon, val_dir, normalized=True)
else:
    anomalies_normal = xr.open_dataarray(f'{val_dir}chirps_anomalies_normal.nc', engine='netcdf4')
    anomalies_normal = anomalies_normal.values
print(anomalies_normal.shape)

##### Plots to ensure calculating anomalies went well

In [None]:
validate_anomalies1(prec_data, anomalies, lat, lon)

In [None]:
validate_anomalies2(anomalies, anomalies_normal, lat, lon)

In [None]:
validate_anomalies3(anomalies, anomalies_normal)

In [None]:
validate_anomalies4(anomalies, anomalies_normal)

##### Get/calculate EOFs and factor loadings

In [None]:
# Get EOFs
if not path.exists(f'{val_dir}chirps_eofs.nc'):
    # Calculate EOFs
    n_eofs = 7  # Number of EOFs to compute
    eofs, pcs, var_fracs = compute_eofs_pcs(anomalies_normal, n_eofs)
    # Reshape EOFs to 3D (n_eofs, lat, lon)
    eofs_reshaped = eofs.reshape((n_eofs, len(lat), len(lon)))
    # Save EOFs, PCs and variance fractions
    save_eofs_pcs(eofs_reshaped, pcs, var_fracs, year, lat, lon, val_dir, n_eofs)
else:
    eofs_reshaped = xr.open_dataarray(f'{val_dir}chirps_eofs.nc', engine='netcdf4').values
    pcs = xr.open_dataarray(f'{val_dir}chirps_pcs.nc', engine='netcdf4').values
    var_fracs = xr.open_dataarray(f'{val_dir}chirps_var_fracs.nc', engine='netcdf4').values
    
print("Normalized Anomalies EOFs Shape:", eofs_reshaped.shape)
print("Normalized Anomalies PCs Shape:", pcs.shape)
print("Normalized Anomalies Variance Fraction:", var_fracs, "sum: ", var_fracs.sum())

##### Plots to ensure calculating EOFs and factor loadings went well

In [None]:
# Plot EOFs for normalized anomalies
validate_eofs(eofs_reshaped, f"Normalized Anomalies - {season}", n_eofs=7)

In [None]:
validate_pcs(anomalies_normal, eofs_reshaped, pcs, lat, lon, year, period_train, season, n_eofs = 7)

### ERA5 Data

If not already available, load ERA5 data, calculate indices, and save out.

### Loading data

In [None]:
load_years = [*range(min(year_train_start, year_clm_start), max(year_train_end+1, year_clm_end+1))]

In [None]:
# Load sst data
sst_data = load_raw_data(era5_dir, "sst", load_years, season)
sst_data

In [None]:
# Load uwind200 data
uwind200_data = load_raw_data(era5_dir, "uwind200", load_years, season)
uwind200_data

In [None]:
# Load uwind850 data
uwind850_data = load_raw_data(era5_dir, "uwind850", load_years, season)
uwind850_data

### Calculate indices

#### Functions for index calculations

In [None]:

def get_region_indices(region):
    """
    Retrieve the bounding box coordinates for a specified region.

    Parameters:
    - region (str): The name of the region for which to retrieve the bounding box coordinates.
      Valid region names include:
        - 'n34': Nino 3.4 region
        - 'n3': Nino 3 region
        - 'n4_1': Nino 4 (part 1) region
        - 'n4_2': Nino 4 (part 2) region
        - 'wpg': Western Pacific region
        - 'dmi_1': Dipole Mode Index (West)
        - 'dmi_2': Dipole Mode Index (East)
        - 'sji850': South Indian Ocean Jet at 850 hPa
        - 'sji200': South Indian Ocean Jet at 200 hPa
        - 'ueq850': Upper Equatorial region at 850 hPa
        - 'ueq200': Upper Equatorial region at 200 hPa
        - 'wp': Western Pacific region
        - 'wnp_1': Western North Pacific (part 1)
        - 'wnp_2': Western North Pacific (part 2)
        - 'wsp_1': Western South Pacific (part 1)
        - 'wsp_2': Western South Pacific (part 2)

    Returns:
    - dict: A dictionary containing the bounding box coordinates for the specified region, with keys:
        - 'lat_min': Minimum latitude
        - 'lat_max': Maximum latitude
        - 'lon_min': Minimum longitude
        - 'lon_max': Maximum longitude
    """
    # Define the bounding boxes for the indices
    indices_definitions = {
        'n34': {'lat_min': -5, 'lat_max': 5, 'lon_min': -170, 'lon_max': -120},
        'n3': {'lat_min': -5, 'lat_max': 5, 'lon_min': -150, 'lon_max': -90},
        'n4_1': {'lat_min': -5, 'lat_max': 5, 'lon_min': -180, 'lon_max': -150},
        'n4_2': {'lat_min': -5, 'lat_max': 5, 'lon_min': 160, 'lon_max': 180},
        'wpg': {'lat_min': 0, 'lat_max': 20, 'lon_min': 130, 'lon_max': 150},
        'dmi_1': {'lat_min': -10, 'lat_max': 10, 'lon_min': 50, 'lon_max': 70}, # West
        'dmi_2': {'lat_min': -10, 'lat_max': 0, 'lon_min': 90, 'lon_max': 110}, # East
        'sji850': {'lat_min': 0, 'lat_max': 15, 'lon_min': 35, 'lon_max': 50},
        'sji200': {'lat_min': 0, 'lat_max': 15, 'lon_min': 35, 'lon_max': 50},
        'ueq850': {'lat_min': -4, 'lat_max': 4, 'lon_min': 60, 'lon_max': 90},
        'ueq200': {'lat_min': -4, 'lat_max': 4, 'lon_min': 60, 'lon_max': 90},
        'wp' : {'lat_min': -15, 'lat_max': 20, 'lon_min': 120, 'lon_max': 160},
        'wnp_1' : {'lat_min': 20, 'lat_max': 35, 'lon_min': 160, 'lon_max': 180},
        'wnp_2' : {'lat_min': 20, 'lat_max': 35, 'lon_min': -180, 'lon_max': -150},
        'wsp_1' : {'lat_min': -30, 'lat_max': -15, 'lon_min': 155, 'lon_max': 180}, 
        'wsp_2' : {'lat_min': -30, 'lat_max': -15, 'lon_min': -180, 'lon_max': -150},
    }
    # Get index for region
    return indices_definitions[region]


def standardize_index(data, index_name, period_clm, year_fcst, month_init, before = False):
    # Check if trying to calculate for first month and year
    first_year = data['year'].min().item()
    if year_fcst == first_year and month_init == 1:
        raise ValueError("Cannot calculate the index for the first month in the first year of the dataset.")
    
    # Calculate the index for the entire dataset
    if index_name in ["n4", "dmi", "wnp", "wsp"]: 
        region_1 = get_region_indices(f"{index_name}_1")
        region_2 = get_region_indices(f"{index_name}_2")
        subset_1 = data.sel(lat=slice(region_1['lat_min'], region_1['lat_max']), lon=slice(region_1['lon_min'], region_1['lon_max']))
        subset_2 = data.sel(lat=slice(region_2['lat_min'], region_2['lat_max']), lon=slice(region_2['lon_min'], region_2['lon_max']))
        
        if index_name == "dmi":
            index_dmi_1 = subset_1.mean(dim=['lat', 'lon'])
            index_dmi_2 = subset_2.mean(dim=['lat', 'lon'])
            
            if before:
                # Select the reference period from the indices
                ref_data_dmi_1 = index_dmi_1.sel(year=slice(period_clm[0], period_clm[1]))
                ref_data_dmi_2 = index_dmi_2.sel(year=slice(period_clm[0], period_clm[1]))

                # Calculate the climatology (mean) and standard deviation during the reference period for both indices
                climatology_dmi_1 = ref_data_dmi_1.mean(dim='year')
                climatology_std_dmi_1 = ref_data_dmi_1.std(dim='year', ddof=1)
                climatology_dmi_2 = ref_data_dmi_2.mean(dim='year')
                climatology_std_dmi_2 = ref_data_dmi_2.std(dim='year', ddof=1)

                # Standardize the entire index data
                standardized_index_dmi_1 = (index_dmi_1 - climatology_dmi_1) / climatology_std_dmi_1
                standardized_index_dmi_2 = (index_dmi_2 - climatology_dmi_2) / climatology_std_dmi_2

                # Calculate the difference between the standardized indices
                index = standardized_index_dmi_1 - standardized_index_dmi_2
            else:
                # Calculate the difference between the indices
                index = index_dmi_1 - index_dmi_2
        else:
            combined_subset = xr.concat([subset_1, subset_2], dim='lat')
            index = combined_subset.mean(dim=['lat', 'lon'])

    elif index_name == "wpg":
        # Calculate the index for the entire dataset
        region_wp = get_region_indices(index_name)
        region_n4_1 = get_region_indices("n4_1")
        region_n4_2 = get_region_indices("n4_2")
        subset_wp = data.sel(lat=slice(region_wp['lat_min'], region_wp['lat_max']), lon=slice(region_wp['lon_min'], region_wp['lon_max']))
        subset_n4_1 = data.sel(lat=slice(region_n4_1['lat_min'], region_n4_1['lat_max']), lon=slice(region_n4_1['lon_min'], region_n4_1['lon_max']))
        subset_n4_2 = data.sel(lat=slice(region_n4_2['lat_min'], region_n4_2['lat_max']), lon=slice(region_n4_2['lon_min'], region_n4_2['lon_max']))
        combined_subset_n4 = xr.concat([subset_n4_1, subset_n4_2], dim='lat')
        
        index_wp = subset_wp.mean(dim=['lat', 'lon'])
        index_n4 = combined_subset_n4.mean(dim=['lat', 'lon'])

        if before:
            # Select the reference period from the indices
            ref_data_wp = index_wp.sel(year=slice(period_clm[0], period_clm[1]))
            ref_data_n4 = index_n4.sel(year=slice(period_clm[0], period_clm[1]))

            # Calculate the climatology (mean) and standard deviation during the reference period for both indices
            climatology_wp = ref_data_wp.mean(dim='year')
            climatology_std_wp = ref_data_wp.std(dim='year', ddof=1)
            climatology_n4 = ref_data_n4.mean(dim='year')
            climatology_std_n4 = ref_data_n4.std(dim='year', ddof=1)

            # Standardize the entire index data
            standardized_index_wp = (index_wp - climatology_wp) / climatology_std_wp
            standardized_index_n4 = (index_n4 - climatology_n4) / climatology_std_n4

            # Calculate the difference between the standardized indices
            index = standardized_index_n4 - standardized_index_wp
        else:
            # Calculate the difference between the indices
            index = index_n4 - index_wp

    else:
        region = get_region_indices(index_name)
        subset = data.sel(lat=slice(region['lat_min'], region['lat_max']), lon=slice(region['lon_min'], region['lon_max']))
        index = subset.mean(dim=['lat', 'lon'])

    if before:
        if month_init == 1:
            current_datapoint = index.sel(year=year_fcst-1, month=12)
        else:
            current_datapoint = index.sel(year=year_fcst, month=month_init-1)

        return current_datapoint
    
    else:
        # Select the reference period from the index
        ref_data = index.sel(year=slice(period_clm[0], period_clm[1]))

        # Calculate the climatology (mean) and standard deviation during the reference period
        climatology = ref_data.mean(dim=['year'])
        climatology_std = ref_data.std(dim=['year'], ddof = 1)

        # Standardize the current data point based on data from previous month
        if month_init == 1:
            current_datapoint = index.sel(year = year_fcst-1, month = 12)
            climatology = climatology.sel(month = 12)
            climatology_std = climatology_std.sel(month = 12)
        else:
            current_datapoint = index.sel(year = year_fcst, month = month_init-1)
            climatology = climatology.sel(month = month_init-1)
            climatology_std = climatology_std.sel(month = month_init-1)
        
        anomalies = current_datapoint - climatology
        standardized_anomalies = anomalies / climatology_std
        
        return standardized_anomalies


def standardize_index_diff1(data, index_name, period_clm, year_fcst, month_init, before = False):
    # Check if trying to calculate for first month and year
    first_year = data['year'].min().item()
    if year_fcst == first_year and month_init == 1:
        raise ValueError("Cannot calculate the index for the first month in the first year of the dataset.")
    
    # Calculate the index for the entire dataset
    if index_name == "n34":
        region = get_region_indices(index_name)
        subset = data.sel(lat=slice(region['lat_min'], region['lat_max']), lon=slice(region['lon_min'], region['lon_max']))
        index = subset.mean(dim=['lat', 'lon'])
    elif index_name == "dmi":
        region_dmi_1 = get_region_indices("dmi_1")
        region_dmi_2 = get_region_indices("dmi_2")
        subset_dmi_1 = data.sel(lat=slice(region_dmi_1['lat_min'], region_dmi_1['lat_max']), lon=slice(region_dmi_1['lon_min'], region_dmi_1['lon_max']))
        subset_dmi_2 = data.sel(lat=slice(region_dmi_2['lat_min'], region_dmi_2['lat_max']), lon=slice(region_dmi_2['lon_min'], region_dmi_2['lon_max']))
        index_dmi_1 = subset_dmi_1.mean(dim=['lat', 'lon'])
        index_dmi_2 = subset_dmi_2.mean(dim=['lat', 'lon'])

        if before:
                # Select the reference period from the indices
                ref_data_dmi_1 = index_dmi_1.sel(year=slice(period_clm[0], period_clm[1]))
                ref_data_dmi_2 = index_dmi_2.sel(year=slice(period_clm[0], period_clm[1]))

                # Calculate the climatology (mean) and standard deviation during the reference period for both indices
                climatology_dmi_1 = ref_data_dmi_1.mean(dim='year')
                climatology_std_dmi_1 = ref_data_dmi_1.std(dim='year', ddof=1)
                climatology_dmi_2 = ref_data_dmi_2.mean(dim='year')
                climatology_std_dmi_2 = ref_data_dmi_2.std(dim='year', ddof=1)

                # Standardize the entire index data
                standardized_index_dmi_1 = (index_dmi_1 - climatology_dmi_1) / climatology_std_dmi_1
                standardized_index_dmi_2 = (index_dmi_2 - climatology_dmi_2) / climatology_std_dmi_2

                # Calculate the difference between the standardized indices
                index = standardized_index_dmi_1 - standardized_index_dmi_2

        else:
            # Calculate the difference between the indices
            index = index_dmi_1 - index_dmi_2
    else:
        print(f"Diff1 not implemented for index {index_name}")
        raise TypeError(f"Diff1 not implemented for index {index_name}")

    # Select the reference period from the index
    ref_data = index.sel(year=slice(period_clm[0], period_clm[1]))

    if before:
        # Calculate the climatology (mean) and standard deviation during the reference period
        climatology = ref_data.mean(dim='year')
        climatology_std = ref_data.std(dim='year', ddof=1)

        # Standardize the entire index data
        standardized_index = (index - climatology) / climatology_std

        # Calculate the differences for all 12 months, including January using December from the previous year
        diff_list = []
        for year in range(period_clm[0], period_clm[1] + 1):
            for month in range(1, 13):
                if month == 1:
                    if year > period_clm[0]:  # Ensure we have the previous year's December data
                        current_value = standardized_index.sel(year=year, month=1)
                        previous_value = standardized_index.sel(year=year-1, month=12)
                        diff = (current_value - previous_value).assign_coords(year=year, month=month)
                        diff_list.append(diff)
                else:
                    current_value = standardized_index.sel(year=year, month=month)
                    previous_value = standardized_index.sel(year=year, month=month-1)
                    diff = (current_value - previous_value).assign_coords(year=year, month=month)
                    diff_list.append(diff)

    else:
        # Calculate the differences for all 12 months, including January using December from the previous year
        diff_list = []
        
        for year in range(period_clm[0], period_clm[1] + 1):
            for month in range(1, 13):
                if month == 1:
                    if year > period_clm[0]:  # Ensure we have the previous year's December data
                        current_value = ref_data.sel(year=year, month=1)
                        previous_value = ref_data.sel(year=year-1, month=12)
                        diff = (current_value - previous_value).assign_coords(year=year, month=month)
                        diff_list.append(diff)
                else:
                    current_value = ref_data.sel(year=year, month=month)
                    previous_value = ref_data.sel(year=year, month=month-1)
                    diff = (current_value - previous_value).assign_coords(year=year, month=month)
                    diff_list.append(diff)
    
    diff_data = xr.concat(diff_list, dim='time')

    climatology = diff_data.mean(dim='time')
    climatology_std = diff_data.std(dim='time', ddof = 1)

    if before:
        # Standardize the current data point based on data from the previous month
        if month_init == 1:
            current_datapoint = standardized_index.sel(year = year_fcst, month = 1)
            previous_datapoint = standardized_index.sel(year = year_fcst-1, month = 12)
        else:
            current_datapoint = standardized_index.sel(year = year_fcst, month = month_init)
            previous_datapoint = standardized_index.sel(year = year_fcst, month = month_init-1)
    else:
        # Standardize the current data point based on data from the previous month
        if month_init == 1:
            current_datapoint = index.sel(year = year_fcst, month = 1)
            previous_datapoint = index.sel(year = year_fcst-1, month = 12)
        else:
            current_datapoint = index.sel(year = year_fcst, month = month_init)
            previous_datapoint = index.sel(year = year_fcst, month = month_init-1)

    # Calculate the difference between the current and previous month
    difference = current_datapoint - previous_datapoint

    anomalies = difference - climatology
    standardized_anomalies = anomalies / climatology_std
    
    return standardized_anomalies

#### Code for plotting time series

In [None]:
def prepare_time_series_data(data, index_name, period_clm, period_train, months, diff1 = False):
    time_series_data = []

    for year in range(period_train[0], period_train[1] + 1):
        # print(f"Preparing data for year {year}")
        if (year == period_train[0]) & (months[0] == 1):
            months_loop = months[1:]
        else:
            months_loop = months
        for month in months_loop:
            if diff1:
                standardized_anomaly = standardize_index_diff1(data, index_name, period_clm, year, month, before = True)
            else:
                standardized_anomaly = standardize_index(data, index_name, period_clm, year, month)
            if index_name in ["ueq850", "ueq200", "sji850", "sji200"]:
                standardized_anomaly = standardized_anomaly.uwind.values
            else:
                standardized_anomaly = standardized_anomaly.sst.values
            if month == 1:
                year_prev = year - 1
                month_prev = 12
            else:
                year_prev = year
                month_prev = month - 1

            time_series_data.append({
                'year': year_prev,
                'month': month_prev,
                'standardized_anomaly': standardized_anomaly
            })

    df = pd.DataFrame(time_series_data)
    return df

In [None]:
def process_reference_index(reference_df, period_train):
    # Filter the reference dataframe for the specified time range
    reference_df_filtered = reference_df[(reference_df['year'] >= period_train[0]) & (reference_df['year'] <= period_train[1])]
    reference_df_filtered = reference_df_filtered[['year', 'month', 'fl']]
    # Drop duplicate rows
    reference_df_filtered = reference_df_filtered.drop_duplicates()
    return reference_df_filtered

In [None]:
def plot_time_series(df, index_name, comparison = False):
    df['date'] = pd.to_datetime(df[['year', 'month']].assign(day=1))
    df = df.sort_values('date')

    plt.figure(figsize=(12, 6))
    if comparison:
        plt.plot(df['date'], df['standardized_anomaly'], label=f"{index_name} Calculated", color='b')
        plt.plot(df['date'], df['fl'], label=f"{index_name} Reference", color='r', linestyle='--')
    else:
        plt.plot(df['date'], df['standardized_anomaly'], label=f"{index_name} Index", color='b')
    plt.title(f"Time Series of {index_name} Index")
    plt.xlabel("Time")
    plt.ylabel(f"{index_name} Index Value")
    plt.legend()
    plt.grid(True)
    plt.show()


#### Calculating and verifying indices

In [None]:
# Shared between all indices
months = list(range(1, 13))
start_year = year_train_start # Add fixing in standardize_index for the first month of the first year, which fails now
end_year = year_train_end
forecast_year = year_fcst
filepath_indices = f'/nr/samba/PostClimDataNoBackup/CONFER/EASP/fls/predictors/refper_1993-2020/indices/'

In [None]:
feature_names = ['n34','n3','n4','dmi','n34_diff1','dmi_diff1','wsp','wpg','wp','wnp','ueq850','ueq200','sji850','sji200']
time_series_n34_df = prepare_time_series_data(sst_data, "n34", period_clm, period_train, months)
feature_dfs = {}
for feature in feature_names:
    if feature.endswith("diff1"):
        time_series_df = prepare_time_series_data(sst_data, feature[:3], period_clm, period_train, months, diff1=True)
    elif feature in ['ueq850','sji850']:
        prepare_time_series_data(uwind850_data, feature, period_clm, period_train, months)
    elif feature in ['ueq200','sji200']:
        prepare_time_series_data(uwind200_data, feature, period_clm, period_train, months)
    else:
        time_series_df = prepare_time_series_data(sst_data, feature, period_clm, period_train, months)
    feature_dfs[feature] = time_series_df
# Add wvg as well
# Merge DataFrames
merge_for_wvg_df = reduce(lambda left, right: pd.merge(left, right, on=['year', 'month']), (
    feature_dfs["n4"].rename(columns={'standardized_anomaly': 'n4'}),
    feature_dfs["wp"].rename(columns={'standardized_anomaly': 'wp'}),
    feature_dfs["wnp"].rename(columns={'standardized_anomaly': 'wnp'}),
    feature_dfs["wsp"].rename(columns={'standardized_anomaly': 'wsp'}))
)
merge_for_wvg_df['standardized_anomaly'] = merge_for_wvg_df['n4'] - (merge_for_wvg_df['wp'] + merge_for_wvg_df['wnp'] + merge_for_wvg_df['wsp']) / 3
feature_dfs["wvg"] = merge_for_wvg_df

In [None]:
feature_dfs

In [None]:
# Convert each DataFrame to xarray DataArray
dataarrays = {}
for feature, df in feature_dfs.items():
    df['year'] = df['year'].astype(int)
    df['month'] = df['month'].astype(int)
    df.set_index(['year', 'month'], inplace=True)
    dataarrays[feature] = df.to_xarray()['standardized_anomaly']

# Combine DataArrays into a Dataset
ds = xr.Dataset(dataarrays)

# Define the file path
era5_indices_path = f'{val_dir}era5_indices.nc'

# Save the Dataset to a NetCDF file
ds.to_netcdf(era5_indices_path)
print(f"Data saved to {era5_indices_path}")

In [None]:
# # Load the NetCDF file into an xarray Dataset
# ds_loaded = xr.open_dataset(netcdf_file_path)
# print(f"Data loaded from {netcdf_file_path}")

# # Convert the xarray Dataset back to a DataFrame
# df_loaded = ds_loaded.to_dataframe().reset_index()

# # Print the first few rows of the DataFrame to verify
# print(df_loaded.head())

##### Index N34

In [None]:
time_series_n34_df = prepare_time_series_data(sst_data, "n34", period_clm, period_train, months)

# Reference values for index n34
n34_index_reference = pd.read_csv(f"{filepath_indices}n34_full.csv")
time_series_n34_reference_df = process_reference_index(n34_index_reference, period_train)

# Merge and plot dataframes
merged_df = pd.merge(time_series_n34_df, time_series_n34_reference_df, on=['year', 'month'], suffixes=('_calculated', '_reference'))
plot_time_series(merged_df, "n34", comparison = True)

# Print some values for comparison
for month in range(1, 3):
    print(f"Standardized anomaly (calculated index) for forecast year: {forecast_year}, forecast month = {month} (based on data from month before):")
    print(standardize_index(sst_data, "n34", period_clm, year_fcst=forecast_year, month_init=month).sst.values)
    print(f"Reference value:")
    if month == 1:
        print(n34_index_reference[(n34_index_reference["year"] == forecast_year-1) & (n34_index_reference["month"] == 12)].fl.values[0])
    else:
        print(n34_index_reference[(n34_index_reference["year"] == forecast_year) & (n34_index_reference["month"] == month-1)].fl.values[0])
    print("\n")

##### Index N3

In [None]:
time_series_n3_df = prepare_time_series_data(sst_data, "n3", period_clm, period_train, months)

# Reference values for index n3
n3_index_reference = pd.read_csv(f"{filepath_indices}n3_full.csv")
time_series_n3_reference_df = process_reference_index(n3_index_reference, period_train)

# Merge and plot dataframes
merged_df = pd.merge(time_series_n3_df, time_series_n3_reference_df, on=['year', 'month'], suffixes=('_calculated', '_reference'))
plot_time_series(merged_df, "n3", comparison = True)

# Print some values for comparison
for month in range(1, 3):
    print(f"Standardized anomaly (calculated index) for forecast year: {forecast_year}, forecast month = {month} (based on data from month before):")
    print(standardize_index(sst_data, "n3", period_clm, year_fcst=forecast_year, month_init=month).sst.values)
    print(f"Reference value:")
    if month == 1:
        print(n3_index_reference[(n3_index_reference["year"] == forecast_year-1) & (n3_index_reference["month"] == 12)].fl.values[0])
    else:
        print(n3_index_reference[(n3_index_reference["year"] == forecast_year) & (n3_index_reference["month"] == month-1)].fl.values[0])
    print("\n")

##### Index N4

In [None]:
time_series_n4_df = prepare_time_series_data(sst_data, "n4", period_clm, period_train, months)

# Reference values for index n4
n4_index_reference = pd.read_csv(f"{filepath_indices}n4_full.csv")
time_series_n4_reference_df = process_reference_index(n4_index_reference, period_train)

# Merge and plot dataframes
merged_df = pd.merge(time_series_n4_df, time_series_n4_reference_df, on=['year', 'month'], suffixes=('_calculated', '_reference'))
plot_time_series(merged_df, "n4", comparison = True)

# Print some values for comparison
for month in range(1, 3):
    print(f"Standardized anomaly (calculated index) for forecast year: {forecast_year}, forecast month = {month} (based on data from month before):")
    print(standardize_index(sst_data, "n4", period_clm, year_fcst=forecast_year, month_init=month).sst.values)
    print(f"Reference value:")
    if month == 1:
        print(n4_index_reference[(n4_index_reference["year"] == forecast_year-1) & (n4_index_reference["month"] == 12)].fl.values[0])
    else:
        print(n4_index_reference[(n4_index_reference["year"] == forecast_year) & (n4_index_reference["month"] == month-1)].fl.values[0])
    print("\n")

##### Index WPG

In [None]:
time_series_wpg_df = prepare_time_series_data(sst_data, "wpg", period_clm, period_train, months)

# Reference values for index wpg
wpg_index_reference = pd.read_csv(f"{filepath_indices}wpg_full.csv")
time_series_wpg_reference_df = process_reference_index(wpg_index_reference, period_train)

# Merge and plot dataframes
merged_df = pd.merge(time_series_wpg_df, time_series_wpg_reference_df, on=['year', 'month'], suffixes=('_calculated', '_reference'))
plot_time_series(merged_df, "wpg", comparison = True)

# Print some values for comparison
for month in range(1, 3):
    print(f"Standardized anomaly (calculated index) for forecast year: {forecast_year}, forecast month = {month} (based on data from month before):")
    print(standardize_index(sst_data, "wpg", period_clm, year_fcst=forecast_year, month_init=month, before = True).sst.values)
    print(f"Reference value:")
    if month == 1:
        print(wpg_index_reference[(wpg_index_reference["year"] == forecast_year-1) & (wpg_index_reference["month"] == 12)].fl.values[0])
    else:
        print(wpg_index_reference[(wpg_index_reference["year"] == forecast_year) & (wpg_index_reference["month"] == month-1)].fl.values[0])
    print("\n")

##### Index DMI

In [None]:
time_series_dmi_df = prepare_time_series_data(sst_data, "dmi", period_clm, period_train, months)

# Reference values for index dmi
dmi_index_reference = pd.read_csv(f"{filepath_indices}dmi_full.csv")
time_series_dmi_reference_df = process_reference_index(dmi_index_reference, period_train)

# Merge and plot dataframes
merged_df = pd.merge(time_series_dmi_df, time_series_dmi_reference_df, on=['year', 'month'], suffixes=('_calculated', '_reference'))
plot_time_series(merged_df, "dmi", comparison = True)

# Print some values for comparison
for month in range(1, 3):
    print(f"Standardized anomaly (calculated index) for forecast year: {forecast_year}, forecast month = {month} (based on data from month before):")
    print(standardize_index(sst_data, "dmi", period_clm, year_fcst=forecast_year, month_init=month, before = True).sst.values)
    print(f"Reference value:")
    if month == 1:
        print(dmi_index_reference[(dmi_index_reference["year"] == forecast_year-1) & (dmi_index_reference["month"] == 12)].fl.values[0])
    else:
        print(dmi_index_reference[(dmi_index_reference["year"] == forecast_year) & (dmi_index_reference["month"] == month-1)].fl.values[0])
    print("\n")

##### Index SJI850

In [None]:
time_series_sji850_df = prepare_time_series_data(uwind850_data, "sji850", period_clm, period_train, months)

# Reference values for index sji850
sji850_index_reference = pd.read_csv(f"{filepath_indices}sji850_full.csv")
time_series_sji850_reference_df = process_reference_index(sji850_index_reference, period_train)

# Merge and plot dataframes
merged_df = pd.merge(time_series_sji850_df, time_series_sji850_reference_df, on=['year', 'month'], suffixes=('_calculated', '_reference'))
plot_time_series(merged_df, "sji850", comparison = True)

# Print some values for comparison
for month in range(1, 3):
    print(f"Standardized anomaly (calculated index) for forecast year: {forecast_year}, forecast month = {month} (based on data from month before):")
    print(standardize_index(uwind850_data, "sji850", period_clm, year_fcst=forecast_year, month_init=month).uwind.values)
    print(f"Reference value:")
    if month == 1:
        print(sji850_index_reference[(sji850_index_reference["year"] == forecast_year-1) & (sji850_index_reference["month"] == 12)].fl.values[0])
    else:
        print(sji850_index_reference[(sji850_index_reference["year"] == forecast_year) & (sji850_index_reference["month"] == month-1)].fl.values[0])
    print("\n")

##### Index SJI200

In [None]:
time_series_sji200_df = prepare_time_series_data(uwind200_data, "sji200", period_clm, period_train, months)

# Reference values for index sji200
sji200_index_reference = pd.read_csv(f"{filepath_indices}sji200_full.csv")
time_series_sji200_reference_df = process_reference_index(sji200_index_reference, period_train)

# Merge and plot dataframes
merged_df = pd.merge(time_series_sji200_df, time_series_sji200_reference_df, on=['year', 'month'], suffixes=('_calculated', '_reference'))
plot_time_series(merged_df, "sji200", comparison = True)

# Print some values for comparison
for month in range(1, 3):
    print(f"Standardized anomaly (calculated index) for forecast year: {forecast_year}, forecast month = {month} (based on data from month before):")
    print(standardize_index(uwind200_data, "sji200", period_clm, year_fcst=forecast_year, month_init=month).uwind.values)
    print(f"Reference value:")
    if month == 1:
        print(sji200_index_reference[(sji200_index_reference["year"] == forecast_year-1) & (sji200_index_reference["month"] == 12)].fl.values[0])
    else:
        print(sji200_index_reference[(sji200_index_reference["year"] == forecast_year) & (sji200_index_reference["month"] == month-1)].fl.values[0])
    print("\n")

##### Index UEQ850

In [None]:
time_series_ueq850_df = prepare_time_series_data(uwind850_data, "ueq850", period_clm, period_train, months)

# Reference values for index ueq850
ueq850_index_reference = pd.read_csv(f"{filepath_indices}ueq850_full.csv")
time_series_ueq850_reference_df = process_reference_index(ueq850_index_reference, period_train)

# Merge and plot dataframes
merged_df = pd.merge(time_series_ueq850_df, time_series_ueq850_reference_df, on=['year', 'month'], suffixes=('_calculated', '_reference'))
plot_time_series(merged_df, "ueq850", comparison = True)

# Print some values for comparison
for month in range(1, 3):
    print(f"Standardized anomaly (calculated index) for forecast year: {forecast_year}, forecast month = {month} (based on data from month before):")
    print(standardize_index(uwind850_data, "ueq850", period_clm, year_fcst=forecast_year, month_init=month).uwind.values)
    print(f"Reference value:")
    if month == 1:
        print(ueq850_index_reference[(ueq850_index_reference["year"] == forecast_year-1) & (ueq850_index_reference["month"] == 12)].fl.values[0])
    else:
        print(ueq850_index_reference[(ueq850_index_reference["year"] == forecast_year) & (ueq850_index_reference["month"] == month-1)].fl.values[0])
    print("\n")

##### Index UEQ200

In [None]:
time_series_ueq200_df = prepare_time_series_data(uwind200_data, "ueq200", period_clm, period_train, months)

# Reference values for index ueq200
ueq200_index_reference = pd.read_csv(f"{filepath_indices}ueq200_full.csv")
time_series_ueq200_reference_df = process_reference_index(ueq200_index_reference, period_train)

# Merge and plot dataframes
merged_df = pd.merge(time_series_ueq200_df, time_series_ueq200_reference_df, on=['year', 'month'], suffixes=('_calculated', '_reference'))
plot_time_series(merged_df, "ueq200", comparison = True)

# Print some values for comparison
for month in range(1, 3):
    print(f"Standardized anomaly (calculated index) for forecast year: {forecast_year}, forecast month = {month} (based on data from month before):")
    print(standardize_index(uwind200_data, "ueq200", period_clm, year_fcst=forecast_year, month_init=month).uwind.values)
    print(f"Reference value:")
    if month == 1:
        print(ueq200_index_reference[(ueq200_index_reference["year"] == forecast_year-1) & (ueq200_index_reference["month"] == 12)].fl.values[0])
    else:
        print(ueq200_index_reference[(ueq200_index_reference["year"] == forecast_year) & (ueq200_index_reference["month"] == month-1)].fl.values[0])
    print("\n")

##### Index WP

In [None]:
time_series_wp_df = prepare_time_series_data(sst_data, "wp", period_clm, period_train, months)

# Reference values for index wp
wp_index_reference = pd.read_csv(f"{filepath_indices}wp_full.csv")
time_series_wp_reference_df = process_reference_index(wp_index_reference, period_train)

# Merge and plot dataframes
merged_df = pd.merge(time_series_wp_df, time_series_wp_reference_df, on=['year', 'month'], suffixes=('_calculated', '_reference'))
plot_time_series(merged_df, "wp", comparison = True)

# Print some values for comparison
for month in range(1, 3):
    print(f"Standardized anomaly (calculated index) for forecast year: {forecast_year}, forecast month = {month} (based on data from month before):")
    print(standardize_index(sst_data, "wp", period_clm, year_fcst=forecast_year, month_init=month).sst.values)
    print(f"Reference value:")
    if month == 1:
        print(wp_index_reference[(wp_index_reference["year"] == forecast_year-1) & (wp_index_reference["month"] == 12)].fl.values[0])
    else:
        print(wp_index_reference[(wp_index_reference["year"] == forecast_year) & (wp_index_reference["month"] == month-1)].fl.values[0])
    print("\n")

##### Index WNP

In [None]:
time_series_wnp_df = prepare_time_series_data(sst_data, "wnp", period_clm, period_train, months)

# Reference values for index wnp
wnp_index_reference = pd.read_csv(f"{filepath_indices}wnp_full.csv")
time_series_wnp_reference_df = process_reference_index(wnp_index_reference, period_train)

# Merge and plot dataframes
merged_df = pd.merge(time_series_wnp_df, time_series_wnp_reference_df, on=['year', 'month'], suffixes=('_calculated', '_reference'))
plot_time_series(merged_df, "wnp", comparison = True)

# Print some values for comparison
for month in range(1, 3):
    print(f"Standardized anomaly (calculated index) for forecast year: {forecast_year}, forecast month = {month} (based on data from month before):")
    print(standardize_index(sst_data, "wnp", period_clm, year_fcst=forecast_year, month_init=month).sst.values)
    print(f"Reference value:")
    if month == 1:
        print(wnp_index_reference[(wnp_index_reference["year"] == forecast_year-1) & (wnp_index_reference["month"] == 12)].fl.values[0])
    else:
        print(wnp_index_reference[(wnp_index_reference["year"] == forecast_year) & (wnp_index_reference["month"] == month-1)].fl.values[0])
    print("\n")

##### Index WSP

In [None]:
time_series_wsp_df = prepare_time_series_data(sst_data, "wsp", period_clm, period_train, months)

# Reference values for index wsp
wsp_index_reference = pd.read_csv(f"{filepath_indices}wsp_full.csv")
time_series_wsp_reference_df = process_reference_index(wsp_index_reference, period_train)

# Merge and plot dataframes
merged_df = pd.merge(time_series_wsp_df, time_series_wsp_reference_df, on=['year', 'month'], suffixes=('_calculated', '_reference'))
plot_time_series(merged_df, "wsp", comparison = True)

# Print some values for comparison
for month in range(1, 3):
    print(f"Standardized anomaly (calculated index) for forecast year: {forecast_year}, forecast month = {month} (based on data from month before):")
    print(standardize_index(sst_data, "wsp", period_clm, year_fcst=forecast_year, month_init=month).sst.values)
    print(f"Reference value:")
    if month == 1:
        print(wsp_index_reference[(wsp_index_reference["year"] == forecast_year-1) & (wsp_index_reference["month"] == 12)].fl.values[0])
    else:
        print(wsp_index_reference[(wsp_index_reference["year"] == forecast_year) & (wsp_index_reference["month"] == month-1)].fl.values[0])
    print("\n")

##### Index WVG

In [None]:
# Calculate index

# Merge DataFrames
merge_for_wvg_df = reduce(lambda left, right: pd.merge(left, right, on=['year', 'month']), (
    time_series_n4_df.rename(columns={'standardized_anomaly': 'n4'}),
    time_series_wp_df.rename(columns={'standardized_anomaly': 'wp'}),
    time_series_wnp_df.rename(columns={'standardized_anomaly': 'wnp'}),
    time_series_wsp_df.rename(columns={'standardized_anomaly': 'wsp'}))
)
merge_for_wvg_df['standardized_anomaly'] = merge_for_wvg_df['n4'] - (merge_for_wvg_df['wp'] + merge_for_wvg_df['wnp'] + merge_for_wvg_df['wsp']) / 3
time_series_wvg_df = merge_for_wvg_df

# Reference values for index wvg
wvg_index_reference = pd.read_csv(f"{filepath_indices}wvg_full.csv")
time_series_wvg_reference_df = process_reference_index(wvg_index_reference, period_train)

# Plot dataframes
merged_df = pd.merge(time_series_wvg_df, time_series_wvg_reference_df, on=['year', 'month'], suffixes=('_calculated', '_reference'))
plot_time_series(merged_df, "wvg", comparison = True)

# Print some values for comparison
for month in range(1, 3):
    print(f"Standardized anomaly (calculated index) for forecast year: {forecast_year}, forecast month = {month} (based on data from month before):")
    n4 = standardize_index(sst_data, "n4", period_clm, year_fcst=forecast_year, month_init=month).sst.values
    wp = standardize_index(sst_data, "wp", period_clm, year_fcst=forecast_year, month_init=month).sst.values
    wnp = standardize_index(sst_data, "wnp", period_clm, year_fcst=forecast_year, month_init=month).sst.values
    wsp = standardize_index(sst_data, "wsp", period_clm, year_fcst=forecast_year, month_init=month).sst.values
    wvg = n4 - (wp + wnp + wsp) / 3
    print(wvg)
    print(f"Reference value:")
    if month == 1:
        print(wvg_index_reference[(wvg_index_reference["year"] == forecast_year-1) & (wvg_index_reference["month"] == 12)].fl.values[0])
    else:
        print(wvg_index_reference[(wvg_index_reference["year"] == forecast_year) & (wvg_index_reference["month"] == month-1)].fl.values[0])
    print("\n")

##### Index N34_DIFF1

In [None]:
time_series_n34_diff1_df = prepare_time_series_data(sst_data, "n34", period_clm, period_train, months, diff1 = True)

# Reference values for index n34_diff1
n34_diff1_index_reference = pd.read_csv(f"{filepath_indices}n34_diff1_full.csv")
time_series_n34_diff1_reference_df = process_reference_index(n34_diff1_index_reference, period_train)

# Merge and plot dataframes
merged_df = pd.merge(time_series_wpg_df, time_series_wpg_reference_df, on=['year', 'month'], suffixes=('_calculated', '_reference'))
plot_time_series(merged_df, "n34_diff1", comparison = True)

for month in range(1, 3):
    print(f"Standardized anomaly (calculated index) for forecast year: {forecast_year}, forecast month = {month} (based on data from month before):")
    print(standardize_index_diff1(sst_data, "n34", period_clm, year_fcst=forecast_year, month_init=month, before = True).sst.values)
    print(f"Reference value:")
    if month == 1:
        print(n34_diff1_index_reference[(n34_diff1_index_reference["year"] == forecast_year-1) & (n34_diff1_index_reference["month"] == 12)].fl.values[0])
    else:
        print(n34_diff1_index_reference[(n34_diff1_index_reference["year"] == forecast_year) & (n34_diff1_index_reference["month"] == month-1)].fl.values[0])
    print("\n")

##### Index DMI_DIFF1

In [None]:
time_series_dmi_diff1_df = prepare_time_series_data(sst_data, "dmi", period_clm, period_train, months, diff1 = True)

# Reference values for index dmi_diff1
dmi_diff1_index_reference = pd.read_csv(f"{filepath_indices}dmi_diff1_full.csv")
time_series_dmi_diff1_reference_df = process_reference_index(dmi_diff1_index_reference, period_train)

# Merge and plot dataframes
merged_df = pd.merge(time_series_wpg_df, time_series_wpg_reference_df, on=['year', 'month'], suffixes=('_calculated', '_reference'))
plot_time_series(merged_df, "dmi_diff1", comparison = True)

for month in range(1, 3):
    print(f"Standardized anomaly (calculated index) for forecast year: {forecast_year}, forecast month = {month} (based on data from month before):")
    print(standardize_index_diff1(sst_data, "dmi", period_clm, year_fcst=forecast_year, month_init=month, before = True).sst.values)
    print(f"Reference value:")
    if month == 1:
        print(dmi_diff1_index_reference[(dmi_diff1_index_reference["year"] == forecast_year-1) & (dmi_diff1_index_reference["month"] == 12)].fl.values[0])
    else:
        print(dmi_diff1_index_reference[(dmi_diff1_index_reference["year"] == forecast_year) & (dmi_diff1_index_reference["month"] == month-1)].fl.values[0])
    print("\n")

## ML model

In [None]:
# Precip data:
# eofs_norm_anomalies_reshaped, numpy array of shape (eof = n_eofs, lat = 67, lon = 59) 
# pcs_ano_normal, numpy array of shape (year = train_end - train_start, n_eofs)
# var_frac_ano_normal, numpy array of shape (n_eofs, )
# Features/predictors: (dataframe with shape (year, month, standardized_anomaly))
# time_series_{feature}_df for feature in ['n34','dmi','wvg','wsp','wpg','wp','wnp','n34_diff1','dmi_diff1','ueq850','ueq200','sji850','sji200']

feature_names = ['n34','dmi','wvg','wsp','wpg','wp','wnp','n34_diff1','dmi_diff1','ueq850','ueq200','sji850','sji200']
feature_dfs = {feature: eval(f"time_series_{feature}_df") for feature in feature_names}

In [None]:
simplefilter("ignore", category=ConvergenceWarning)

In [None]:
# Cross-validation setup
years_cv = list(range(year_train_start, year_train_end+1))
years_verif = list(range(year_clm_start, year_clm_end+1))
df_year = pd.DataFrame((np.array(years_cv) - 2000) / 10, index=years_cv, columns=['year'])

In [None]:
# Number of EOFs
ntg = n_eofs

# Calculate explained variance
frac_expl_var = var_fracs / np.sum(var_fracs)
wgt_values = np.sqrt(frac_expl_var)
wgt = pd.DataFrame(np.tile(wgt_values, (len(years_verif), 1)), index=years_verif, columns=[f'eof{i+1}' for i in range(ntg)])

In [None]:
# Cross-validation folds

k = 5
nyr = len(years_cv)

cv_folds = []
for i in range(k):
    idx_test = set(range(i*2, nyr, k*2)).union(set(range(1+i*2, nyr, k*2)))
    idx_train = set(range(nyr)) - idx_test
    cv_folds.append((list(idx_train), list(idx_test)))

In [None]:
# DataFrames for storing results
df_fl_pred_mean = pd.DataFrame(index=years_verif, columns=[f'fl{i}' for i in range(1, ntg+1)])
df_fl_pred_cov = pd.DataFrame(index=years_verif, columns=[f'cov-{i}{j}' for i in range(1, ntg+1) for j in range(1, ntg+1)])
df_hyperparameters = pd.DataFrame(index=years_verif, columns=['alpha', 'l1_ratio'])
df_selected_features = pd.DataFrame(index=years_verif, columns=['year'] + feature_names, dtype=int)

In [None]:
# Fit models and make predictions
previous_month = month_init - 1 if month_init > 1 else 12

for iyr in years_verif:
    df_target = pd.DataFrame(pcs[:, :ntg], index=years_cv).reindex(years_cv)
    
    df_combined_features = pd.DataFrame(index = years_cv)

    # Select the data for the month before month_init and combine features
    for feature in feature_names:
        df_feature = feature_dfs[feature] 
        df_feature_selected = df_feature[df_feature['month'] == previous_month].set_index('year')['standardized_anomaly']
        df_combined_features[feature] = df_feature_selected

     # Add the standardized year as a predictor
    # df_combined_features['year'] = (df_combined_features.index - 2000) / 10

    # Ensure there are no missing values
    df_combined_features.fillna(0, inplace=True)

    y = df_target.to_numpy()
    X = df_combined_features.to_numpy()

    # Feature pre-selection
    feature_idx = [True] + [False] * len(df_combined_features.columns)
    for ift in range(len(feature_idx) - 1):
        pval = [pearsonr(y[:, ipc], X[:, ift])[1] for ipc in range(ntg)]
        feature_idx[1 + ift] = np.any(np.array(pval) < 0.1 * wgt.iloc[0, :])
    
    df_combined_features = df_combined_features.iloc[:, feature_idx[1:]]
    df_year = pd.DataFrame((df_combined_features.index - 2000) / 10, index=df_combined_features.index, columns=['year'])
    selected_columns = ['year'] + df_combined_features.columns.to_list()

    # Ensure df_selected_features has the correct columns
    if not set(selected_columns).issubset(df_selected_features.columns):
        for col in selected_columns:
            if col not in df_selected_features.columns:
                df_selected_features[col] = 0

    #print(f'{iyr}: {sum(feature_idx)} features selected')
    if sum(feature_idx) == 1:
        df_combined_features = df_year
    else:
        df_combined_features = pd.concat([df_year, df_combined_features], axis=1)

    X = df_combined_features.to_numpy()
    # Lasso regression
    clf = MultiTaskLassoCV(cv=cv_folds, fit_intercept=False, max_iter=5000)
    clf.fit(X, y)
    df_hyperparameters.loc[iyr, 'alpha'] = clf.alpha_
    df_hyperparameters.loc[iyr, 'l1_ratio'] = 1.0
    ind_active = np.all(clf.coef_ != 0, axis=0)
    #print(ind_active)
    n_a = sum(ind_active)
    #print(n_a)
    if n_a > 0:
        active_features = np.where(np.insert(ind_active, 0, True), 1, 0)  # Insert True for 'year' column
        # Adjust the length of active_features to match selected_columns
        if len(active_features) != len(selected_columns):
            if len(active_features) < len(selected_columns):
                active_features = np.append(active_features, [0] * (len(selected_columns) - len(active_features)))
            else:
                active_features = active_features[:len(selected_columns)]
        
        df_selected_features.loc[iyr, selected_columns] = active_features
        dgf = 1 + n_a
    else:
        dgf = 1
    df_coefficients = pd.DataFrame(clf.coef_, index=[f'eof{i}' for i in range(1,ntg+1)], columns=df_combined_features.columns)
    #print(df_coefficients)
    # Compute prediction covariance
    errors = y - clf.predict(X)
    df_fl_pred_cov.loc[iyr, :] = np.broadcast_to((np.dot(errors.T, errors) / (nyr - dgf)).flatten(), (1, ntg ** 2))

#print(df_fl_pred_cov)
# Save results
# df_coefficients.to_csv(f'path_to_coefficients.csv')
# df_fl_pred_cov.to_csv(f'path_to_fl_pred_cov.csv')

In [None]:
# Plot 1: CHIRPS vs. GPCC (short) vs. GPCC (long)

month = 8
season = 'OND'

month_str = {1:"January", 2:"February", 3:"March", 4:"April", 5:"May", 6:"June", 7:"July", 8:"August", 9:"September", 10:"October", 11:"November", 12:"December"}[month]

fig, ax = plt.subplots(1,3, figsize=(14,4), width_ratios=[9, 8, 8])
for i, target_prod, cv_period in zip([*range(3)],['chirps'],['1981-2020']):
    df_coefficients = df_coefficients
    nft = len(df_coefficients.columns)
    ntg = len(df_coefficients.index)
    img = ax[i].imshow(df_coefficients, vmin={'MAM':-80., 'JJAS':-160., 'OND':-230.}[season], vmax={'MAM':80., 'JJAS':160., 'OND':230.}[season], cmap='bwr', extent=[0,nft,0,ntg])
    ax[i].set_xticks(np.arange(.5,nft))
    ax[i].set_xticklabels(df_coefficients.columns.to_list(), rotation=90, fontsize=10)
    ax[i].set_yticks(np.arange(ntg-.5,0,-1))
    ax[i].set_yticklabels(df_coefficients.index.to_list(), fontsize=10)
    ax[i].set_title(f'{target_prod.upper()}, training period: {cv_period}')

fig.subplots_adjust(right=0.85)
cbar_ax = fig.add_axes([0.9, 0.15, 0.03, 0.6])
fig.colorbar(img, cax=cbar_ax)
#fig.suptitle(f'LASSO regression coefficients for {season} forecast issued in {month_str}', fontsize=14)

#plt.savefig(f'{data_dir}plots_ml_diagnostics/coef_training_period_comparison_im{month}_{season}.png')

In [None]:
def calculate_tercile_probability_forecasts(season, year_fcst, month_init, period_clm, scaling, eofs, eigenvalues, coefficients, fl_eof_cov, ts_indices):

    # Ensure columns and indices alignment
    if not all(coefficients.columns == ts_indices.index):
        raise ValueError("Columns of coefficients and indices of ts_indices are not aligned.")
    
    # Calculate residual variance
    var_eps = scaling**2 - np.sum(eigenvalues[:, None, None] * eofs**2, axis=0)

    # Ensure ts_indices has year standardized
    ts_indices['year'] = (year_fcst - 2000) / 10

    # Calculate predictive mean of factor loadings
    fl_eof_mean = coefficients.dot(ts_indices).to_numpy()

    # Calculate mean and variance of the probabilistic forecast in normal space
    mean_ml = np.sum(fl_eof_mean[:, None, None] * eofs, axis=0)
    var_ml = np.sum(np.sum(fl_eof_cov[:, :, None, None] * eofs[None, :, :, :], axis=1) * eofs, axis=0) + var_eps

    mean_ml = np.array(mean_ml, dtype=np.float64)
    var_ml = np.array(var_ml, dtype=np.float64)

    mean_ml_stdz = mean_ml / scaling
    stdv_ml_stdz = np.sqrt(var_ml) / scaling

    # Calculate tercile forecasts
    prob_bn = norm.cdf((norm.ppf(0.333) - mean_ml_stdz) / stdv_ml_stdz)
    prob_an = 1.0 - norm.cdf((norm.ppf(0.667) - mean_ml_stdz) / stdv_ml_stdz)

    return prob_bn, prob_an

In [None]:
def plot_simple(fields, titles, cmap, unit, year):
    n_fields = len(fields)
    fig, axes = plt.subplots(1, n_fields, figsize=(15, 5), subplot_kw={'projection': None})

    for i, ax in enumerate(axes):
        im = ax.imshow(fields[i], extent=[lon.min(), lon.max(), lat.min(), lat.max()],
                       origin='lower', cmap=cmap[i], vmin=0.333, vmax=1)
        ax.set_title(titles[i])
        ax.set_xlabel('Longitude')
        ax.set_ylabel('Latitude')
        cbar = fig.colorbar(im, ax=ax, orientation='vertical')
        cbar.set_label(unit)

    fig.suptitle(f'Predicted tercile probabilities for {season} precipitation amounts, {year}', fontsize=16)
    plt.tight_layout()
    plt.show()

In [None]:
def calculate_percentiles(chirps_data, ref_period_indices):
    # Extract the reference period indice
    chirps_ref = chirps_data[ref_period_indices, :, :]
    
    # Calculate the 33rd and 67th percentiles for each grid point
    percentile_33 = np.nanpercentile(chirps_ref, 33, axis=0)
    percentile_67 = np.nanpercentile(chirps_ref, 67, axis=0)
    
    return percentile_33, percentile_67


In [None]:
def categorize_precipitation(chirps_data, percentiles_33, percentiles_67, year):
    # Select the specific year
    actual_precip = chirps_data[year - year_train_start, :, :]
    
    # Categorize the precipitation
    below_normal = actual_precip < percentiles_33
    above_normal = actual_precip > percentiles_67
    normal = (actual_precip >= percentiles_33) & (actual_precip <= percentiles_67)
    
    # Create a categorical array: 0 for below normal, 1 for normal, 2 for above normal
    categories = xr.where(below_normal, 0, xr.where(above_normal, 2, 1))
    
    return categories

In [None]:
def verify_predictions(prob_bn, prob_an, actual_categories):
    # Determine the predicted category based on highest probability
    pred_categories = xr.where(prob_bn > 0.4, 0, xr.where(prob_an > 0.4, 2, 1))
    
    # Compare predictions with actual categories
    correct_predictions = pred_categories == actual_categories
    
    return correct_predictions

In [None]:
def plot_verification(verification, prec_data, lon, lat):
    # Reintroduce NaNs based on the original prec_data
    masked_verification = np.ma.masked_where(np.isnan(prec_data[year_fcst - year_train_start, :, :]), verification)
    
    # Create a custom colormap
    cmap = mcolors.ListedColormap(['red', 'green', 'gray'])
    bounds = [0, 0.5, 1.5, 2]
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    fig, ax = plt.subplots(figsize=(7, 5), subplot_kw={'projection': None})
    
    im = ax.imshow(masked_verification, extent=[lon.min(), lon.max(), lat.min(), lat.max()],
                   origin='lower', cmap=cmap, norm=norm)
    
    ax.set_title(f'Verification of {season} {year_fcst} Precipitation Forecast')
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    
    # Create a color bar with the correct labels
    cbar = fig.colorbar(im, ax=ax, orientation='vertical', ticks=[0.25, 1, 1.75])
    cbar.ax.set_yticklabels(['Incorrect', 'Correct', 'Masked'])

    plt.tight_layout()
    plt.show()


In [None]:
simplefilter("ignore", category=RuntimeWarning)
# Season, year_fcst, month_init, period_train, period_clm, ntg predefined
# Model coeffiecients in df_coefficients
# previous_month = month_init - 1 if month_init > 1 else 12
year_fcst = 2020
for year_fcst in range(year_clm_start, year_clm_end+1):
    # Reshape the covariance matrix for each year
    def reshape_covariance_matrix(cov_df, ntg):
        reshaped_cov = np.zeros((len(cov_df), ntg, ntg))
        for i, year in enumerate(cov_df.index):
            reshaped_cov[i] = cov_df.loc[year].values.reshape(ntg, ntg)
        return reshaped_cov

    df_fl_pred_cov_reshaped = reshape_covariance_matrix(df_fl_pred_cov, ntg)

    # Use the reshaped covariance matrix with the right year
    cov_matrix_for_year = df_fl_pred_cov_reshaped[years_verif.index(year_fcst)]

    scaling = np.nanstd(anomalies_normal[ref_period_indices, :, :], axis = 0, ddof = 1)

    # Prepare ts_indices
    ts_indices = pd.Series(index=df_coefficients.columns)
    for feature in feature_names:
        df_feature = feature_dfs[feature]
        df_feature_selected = df_feature[df_feature['month'] == previous_month].set_index('year')['standardized_anomaly']
        ts_indices[feature] = df_feature_selected.loc[year_fcst]

    eigenvalues = var_frac_ano_normal  # The variance fraction for each EOF
    # Ensure alignment before using the function
    coefficients_columns = df_coefficients.columns.to_list()
    ts_indices = ts_indices.reindex(coefficients_columns)

    # Debugging statements
    # print("Coefficient Columns:", df_coefficients.columns)
    # print("TS Indices Index:", ts_indices.index)
    # print("TS Indices:", ts_indices)
    # print("EOFs Shape:", eofs_norm_anomalies_reshaped.shape)
    # print("Eigenvalues Shape:", eigenvalues.shape)
    # print("Coefficients Shape:", df_coefficients.shape)
    # print("FL EOF Covariance Shape:", df_fl_pred_cov_reshaped.shape)

    prob_bn, prob_an = calculate_tercile_probability_forecasts(season, year_fcst, month_init, period_clm, scaling, eofs_norm_anomalies_reshaped, eigenvalues, df_coefficients, cov_matrix_for_year, ts_indices)

    # Plot the probabilities
    plot_simple(fields=[prob_bn, prob_an],
                titles=['Below Normal', 'Above Normal'],
                cmap=['Oranges', 'Greens'],
                unit='Probability',
                year=year_fcst)

    # Define the reference period
    reference_period = period_clm

    # Calculate percentiles
    percentiles_33, percentiles_67 = calculate_percentiles(prec_data, ref_period_indices)

    actual_categories = categorize_precipitation(prec_data, percentiles_33, percentiles_67, year_fcst)
    verification = verify_predictions(prob_bn, prob_an, actual_categories)

    # Plot the verification
    plot_verification(verification, prec_data, lon, lat)


In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature

from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
from mpl_toolkits.axes_grid1 import make_axes_locatable
from memory_profiler import profile

def get_nearest_grid_index(lon_exmpl, lat_exmpl, lon_grid, lat_grid):
    ix = np.argmin(abs(lon_grid-lon_exmpl))
    iy = np.argmin(abs(lat_grid-lat_exmpl))
    return ix, iy

def get_xticks(x_extent, inc = 1):
    x_inc = np.arange(-180,180,inc)
    return(x_inc[np.where(np.logical_and(x_inc >= x_extent[0], x_inc <= x_extent[1]))])

def get_yticks(y_extent, inc = 1):
    y_inc = np.arange(-90,90,inc)
    return(y_inc[np.where(np.logical_and(y_inc >= y_extent[0], y_inc <= y_extent[1]))])

@profile
def plot_fields(fields_list, lon, lat, lon_bounds, lat_bounds, main_title, subtitle_list, unit, vmin=None, vmax=None, cmap='BuPu', water_bodies=False, ticks=True, tick_labels=None):

    n_img = len(fields_list)
    img_extent = lon_bounds + lat_bounds

    if not type(unit) is list:
        unit = [unit for i in range(n_img)]

    if not type(cmap) is list:
        cmap = [cmap for i in range(n_img)]

    if vmin == None:
        vmin = [np.nanmin(field) for field in fields_list]
    if vmax == None:
        vmax = [np.nanmax(field) for field in fields_list]

    if not type(vmin) is list:
        vmin = [vmin for i in range(n_img)]
    if not type(vmax) is list:
        vmax = [vmax for i in range(n_img)]       

    if ticks == True:
        ticks = [ticks for i in range(n_img)]
    elif not type(ticks) is list:
        print("Error! Argument 'ticks' must be a list or a list of lists.")
    elif all([isinstance(tt, float) or isinstance(tt, int) for tt in ticks]):
        ticks = [ticks for i in range(n_img)]

    if tick_labels == None:
        tick_labels = [tick_labels for i in range(n_img)]
    elif not type(tick_labels) is list:
        print("Error! Argument 'ticks_labels' must be a list or a list of lists.")
    elif all([isinstance(tt, float) or isinstance(tt, int) or isinstance(tt, str) for tt in tick_labels]):
        tick_labels = [tick_labels for i in range(n_img)]

    r = abs(lon[1]-lon[0])
    lons_mat, lats_mat = np.meshgrid(lon, lat)
    lons_matplot = np.hstack((lons_mat - r/2, lons_mat[:,[-1]] + r/2))
    lons_matplot = np.vstack((lons_matplot, lons_matplot[[-1],:]))
    lats_matplot = np.hstack((lats_mat, lats_mat[:,[-1]]))
    lats_matplot = np.vstack((lats_matplot - r/2, lats_matplot[[-1],:] + r/2))     # assumes latitudes in ascending order

    dlon = (lon_bounds[1]-lon_bounds[0]) // 8
    dlat = (lat_bounds[1]-lat_bounds[0]) // 8

    fig_height = 7.
    fig_width = (n_img*1.15)*(fig_height/1.1)*np.diff(lon_bounds)[0]/np.diff(lat_bounds)[0]

    fig = plt.figure(figsize=(fig_width,fig_height))
    for i_img in range(n_img):
        ax = fig.add_subplot(100+n_img*10+i_img+1, projection=ccrs.PlateCarree())
        cmesh = ax.pcolormesh(lons_matplot, lats_matplot, fields_list[i_img], cmap=cmap[i_img], vmin=vmin[i_img], vmax=vmax[i_img])
        ax.set_extent(img_extent, crs=ccrs.PlateCarree())
        ax.set_yticks(get_yticks(img_extent[2:4],dlat), crs=ccrs.PlateCarree())
        ax.yaxis.set_major_formatter(LatitudeFormatter()) 
        ax.set_xticks(get_xticks(img_extent[0:2],dlon), crs=ccrs.PlateCarree())
        ax.xaxis.set_major_formatter(LongitudeFormatter(zero_direction_label=True))
        ax.add_feature(cfeature.COASTLINE, linewidth=2)
        ax.add_feature(cfeature.BORDERS, linewidth=2, linestyle='-', alpha=.9)
        if water_bodies:
            ax.add_feature(cfeature.LAKES, alpha=0.95)
            ax.add_feature(cfeature.RIVERS)

        plt.title(subtitle_list[i_img], fontsize=14)
        divider = make_axes_locatable(ax)
        ax_cb = divider.new_horizontal(size="5%", pad=0.1, axes_class=plt.Axes)
        fig.add_axes(ax_cb)
        cbar = plt.colorbar(cmesh, cax=ax_cb)
        cbar.set_label(unit[i_img])
        if not ticks[i_img] == True:
            cbar.set_ticks(ticks[i_img])
            cbar.set_ticklabels(tick_labels[i_img])

    fig.canvas.draw()
    plt.tight_layout(rect=[0,0,1,0.95])
    fig.suptitle(main_title, fontsize=16)
    plt.show()



In [None]:
# plot_fields(fields_list = [prob_bn, prob_an],
#           lon = lon,
#           lat = lat,
#           lon_bounds = lon_bnds,
#           lat_bounds = lat_bnds,
#           main_title = f'Predicted tercile probabilities for {season} precipitation amounts',
#           subtitle_list = ['below normal','above normal'],
#           vmin = 0.333,
#           vmax = 1,
#           cmap = ['Oranges','Greens'],
#           unit = '')

In [None]:


#data_dir = '/home/michael/nr/samba/PostClimDataNoBackup/CONFER/EASP/'

# month = month_init                # month in which the forecast is issued; predictors are taken from previous month
# season_name = season
# ref_period = f'{year_clm_start}-{year_clm_end}'


# print(f'\n Generating {season} forecasts at halfdeg resolution')

# # -- Calculate explained variance by the respective EOFs-----------------------------------------------------------------------------------------------------------------------

# refper_start = int(ref_period[:4])
# refper_end = int(ref_period[-4:])

# nc = xr.open_dataset(f'{data_dir}eofs/chirps/halfdeg_res/refper_{ref_period}/prec_loyo_seasonal_{season}.nc')
# nc_subset = nc.sel(loy=slice(years_verif[0],years_verif[-1]))
# eigenvalues = (nc_subset.d.values**2) / (refper_end-refper_start)
# nc.close()

# for iyr in years_verif:
# #    print(iyr)
#     df_target = data_prec.loc[(slice(None),iyr,range(1,ntg+1)),:].unstack().droplevel(level=1).droplevel(level=0, axis=1).fillna(0.).reindex(years_cv)
#     data_indices_yr = []
#     for i in range(len(data_indices)):
#         data_indices_yr.append(data_indices[i].loc[(slice(None),month-1,iyr),:].droplevel(level=(1,2)))
#     df_features = pd.concat(data_indices_yr, axis=1).fillna(0.).reindex(years_cv)
#     y = df_target.to_numpy()
#     X = df_features.to_numpy()
#    # Feature pre-selection
#     feature_idx = [True]+[False]*len(df_features.columns)
#     for ift in range(len(feature_idx)-1):
#         pval = [pearsonr(y[:,ipc], X[:,ift])[1] for ipc in range(ntg)]
#         feature_idx[1+ift] = np.any(np.array(pval)<0.1*wgt.loc[iyr])
#     df_features = df_features.iloc[:,feature_idx[1:]]
#     df_selected_features.loc[iyr,['year']+df_features.columns.to_list()] = 0
#     print(f'{iyr}: {sum(feature_idx)} features selected')
#    # Interactions with b/n/a dummy variables of features
#     if sum(feature_idx) == 1:
#         df_features = df_year
#     else:
#         df_features = pd.concat([df_year, df_features], axis=1)

#     X = df_features.to_numpy()

#     # Lasso regression of pre-selected features
#     clf = MultiTaskLassoCV(cv=cv_folds, fit_intercept=False, max_iter=5000)
#     clf.fit(X, y)
#     df_hyperparameters.loc[iyr,'alpha'] = clf.alpha_
#     df_hyperparameters.loc[iyr,'l1_ratio'] = 1.
#     ind_active = np.all(clf.coef_!=0, axis=0)
#     n_a = sum(ind_active)
#     if n_a > 0:
#         df_selected_features.loc[iyr,df_features.columns.to_list()] = np.where(ind_active, 1, 0)
#         dgf = 1 + n_a
#     else:
#         dgf = 1
#    # In 'full' mode, we mainly care about the estimated coefficients
#     df_coefficients = pd.DataFrame(clf.coef_, index=[f'eof{i}' for i in range(1,ntg+1)], columns=df_features.columns)
#     errors = y - clf.predict(X)
#     df_fl_pred_cov.loc[:,:] = np.broadcast_to((np.dot(errors.T,errors)/(nyr-dgf)).flatten(), (len(years_verif),ntg**2))
#     break


# # if not path.exists(f'{data_dir}fls_pred/chirps/seasonal/halfdeg_res'):
# #     os.mkdir(f'{data_dir}fls_pred/chirps/seasonal/halfdeg_res')

# # if not path.exists(f'{data_dir}fls_pred/chirps/seasonal/halfdeg_res/refper_{ref_period}_cvper_{years_cv[0]}-{years_cv[-1]}'):
#     os.mkdir(f'{data_dir}fls_pred/chirps/seasonal/halfdeg_res/refper_{ref_period}_cvper_{years_cv[0]}-{years_cv[-1]}')

# df_coefficients.to_csv(f'{data_dir}fls_pred/chirps/seasonal/halfdeg_res/refper_{ref_period}_cvper_{years_cv[0]}-{years_cv[-1]}/coefficients_{predictor}_lasso_full_im{month}_{season}.csv')
# df_fl_pred_cov.to_csv(f'{data_dir}fls_pred/chirps/seasonal/halfdeg_res/refper_{ref_period}_cvper_{years_cv[0]}-{years_cv[-1]}/fls_cov_{predictor}_lasso_full_im{month}_{season}.csv')



Run LASSO model to predict precipitation EOFs based on indices.

In [None]:
# do it

Visualize fitted coefficients

In [None]:
# do it

Load indices for the forecast year and use the previously fitted model to make a forecast

In [None]:
#prob_fcst_below, prob_fcst_above = calculate_tercile_probability_forecasts(season, year_fcst, month_init, period_train, period_clm, indices_dir, anomaly_dir, eof_dir, fcst_dir)

Depict as a map.

In [None]:
# plot_fields (fields_list = [prob_fcst_below, prob_fcst_above],
#           lon = lon,
#           lat = lat,
#           lon_bounds = lon_bnds,
#           lat_bounds = lat_bnds,
#           main_title = f'Predicted tercile probabilities for {season} precipitation amounts',
#           subtitle_list = ['below normal','above normal'],
#           vmin = 0.333,
#           vmax = 1,
#           cmap = ['Oranges','Greens'],
#           unit = '')