In [None]:
import numpy as np                # library for mathematical operations with arrays
import pandas as pd               # library for data frames; includes useful functions for date arithmetic
import xarray as xr               # library for arrays, especially tailored to weather data
import matplotlib.pyplot as plt   # library for basic plotting
import matplotlib.colors as mcolors 
import datetime                   # library for date manipulation

from functools import reduce
from scipy.stats import norm, pearsonr

from sklearn.linear_model import MultiTaskLassoCV
from sklearn.exceptions import ConvergenceWarning
from warnings import simplefilter

import os
from os import path
import sys
sys.path.append("../src/confer_wp3/")
from plotting import plot_fields                                    # function for visualizing spatial data in a map
from confer_wp3.dataloading import load_raw_data, save_anomalies, save_eofs_pcs
from confer_wp3.lasso_forecast import calculate_anomalies, quantile_mapping, compute_eofs_pcs, get_all_indices

The following two variables specify the paths where the forecast and observation data are stored:

In [None]:
era5_dir = '/nr/samba/PostClimDataNoBackup/CONFER/EASP/raw_predictors/'
chirps_dir = '/nr/samba/PostClimDataNoBackup/CONFER/EASP/precip/chirps/'
indices_dir = '/home/michael/nr/samba/PostClimDataNoBackup/CONFER/EASP/fls/predictors/'
anomaly_dir = '/nr/samba/PostClimDataNoBackup/CONFER/EASP/precip/chirps/seasonal/halfdeg_res/'
eof_dir = '/home/michael/nr/samba/PostClimDataNoBackup/CONFER/EASP/eofs/chirps/halfdeg_res/'
fcst_dir = '/home/michael/nr/samba/PostClimDataNoBackup/CONFER/EASP/fls_pred/chirps/seasonal/halfdeg_res/'

Now, we set a number of parameters defining our forecast domain, training period, forecast year, etc.:

In [None]:
year_clm_start = 1993     # first year of the climatological reference period
year_clm_end = 2020       # last year of the climatological reference period

year_train_start = 1981   # first year of the training period
year_train_end = 2020     # last year of the training period

year_fcst = 2020          # year in which forecasts should be generated
month_init = 8            # month in which the forecast should be generated (based on data of the preceding month)
season = 'OND'

lon_bnds = [20, 53]       # longitude range of the domain of interest
lat_bnds = [-15, 23]      # latitude range of the domain of interest

period_clm = [year_clm_start, year_clm_end]
period_train = [year_train_start, year_train_end]

If not already available, load CHIRPS data, calculate seasonal precipitation anomalies, EOFs, and factor loading, and save out.

### CHIRPS Data

##### Loading data

In [None]:
# Temporary directory for storing results
val_dir = "/nr/samba/user/ahellevik/CONFER-WP3/validation_data/"

# Load data
year, lat, lon, prec_data = load_raw_data(chirps_dir, "chirps", [*range(year_train_start, year_train_end+1)], season, lat_bnds, lon_bnds) # 1993, 2021


##### Get anomalies and normalized anomalies

In [None]:
# Get anomalies
if not path.exists(f'{val_dir}chirps_anomalies.nc'):
    # Calculate anomalies
    anomalies = calculate_anomalies(prec_data, year, period_clm)
    # Save anomalies
    save_anomalies(anomalies, year, lat, lon, val_dir, normalized=False)
else:
    anomalies = xr.open_dataarray(f'{val_dir}chirps_anomalies.nc', engine='netcdf4')
    anomalies = anomalies.values

In [None]:
# Get normalized anomalies
if not path.exists(f'{val_dir}chirps_anomalies_normal.nc'):
    # Apply the transformation to the anomalies data
    anomalies_normal = quantile_mapping(anomalies, year, period_clm)
    # Save normalized anomalies
    save_anomalies(anomalies_normal, year, lat, lon, val_dir, normalized=True)
else:
    anomalies_normal = xr.open_dataarray(f'{val_dir}chirps_anomalies_normal.nc', engine='netcdf4')
    anomalies_normal = anomalies_normal.values

##### Get EOFs and factor loadings

In [None]:
# Get EOFs
if not path.exists(f'{val_dir}chirps_eofs.nc'):
    # Calculate EOFs
    n_eofs = 7  # Number of EOFs to compute
    eofs, pcs, var_fracs = compute_eofs_pcs(anomalies_normal, n_eofs)
    # Reshape EOFs to 3D (n_eofs, lat, lon)
    eofs_reshaped = eofs.reshape((n_eofs, len(lat), len(lon)))
    # Save EOFs, PCs and variance fractions
    save_eofs_pcs(eofs_reshaped, pcs, var_fracs, year, lat, lon, val_dir)
else:
    eofs_reshaped = xr.open_dataarray(f'{val_dir}chirps_eofs.nc', engine='netcdf4').values
    pcs = xr.open_dataarray(f'{val_dir}chirps_pcs.nc', engine='netcdf4').values
    var_fracs = xr.open_dataarray(f'{val_dir}chirps_var_fracs.nc', engine='netcdf4').values

### ERA5 Data

If not already available, load ERA5 data, calculate indices, and save out. Otherwise, load ERA5 data.

##### Get ERA5 indices

In [None]:
# Get era5 indices
# Define the file path
era5_indices_path = f'{val_dir}era5_indices.nc'
# Shared between all indices
months = list(range(1, 13))

if not path.exists(era5_indices_path):
    # Load needed data
    load_years = [i for i in range(min(year_train_start, year_clm_start), max(year_train_end+1, year_clm_end+1))]
    # Load sst data
    sst_data = load_raw_data(era5_dir, "sst", load_years, season)
    # Load uwind200 data
    uwind200_data = load_raw_data(era5_dir, "uwind200", load_years, season)
    # Load uwind850 data
    uwind850_data = load_raw_data(era5_dir, "uwind850", load_years, season)

    # Calculate indices
    era5_indices = get_all_indices(sst_data, uwind200_data, uwind850_data, period_clm, period_train, months)
    # Save indices
    # Convert DataFrame to xarray Dataset for saving
    ds = era5_indices.set_index(['year', 'month']).to_xarray()

    # Save the Dataset to a NetCDF file
    print("Saving indices...")
    ds.to_netcdf(era5_indices_path)
    print(f"Data saved to {era5_indices_path}")
else:
    # Load the NetCDF file into an xarray Dataset
    ds_loaded = xr.open_dataset(era5_indices_path, engine='netcdf4')
    print(f"Data loaded from {era5_indices_path}")

    # Convert the xarray Dataset back to a DataFrame
    era5_indices = ds_loaded.to_dataframe().reset_index()

### ML model

In [None]:
# Precip data:
# eofs_norm_anomalies_reshaped, numpy array of shape (eof = n_eofs, lat = 67, lon = 59) 
# pcs_ano_normal, numpy array of shape (year = train_end - train_start, n_eofs)
# var_frac_ano_normal, numpy array of shape (n_eofs, )
# Features/predictors: (dataframe with shape (year, month, standardized_anomaly))
# time_series_{feature}_df for feature in ['n34','dmi','wvg','wsp','wpg','wp','wnp','n34_diff1','dmi_diff1','ueq850','ueq200','sji850','sji200']

feature_names = ['n34','dmi','wvg','wsp','wpg','wp','wnp','n34_diff1','dmi_diff1','ueq850','ueq200','sji850','sji200']
feature_dfs = {feature: eval(f"time_series_{feature}_df") for feature in feature_names}

In [None]:
simplefilter("ignore", category=ConvergenceWarning)

In [None]:
# Cross-validation setup
years_cv = list(range(year_train_start, year_train_end+1))
years_verif = list(range(year_clm_start, year_clm_end+1))
df_year = pd.DataFrame((np.array(years_cv) - 2000) / 10, index=years_cv, columns=['year'])

In [None]:
# Number of EOFs
ntg = n_eofs

# Calculate explained variance
frac_expl_var = var_fracs / np.sum(var_fracs)
wgt_values = np.sqrt(frac_expl_var)
wgt = pd.DataFrame(np.tile(wgt_values, (len(years_verif), 1)), index=years_verif, columns=[f'eof{i+1}' for i in range(ntg)])

In [None]:
# Cross-validation folds

k = 5
nyr = len(years_cv)

cv_folds = []
for i in range(k):
    idx_test = set(range(i*2, nyr, k*2)).union(set(range(1+i*2, nyr, k*2)))
    idx_train = set(range(nyr)) - idx_test
    cv_folds.append((list(idx_train), list(idx_test)))

In [None]:
# DataFrames for storing results
df_fl_pred_mean = pd.DataFrame(index=years_verif, columns=[f'fl{i}' for i in range(1, ntg+1)])
df_fl_pred_cov = pd.DataFrame(index=years_verif, columns=[f'cov-{i}{j}' for i in range(1, ntg+1) for j in range(1, ntg+1)])
df_hyperparameters = pd.DataFrame(index=years_verif, columns=['alpha', 'l1_ratio'])
df_selected_features = pd.DataFrame(index=years_verif, columns=['year'] + feature_names, dtype=int)

In [None]:
# Fit models and make predictions
previous_month = month_init - 1 if month_init > 1 else 12

for iyr in years_verif:
    df_target = pd.DataFrame(pcs[:, :ntg], index=years_cv).reindex(years_cv)
    
    df_combined_features = pd.DataFrame(index = years_cv)

    # Select the data for the month before month_init and combine features
    for feature in feature_names:
        df_feature = feature_dfs[feature] 
        df_feature_selected = df_feature[df_feature['month'] == previous_month].set_index('year')['standardized_anomaly']
        df_combined_features[feature] = df_feature_selected

     # Add the standardized year as a predictor
    # df_combined_features['year'] = (df_combined_features.index - 2000) / 10

    # Ensure there are no missing values
    df_combined_features.fillna(0, inplace=True)

    y = df_target.to_numpy()
    X = df_combined_features.to_numpy()

    # Feature pre-selection
    feature_idx = [True] + [False] * len(df_combined_features.columns)
    for ift in range(len(feature_idx) - 1):
        pval = [pearsonr(y[:, ipc], X[:, ift])[1] for ipc in range(ntg)]
        feature_idx[1 + ift] = np.any(np.array(pval) < 0.1 * wgt.iloc[0, :])
    
    df_combined_features = df_combined_features.iloc[:, feature_idx[1:]]
    df_year = pd.DataFrame((df_combined_features.index - 2000) / 10, index=df_combined_features.index, columns=['year'])
    selected_columns = ['year'] + df_combined_features.columns.to_list()

    # Ensure df_selected_features has the correct columns
    if not set(selected_columns).issubset(df_selected_features.columns):
        for col in selected_columns:
            if col not in df_selected_features.columns:
                df_selected_features[col] = 0

    #print(f'{iyr}: {sum(feature_idx)} features selected')
    if sum(feature_idx) == 1:
        df_combined_features = df_year
    else:
        df_combined_features = pd.concat([df_year, df_combined_features], axis=1)

    X = df_combined_features.to_numpy()
    # Lasso regression
    clf = MultiTaskLassoCV(cv=cv_folds, fit_intercept=False, max_iter=5000)
    clf.fit(X, y)
    df_hyperparameters.loc[iyr, 'alpha'] = clf.alpha_
    df_hyperparameters.loc[iyr, 'l1_ratio'] = 1.0
    ind_active = np.all(clf.coef_ != 0, axis=0)
    #print(ind_active)
    n_a = sum(ind_active)
    #print(n_a)
    if n_a > 0:
        active_features = np.where(np.insert(ind_active, 0, True), 1, 0)  # Insert True for 'year' column
        # Adjust the length of active_features to match selected_columns
        if len(active_features) != len(selected_columns):
            if len(active_features) < len(selected_columns):
                active_features = np.append(active_features, [0] * (len(selected_columns) - len(active_features)))
            else:
                active_features = active_features[:len(selected_columns)]
        
        df_selected_features.loc[iyr, selected_columns] = active_features
        dgf = 1 + n_a
    else:
        dgf = 1
    df_coefficients = pd.DataFrame(clf.coef_, index=[f'eof{i}' for i in range(1,ntg+1)], columns=df_combined_features.columns)
    #print(df_coefficients)
    # Compute prediction covariance
    errors = y - clf.predict(X)
    df_fl_pred_cov.loc[iyr, :] = np.broadcast_to((np.dot(errors.T, errors) / (nyr - dgf)).flatten(), (1, ntg ** 2))

#print(df_fl_pred_cov)
# Save results
# df_coefficients.to_csv(f'path_to_coefficients.csv')
# df_fl_pred_cov.to_csv(f'path_to_fl_pred_cov.csv')

In [None]:
# Plot 1: CHIRPS vs. GPCC (short) vs. GPCC (long)

month = 8
season = 'OND'

month_str = {1:"January", 2:"February", 3:"March", 4:"April", 5:"May", 6:"June", 7:"July", 8:"August", 9:"September", 10:"October", 11:"November", 12:"December"}[month]

fig, ax = plt.subplots(1,3, figsize=(14,4), width_ratios=[9, 8, 8])
for i, target_prod, cv_period in zip([*range(3)],['chirps'],['1981-2020']):
    df_coefficients = df_coefficients
    nft = len(df_coefficients.columns)
    ntg = len(df_coefficients.index)
    img = ax[i].imshow(df_coefficients, vmin={'MAM':-80., 'JJAS':-160., 'OND':-230.}[season], vmax={'MAM':80., 'JJAS':160., 'OND':230.}[season], cmap='bwr', extent=[0,nft,0,ntg])
    ax[i].set_xticks(np.arange(.5,nft))
    ax[i].set_xticklabels(df_coefficients.columns.to_list(), rotation=90, fontsize=10)
    ax[i].set_yticks(np.arange(ntg-.5,0,-1))
    ax[i].set_yticklabels(df_coefficients.index.to_list(), fontsize=10)
    ax[i].set_title(f'{target_prod.upper()}, training period: {cv_period}')

fig.subplots_adjust(right=0.85)
cbar_ax = fig.add_axes([0.9, 0.15, 0.03, 0.6])
fig.colorbar(img, cax=cbar_ax)
#fig.suptitle(f'LASSO regression coefficients for {season} forecast issued in {month_str}', fontsize=14)

#plt.savefig(f'{data_dir}plots_ml_diagnostics/coef_training_period_comparison_im{month}_{season}.png')

In [None]:
def calculate_tercile_probability_forecasts(season, year_fcst, month_init, period_clm, scaling, eofs, eigenvalues, coefficients, fl_eof_cov, ts_indices):

    # Ensure columns and indices alignment
    if not all(coefficients.columns == ts_indices.index):
        raise ValueError("Columns of coefficients and indices of ts_indices are not aligned.")
    
    # Calculate residual variance
    var_eps = scaling**2 - np.sum(eigenvalues[:, None, None] * eofs**2, axis=0)

    # Ensure ts_indices has year standardized
    ts_indices['year'] = (year_fcst - 2000) / 10

    # Calculate predictive mean of factor loadings
    fl_eof_mean = coefficients.dot(ts_indices).to_numpy()

    # Calculate mean and variance of the probabilistic forecast in normal space
    mean_ml = np.sum(fl_eof_mean[:, None, None] * eofs, axis=0)
    var_ml = np.sum(np.sum(fl_eof_cov[:, :, None, None] * eofs[None, :, :, :], axis=1) * eofs, axis=0) + var_eps

    mean_ml = np.array(mean_ml, dtype=np.float64)
    var_ml = np.array(var_ml, dtype=np.float64)

    mean_ml_stdz = mean_ml / scaling
    stdv_ml_stdz = np.sqrt(var_ml) / scaling

    # Calculate tercile forecasts
    prob_bn = norm.cdf((norm.ppf(0.333) - mean_ml_stdz) / stdv_ml_stdz)
    prob_an = 1.0 - norm.cdf((norm.ppf(0.667) - mean_ml_stdz) / stdv_ml_stdz)

    return prob_bn, prob_an

In [None]:
def plot_simple(fields, titles, cmap, unit, year):
    n_fields = len(fields)
    fig, axes = plt.subplots(1, n_fields, figsize=(15, 5), subplot_kw={'projection': None})

    for i, ax in enumerate(axes):
        im = ax.imshow(fields[i], extent=[lon.min(), lon.max(), lat.min(), lat.max()],
                       origin='lower', cmap=cmap[i], vmin=0.333, vmax=1)
        ax.set_title(titles[i])
        ax.set_xlabel('Longitude')
        ax.set_ylabel('Latitude')
        cbar = fig.colorbar(im, ax=ax, orientation='vertical')
        cbar.set_label(unit)

    fig.suptitle(f'Predicted tercile probabilities for {season} precipitation amounts, {year}', fontsize=16)
    plt.tight_layout()
    plt.show()

In [None]:
def calculate_percentiles(chirps_data, ref_period_indices):
    # Extract the reference period indice
    chirps_ref = chirps_data[ref_period_indices, :, :]
    
    # Calculate the 33rd and 67th percentiles for each grid point
    percentile_33 = np.nanpercentile(chirps_ref, 33, axis=0)
    percentile_67 = np.nanpercentile(chirps_ref, 67, axis=0)
    
    return percentile_33, percentile_67


In [None]:
def categorize_precipitation(chirps_data, percentiles_33, percentiles_67, year):
    # Select the specific year
    actual_precip = chirps_data[year - year_train_start, :, :]
    
    # Categorize the precipitation
    below_normal = actual_precip < percentiles_33
    above_normal = actual_precip > percentiles_67
    normal = (actual_precip >= percentiles_33) & (actual_precip <= percentiles_67)
    
    # Create a categorical array: 0 for below normal, 1 for normal, 2 for above normal
    categories = xr.where(below_normal, 0, xr.where(above_normal, 2, 1))
    
    return categories

In [None]:
def verify_predictions(prob_bn, prob_an, actual_categories):
    # Determine the predicted category based on highest probability
    pred_categories = xr.where(prob_bn > 0.4, 0, xr.where(prob_an > 0.4, 2, 1))
    
    # Compare predictions with actual categories
    correct_predictions = pred_categories == actual_categories
    
    return correct_predictions

In [None]:
def plot_verification(verification, prec_data, lon, lat):
    # Reintroduce NaNs based on the original prec_data
    masked_verification = np.ma.masked_where(np.isnan(prec_data[year_fcst - year_train_start, :, :]), verification)
    
    # Create a custom colormap
    cmap = mcolors.ListedColormap(['red', 'green', 'gray'])
    bounds = [0, 0.5, 1.5, 2]
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    fig, ax = plt.subplots(figsize=(7, 5), subplot_kw={'projection': None})
    
    im = ax.imshow(masked_verification, extent=[lon.min(), lon.max(), lat.min(), lat.max()],
                   origin='lower', cmap=cmap, norm=norm)
    
    ax.set_title(f'Verification of {season} {year_fcst} Precipitation Forecast')
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    
    # Create a color bar with the correct labels
    cbar = fig.colorbar(im, ax=ax, orientation='vertical', ticks=[0.25, 1, 1.75])
    cbar.ax.set_yticklabels(['Incorrect', 'Correct', 'Masked'])

    plt.tight_layout()
    plt.show()


In [None]:
simplefilter("ignore", category=RuntimeWarning)
# Select the reference period using a boolean mask
ref_period_mask = (year >= period_clm[0]) & (year <= period_clm[1])
ref_period_indices = np.where(ref_period_mask)[0]
# Season, year_fcst, month_init, period_train, period_clm, ntg predefined
# Model coeffiecients in df_coefficients
# previous_month = month_init - 1 if month_init > 1 else 12
year_fcst = 2020
for year_fcst in range(year_clm_start, year_clm_end+1):
    # Reshape the covariance matrix for each year
    def reshape_covariance_matrix(cov_df, ntg):
        reshaped_cov = np.zeros((len(cov_df), ntg, ntg))
        for i, year in enumerate(cov_df.index):
            reshaped_cov[i] = cov_df.loc[year].values.reshape(ntg, ntg)
        return reshaped_cov

    df_fl_pred_cov_reshaped = reshape_covariance_matrix(df_fl_pred_cov, ntg)

    # Use the reshaped covariance matrix with the right year
    cov_matrix_for_year = df_fl_pred_cov_reshaped[years_verif.index(year_fcst)]

    scaling = np.nanstd(anomalies_normal[ref_period_indices, :, :], axis = 0, ddof = 1)

    # Prepare ts_indices
    ts_indices = pd.Series(index=df_coefficients.columns)
    for feature in feature_names:
        df_feature = feature_dfs[feature]
        df_feature_selected = df_feature[df_feature['month'] == previous_month].set_index('year')['standardized_anomaly']
        ts_indices[feature] = df_feature_selected.loc[year_fcst]

    eigenvalues = var_fracs  # The variance fraction for each EOF
    # Ensure alignment before using the function
    coefficients_columns = df_coefficients.columns.to_list()
    ts_indices = ts_indices.reindex(coefficients_columns)

    # Debugging statements
    # print("Coefficient Columns:", df_coefficients.columns)
    # print("TS Indices Index:", ts_indices.index)
    # print("TS Indices:", ts_indices)
    # print("EOFs Shape:", eofs_norm_anomalies_reshaped.shape)
    # print("Eigenvalues Shape:", eigenvalues.shape)
    # print("Coefficients Shape:", df_coefficients.shape)
    # print("FL EOF Covariance Shape:", df_fl_pred_cov_reshaped.shape)

    prob_bn, prob_an = calculate_tercile_probability_forecasts(season, year_fcst, month_init, period_clm, scaling, eofs_reshaped, eigenvalues, df_coefficients, cov_matrix_for_year, ts_indices)

    # Plot the probabilities
    plot_simple(fields=[prob_bn, prob_an],
                titles=['Below Normal', 'Above Normal'],
                cmap=['Oranges', 'Greens'],
                unit='Probability',
                year=year_fcst)

    # Define the reference period
    reference_period = period_clm

    # Calculate percentiles
    percentiles_33, percentiles_67 = calculate_percentiles(prec_data, ref_period_indices)

    actual_categories = categorize_precipitation(prec_data, percentiles_33, percentiles_67, year_fcst)
    verification = verify_predictions(prob_bn, prob_an, actual_categories)

    # Plot the verification
    plot_verification(verification, prec_data, lon, lat)


In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature

from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
from mpl_toolkits.axes_grid1 import make_axes_locatable
from memory_profiler import profile

def get_nearest_grid_index(lon_exmpl, lat_exmpl, lon_grid, lat_grid):
    ix = np.argmin(abs(lon_grid-lon_exmpl))
    iy = np.argmin(abs(lat_grid-lat_exmpl))
    return ix, iy

def get_xticks(x_extent, inc = 1):
    x_inc = np.arange(-180,180,inc)
    return(x_inc[np.where(np.logical_and(x_inc >= x_extent[0], x_inc <= x_extent[1]))])

def get_yticks(y_extent, inc = 1):
    y_inc = np.arange(-90,90,inc)
    return(y_inc[np.where(np.logical_and(y_inc >= y_extent[0], y_inc <= y_extent[1]))])

@profile
def plot_fields(fields_list, lon, lat, lon_bounds, lat_bounds, main_title, subtitle_list, unit, vmin=None, vmax=None, cmap='BuPu', water_bodies=False, ticks=True, tick_labels=None):

    n_img = len(fields_list)
    img_extent = lon_bounds + lat_bounds

    if not type(unit) is list:
        unit = [unit for i in range(n_img)]

    if not type(cmap) is list:
        cmap = [cmap for i in range(n_img)]

    if vmin == None:
        vmin = [np.nanmin(field) for field in fields_list]
    if vmax == None:
        vmax = [np.nanmax(field) for field in fields_list]

    if not type(vmin) is list:
        vmin = [vmin for i in range(n_img)]
    if not type(vmax) is list:
        vmax = [vmax for i in range(n_img)]       

    if ticks == True:
        ticks = [ticks for i in range(n_img)]
    elif not type(ticks) is list:
        print("Error! Argument 'ticks' must be a list or a list of lists.")
    elif all([isinstance(tt, float) or isinstance(tt, int) for tt in ticks]):
        ticks = [ticks for i in range(n_img)]

    if tick_labels == None:
        tick_labels = [tick_labels for i in range(n_img)]
    elif not type(tick_labels) is list:
        print("Error! Argument 'ticks_labels' must be a list or a list of lists.")
    elif all([isinstance(tt, float) or isinstance(tt, int) or isinstance(tt, str) for tt in tick_labels]):
        tick_labels = [tick_labels for i in range(n_img)]

    r = abs(lon[1]-lon[0])
    lons_mat, lats_mat = np.meshgrid(lon, lat)
    lons_matplot = np.hstack((lons_mat - r/2, lons_mat[:,[-1]] + r/2))
    lons_matplot = np.vstack((lons_matplot, lons_matplot[[-1],:]))
    lats_matplot = np.hstack((lats_mat, lats_mat[:,[-1]]))
    lats_matplot = np.vstack((lats_matplot - r/2, lats_matplot[[-1],:] + r/2))     # assumes latitudes in ascending order

    dlon = (lon_bounds[1]-lon_bounds[0]) // 8
    dlat = (lat_bounds[1]-lat_bounds[0]) // 8

    fig_height = 7.
    fig_width = (n_img*1.15)*(fig_height/1.1)*np.diff(lon_bounds)[0]/np.diff(lat_bounds)[0]

    fig = plt.figure(figsize=(fig_width,fig_height))
    for i_img in range(n_img):
        ax = fig.add_subplot(100+n_img*10+i_img+1, projection=ccrs.PlateCarree())
        cmesh = ax.pcolormesh(lons_matplot, lats_matplot, fields_list[i_img], cmap=cmap[i_img], vmin=vmin[i_img], vmax=vmax[i_img])
        ax.set_extent(img_extent, crs=ccrs.PlateCarree())
        ax.set_yticks(get_yticks(img_extent[2:4],dlat), crs=ccrs.PlateCarree())
        ax.yaxis.set_major_formatter(LatitudeFormatter()) 
        ax.set_xticks(get_xticks(img_extent[0:2],dlon), crs=ccrs.PlateCarree())
        ax.xaxis.set_major_formatter(LongitudeFormatter(zero_direction_label=True))
        ax.add_feature(cfeature.COASTLINE, linewidth=2)
        ax.add_feature(cfeature.BORDERS, linewidth=2, linestyle='-', alpha=.9)
        if water_bodies:
            ax.add_feature(cfeature.LAKES, alpha=0.95)
            ax.add_feature(cfeature.RIVERS)

        plt.title(subtitle_list[i_img], fontsize=14)
        divider = make_axes_locatable(ax)
        ax_cb = divider.new_horizontal(size="5%", pad=0.1, axes_class=plt.Axes)
        fig.add_axes(ax_cb)
        cbar = plt.colorbar(cmesh, cax=ax_cb)
        cbar.set_label(unit[i_img])
        if not ticks[i_img] == True:
            cbar.set_ticks(ticks[i_img])
            cbar.set_ticklabels(tick_labels[i_img])

    fig.canvas.draw()
    plt.tight_layout(rect=[0,0,1,0.95])
    fig.suptitle(main_title, fontsize=16)
    plt.show()



In [None]:
# plot_fields(fields_list = [prob_bn, prob_an],
#           lon = lon,
#           lat = lat,
#           lon_bounds = lon_bnds,
#           lat_bounds = lat_bnds,
#           main_title = f'Predicted tercile probabilities for {season} precipitation amounts',
#           subtitle_list = ['below normal','above normal'],
#           vmin = 0.333,
#           vmax = 1,
#           cmap = ['Oranges','Greens'],
#           unit = '')

In [None]:


#data_dir = '/home/michael/nr/samba/PostClimDataNoBackup/CONFER/EASP/'

# month = month_init                # month in which the forecast is issued; predictors are taken from previous month
# season_name = season
# ref_period = f'{year_clm_start}-{year_clm_end}'


# print(f'\n Generating {season} forecasts at halfdeg resolution')

# # -- Calculate explained variance by the respective EOFs-----------------------------------------------------------------------------------------------------------------------

# refper_start = int(ref_period[:4])
# refper_end = int(ref_period[-4:])

# nc = xr.open_dataset(f'{data_dir}eofs/chirps/halfdeg_res/refper_{ref_period}/prec_loyo_seasonal_{season}.nc')
# nc_subset = nc.sel(loy=slice(years_verif[0],years_verif[-1]))
# eigenvalues = (nc_subset.d.values**2) / (refper_end-refper_start)
# nc.close()

# for iyr in years_verif:
# #    print(iyr)
#     df_target = data_prec.loc[(slice(None),iyr,range(1,ntg+1)),:].unstack().droplevel(level=1).droplevel(level=0, axis=1).fillna(0.).reindex(years_cv)
#     data_indices_yr = []
#     for i in range(len(data_indices)):
#         data_indices_yr.append(data_indices[i].loc[(slice(None),month-1,iyr),:].droplevel(level=(1,2)))
#     df_features = pd.concat(data_indices_yr, axis=1).fillna(0.).reindex(years_cv)
#     y = df_target.to_numpy()
#     X = df_features.to_numpy()
#    # Feature pre-selection
#     feature_idx = [True]+[False]*len(df_features.columns)
#     for ift in range(len(feature_idx)-1):
#         pval = [pearsonr(y[:,ipc], X[:,ift])[1] for ipc in range(ntg)]
#         feature_idx[1+ift] = np.any(np.array(pval)<0.1*wgt.loc[iyr])
#     df_features = df_features.iloc[:,feature_idx[1:]]
#     df_selected_features.loc[iyr,['year']+df_features.columns.to_list()] = 0
#     print(f'{iyr}: {sum(feature_idx)} features selected')
#    # Interactions with b/n/a dummy variables of features
#     if sum(feature_idx) == 1:
#         df_features = df_year
#     else:
#         df_features = pd.concat([df_year, df_features], axis=1)

#     X = df_features.to_numpy()

#     # Lasso regression of pre-selected features
#     clf = MultiTaskLassoCV(cv=cv_folds, fit_intercept=False, max_iter=5000)
#     clf.fit(X, y)
#     df_hyperparameters.loc[iyr,'alpha'] = clf.alpha_
#     df_hyperparameters.loc[iyr,'l1_ratio'] = 1.
#     ind_active = np.all(clf.coef_!=0, axis=0)
#     n_a = sum(ind_active)
#     if n_a > 0:
#         df_selected_features.loc[iyr,df_features.columns.to_list()] = np.where(ind_active, 1, 0)
#         dgf = 1 + n_a
#     else:
#         dgf = 1
#    # In 'full' mode, we mainly care about the estimated coefficients
#     df_coefficients = pd.DataFrame(clf.coef_, index=[f'eof{i}' for i in range(1,ntg+1)], columns=df_features.columns)
#     errors = y - clf.predict(X)
#     df_fl_pred_cov.loc[:,:] = np.broadcast_to((np.dot(errors.T,errors)/(nyr-dgf)).flatten(), (len(years_verif),ntg**2))
#     break


# # if not path.exists(f'{data_dir}fls_pred/chirps/seasonal/halfdeg_res'):
# #     os.mkdir(f'{data_dir}fls_pred/chirps/seasonal/halfdeg_res')

# # if not path.exists(f'{data_dir}fls_pred/chirps/seasonal/halfdeg_res/refper_{ref_period}_cvper_{years_cv[0]}-{years_cv[-1]}'):
#     os.mkdir(f'{data_dir}fls_pred/chirps/seasonal/halfdeg_res/refper_{ref_period}_cvper_{years_cv[0]}-{years_cv[-1]}')

# df_coefficients.to_csv(f'{data_dir}fls_pred/chirps/seasonal/halfdeg_res/refper_{ref_period}_cvper_{years_cv[0]}-{years_cv[-1]}/coefficients_{predictor}_lasso_full_im{month}_{season}.csv')
# df_fl_pred_cov.to_csv(f'{data_dir}fls_pred/chirps/seasonal/halfdeg_res/refper_{ref_period}_cvper_{years_cv[0]}-{years_cv[-1]}/fls_cov_{predictor}_lasso_full_im{month}_{season}.csv')



Run LASSO model to predict precipitation EOFs based on indices.

In [None]:
# do it

Visualize fitted coefficients

In [None]:
# do it

Load indices for the forecast year and use the previously fitted model to make a forecast

In [None]:
#prob_fcst_below, prob_fcst_above = calculate_tercile_probability_forecasts(season, year_fcst, month_init, period_train, period_clm, indices_dir, anomaly_dir, eof_dir, fcst_dir)

Depict as a map.

In [None]:
# plot_fields (fields_list = [prob_fcst_below, prob_fcst_above],
#           lon = lon,
#           lat = lat,
#           lon_bounds = lon_bnds,
#           lat_bounds = lat_bnds,
#           main_title = f'Predicted tercile probabilities for {season} precipitation amounts',
#           subtitle_list = ['below normal','above normal'],
#           vmin = 0.333,
#           vmax = 1,
#           cmap = ['Oranges','Greens'],
#           unit = '')