# Imports

In [None]:
import xarray as xr
import pandas as pd
from keras.models import load_model
import numpy as np
import matplotlib.pyplot as plt

# Make Intervals (Functions)

In this section, we provide methods to create the appropriate confidence intervals. This means, that we will provide a list of lower bounds and a list of upper bounds where index $i$ in each list is a lower bound and upper bound pair.

In [None]:
def resid_bootstrap(y_pred, residuals, n_bs=1000):
    """
    Perform residual bootstrapping to generate confidence intervals for model predictions.

    This function creates bootstrapped samples of residuals, adds them to the predicted values, 
    and calculates the lower and upper bounds of a 95% confidence interval for the predictions. 


    Inputs:
        y_pred (array-like): The predicted values from a model to which the bootstrapped 
            residuals are added.
        residuals (array-like): The residuals (observed - predicted values), which are 
            randomly sampled and added to `y_pred` to generate new bootstrapped predictions.
        n_bs (int): The number of bootstrap samples to generate. Default is 1000.

    Outputs:
        (numpy.ndarray): The lower bounds of the 95% confidence interval for each prediction, 
            calculated at the 2.5th percentile of the bootstrapped predictions.
        (numpy.ndarray): The upper bounds of the 95% confidence interval for each prediction, 
            calculated at the 97.5th percentile of the bootstrapped predictions.
    """
    bootstrap_predictions = []
    for i in range(n_bs):
        bootstrap_sample = np.random.choice(residuals, size=len(residuals))
        bootstrap_pred = y_pred + bootstrap_sample
        bootstrap_predictions.append(bootstrap_pred)

    bootstrap_arr = np.array(bootstrap_predictions)
    lower_bounds = np.percentile(bootstrap_arr, 2.5, axis=0)
    upper_bounds = np.percentile(bootstrap_arr, 97.5, axis=0)

    return lower_bounds, upper_bounds

In [None]:
def wild_bootstrap(y_pred, residuals, n_bs=1000):
    """
    Perform wild bootstrapping to generate confidence intervals for model predictions.

    This function creates bootstrapped samples of residuals, adds them to the predicted values, 
    and calculates the lower and upper bounds of a 95% confidence interval for the predictions. 


    Inputs:
        y_pred (array-like): The predicted values from a model to which the bootstrapped 
            residuals are added.
        residuals (array-like): The residuals (observed - predicted values), which are 
            randomly sampled and added to `y_pred` to generate new bootstrapped predictions.
        n_bs (int): The number of bootstrap samples to generate. Default is 1000.

    Outputs:
        (numpy.ndarray): The lower bounds of the 95% confidence interval for each prediction, 
            calculated at the 2.5th percentile of the bootstrapped predictions.
        (numpy.ndarray): The upper bounds of the 95% confidence interval for each prediction, 
            calculated at the 97.5th percentile of the bootstrapped predictions.
    """
    bootstrap_predictions = []
    
    for i in range(n_bs):
        bootstrap_V_i = np.random.normal(0, 1, size=len(residuals))
        bootstrap_pred = y_pred + bootstrap_V_i * residuals
        bootstrap_predictions.append(bootstrap_pred)
        
    bootstrap_arr = np.array(bootstrap_predictions)
    lower_bounds = np.percentile(bootstrap_arr, 2.5, axis=0)
    upper_bounds = np.percentile(bootstrap_arr, 97.5, axis=0)

    return lower_bounds, upper_bounds

# Plot Intervals (Functions)

Provides code that plots the functions given predictions, residuals, and ground truth

In [None]:
def create_confidence_interval_plots(predictions, residuals, observations, data_label="", file_name=None, normalize=False, y_limit=None):
    """
    Generate and plot confidence intervals using both residual and wild bootstrapping methods, 
    and calculate coverage rates for the predictions.

    Inputs:
        predictions (numpy.ndarray): The predicted values from the model.
        residuals (numpy.ndarray): The residuals (observed - predicted values) used for bootstrapping.
        observations (numpy.ndarray): The true observed values used for sorting and coverage rate calculation.
        data_label (str): A label for the plot title (default is "").
        file_name (str): If provided, the plot will be saved with this filename (default is None).
        normalize (bool): If True, predictions and confidence intervals will be normalized by the observations (default is False).
        y_limit (float): The y-axis limit for the plots (default is None).

    Outputs:
        None: The function plots the confidence intervals and prints the coverage rate 
              for both residual and wild bootstrapping methods.
    """

    # Plot settings
    pred_color = "green"
    obs_color = "magenta"
    conf_color = "skyblue"
    linewidth = 4
    title_size = 18
    axis_size = 17
    legend_size = 16
    tick_size = 15
    
    # Generate bootstrap confidence intervals
    lower_resid_ci, upper_resid_ci = resid_bootstrap(predictions, residuals)
    lower_wild_ci, upper_wild_ci = wild_bootstrap(predictions, residuals)

    # Sort by observations for a clearer plot
    sorted_indices = np.argsort(observations)
    sorted_observations = observations[sorted_indices]
    sorted_pred = predictions[sorted_indices]
    sorted_lower_resid_ci = lower_resid_ci[sorted_indices]
    sorted_upper_resid_ci = upper_resid_ci[sorted_indices]

    # Normalize data if needed
    if normalize:
        sorted_pred = sorted_pred / sorted_observations
        sorted_lower_resid_ci = sorted_lower_resid_ci / sorted_observations
        sorted_upper_resid_ci = sorted_upper_resid_ci / sorted_observations
        sorted_observations = np.ones_like(sorted_observations)

    # Create the x-axis values
    x_values = np.arange(len(sorted_observations))

    # Plot Residual Bootstrap CI
    plt.figure(figsize=(10, 7))
    plt.plot(x_values, sorted_pred, label='Predicted', color=pred_color)
    plt.plot(x_values, sorted_observations, label='Observed', color=obs_color, linewidth=linewidth)
    
    if y_limit:
        plt.ylim(-5, y_limit)
        
    plt.fill_between(x_values, sorted_lower_resid_ci, sorted_upper_resid_ci, color=conf_color, alpha=1, label='RCI')
    plt.xlabel(r'Observations (sorted by increasing $\sigma_W$)', fontsize=axis_size)
    plt.ylabel('Observation Values (m/s)', fontsize=axis_size)
    plt.title(data_label, fontsize=title_size)
    plt.legend(fontsize=legend_size)
    plt.tick_params(axis='both', which='major', labelsize=tick_size)

    if file_name:
        plt.savefig(f"rci_{file_name}.png", dpi=300)
    
    plt.show()

    # Calculate coverage rate for Residual Bootstrap CI
    within_ci = (sorted_observations >= sorted_lower_resid_ci) & (sorted_observations <= sorted_upper_resid_ci)
    coverage_rate_resid = np.sum(within_ci) / len(sorted_observations)
    print(f"Coverage rate for RCI: {coverage_rate_resid:.2%}")

    # Reset data (from normalization) and sort Wild Bootstrap CI
    sorted_pred = predictions[sorted_indices]
    sorted_observations = observations[sorted_indices]
    sorted_lower_wild_ci = lower_wild_ci[sorted_indices]
    sorted_upper_wild_ci = upper_wild_ci[sorted_indices]

    # Normalize data if needed
    if normalize:
        sorted_pred = sorted_pred / sorted_observations
        sorted_lower_wild_ci = sorted_lower_wild_ci / sorted_observations
        sorted_upper_wild_ci = sorted_upper_wild_ci / sorted_observations
        sorted_observations = np.ones_like(sorted_observations)

    # Plot Wild CI
    plt.figure(figsize=(10, 7))
    plt.plot(x_values, sorted_pred, label='Predicted', color=pred_color)
    plt.plot(x_values, sorted_observations, label='Observed', color=obs_color, linewidth=linewidth)
    
    if y_limit:
        plt.ylim(-5, y_limit)

    plt.fill_between(x_values, sorted_lower_wild_ci, sorted_upper_wild_ci, color=conf_color, alpha=1, label='Confidence Interval')
    plt.xlabel(r'Observations (sorted by increasing $\sigma_W$)', fontsize=axis_size)
    plt.ylabel('Observation Values (m/s)', fontsize=axis_size)
    plt.title(data_label, fontsize=title_size)
    plt.legend(fontsize=legend_size)
    plt.tick_params(axis='both', which='major', labelsize=tick_size)
    
    if file_name:
        plt.savefig(f"wci_{file_name}.png", dpi=300)
        
    plt.show()

    # Calculate coverage rate for Wild Bootstrap CI
    within_ci = (sorted_observations >= sorted_lower_wild_ci) & (sorted_observations <= sorted_upper_wild_ci)
    coverage_rate_wild = np.sum(within_ci) / len(sorted_observations)
    print(f"Coverage rate for WCI: {coverage_rate_wild:.2%}")

In [None]:
# depends on the variables defined in the box above
def get_site_vars(site_val, lev):
    '''
    Retrieves site-specific vertical wind velocity data and model predictions at a
    given atmospheric level and calculates teh residuals between observations/ground
    truth and predictions for the selected site.

    Inputs:
        site_val (int): Index of the site in the predefined list of sites. The 
            valid indices are: 0: "asi", 1: "cor", 2: "nsa", 3: "sgp_cirrus", 4: 
            "sgp_pbl".
        lev (int): The vertical level to retrieve data from. This is 1-based and is 
            used to select level in the data for both observations and predictions.

    Outputs:
        site_resid (numpy.ndarray): Residuals between the observed data and the Wnet 
            predictions for the specified site and level
        site_resid_prior (numpy.ndarray): Residuals between the osberved data and the 
            Wnet-prior predictions for the specified site and level
        site_wnet (numpy.ndarray): The predicted sigma_W values from Wnet
        site_wnet_prior (numpy.ndarray): The predicted sigma_W values from Wnet-prior
        site_observed (numpy.ndarray): The observed sigma_W values for the specified 
            site and level.

    '''
    # retrieve observations by site and level
    sites = ["asi", "cor", "nsa", "sgp_cirrus", "sgp_pbl"]
    site = sites[site_val]

    # Retrieve predictions and reshape
    X, obs = get_data(site)
    X = X.sel(lev = lev - 1)
    heights = obs.coords['height'].values
    obs = obs.sel(height = heights[lev - 1])
    
    wnet_prior_preds = wnet_prior.predict(X, batch_size=2048)
    wnet_preds = wnet.predict(X, batch_size=2048)
    wnet_predictions = wnet_preds.reshape(-1)
    wnet_prior_predictions = wnet_prior_preds.reshape(-1)

    # drop NA
    mask = obs != 0
    obs_filtered = obs[mask]

    # filter and assign site values
    site_observed = obs_filtered.values
    site_wnet = wnet_predictions[mask]
    site_wnet_prior = wnet_prior_predictions[mask]

    # calculate residuals
    site_resid = site_observed - site_wnet
    site_resid_prior = site_observed - site_wnet_prior

    return site_resid, site_resid_prior, site_wnet, site_wnet_prior, site_observed

# Process Data

Since G5NR data is provided in Xarray files, we preprocess the data before performing inference. This section provides the correct functions to process data.


In [None]:
# reading in MERRA and ASR data the same way as PDF_bysite.py
def outlier(x):
    '''
    Calculates the standardized anomaly (z-score) of the input data along the 'time' dimension. 
    Returns the absolute value of the z-score to detect outliers.

    Inputs:
        x (xarray.DataArray): Input data with a 'time' dimension.

    Outputs:
        xarray.DataArray: Absolute value of the z-score of the input data, used for outlier detection.
    '''
    return abs((x-x.mean(dim='time')) / x.std(dim='time')) # causes some RuntimeWarnings (division by 0/nan)


def standardize(ds):
    '''
    Standardizes the input dataset based on predefined means and standard deviations for specific variables.

    Inputs:
        ds (xarray.Dataset): Input dataset containing variables to be standardized.

    Outputs:
        xarray.Dataset: Standardized dataset where each variable is transformed using precomputed mean and 
            standard deviation.
    '''
    i = 0
    m = [243.9, 0.6, 6.3, 0.013, 0.0002, 5.04, 21.8, 0.002, 9.75e-7, 7.87e-6]  #hardcoded from G5NR
    s = [30.3, 0.42, 16.1, 7.9, 0.05, 20.6, 20.8, 0.0036, 7.09e-6, 2.7e-5]
    
    for v in  ds.data_vars:
        ds[v] = (ds[v] - m[i])/s[i]
        i = i+1
    return ds
    
# sites = asi, cor, nsa, sgp_cirrus, sgp_pbl
def get_data(site='', chunk_size=512*72):
    '''
    Loads and processes MERRA and ASR datasets for the specified site, performs standardization, 
    outlier removal, and data alignment between observations and model inputs.

    Inputs:
        site (str): The site identifier for which data is to be retrieved (e.g., 'asi', 'cor'). Default is an empty string.
        chunk_size (int): The chunk size to be used for loading the data. Default is 512*72.

    Outputs:
        X (xarray.DataArray): Standardized and aligned model input data, reshaped and stacked into a 2D array for the model.
        y (xarray.DataArray): Target data (observed vertical wind velocity standard deviation) aligned with the input data.
    '''
    file_path = "" # Insert file_path location
    path_asr = f'{file_path}/Wstd_asr_resampled_stdev30min_72lv_{site}.nc' 
    path_merra = f"{file_path}/Merra_input_asr_72lv_{site}.nc"

    data_obs = xr.open_mfdataset(path_asr, parallel=True)
    data_merra = xr.open_mfdataset(path_merra, parallel=True, chunks={"time": 2560})

    # ================= process obs ====================
    data_obs = data_obs.where((data_obs != -9999.) and (data_obs < 15.))
    data_std = (data_obs.where(data_obs > 0.001)).groupby('time.month').map(outlier)
    data_obs = data_obs.where(data_std < 2.5) 
    data_obs = data_obs.dropna('time', how='all', thresh=2) 
    data_obs = data_obs.fillna(0) 
    
    # ================= process merra ==================
    data_merra = data_merra.resample(time="5min").interpolate("linear") # calculates SD
    data_merra = data_merra[['T', 'AIRD', 'U', 'V', 'W', 'KM', 'RI', 'QV', 'QI', 'QL']] 
    
    # ================= align merra w/ obs ===================
    data_merra, data_obs = xr.align(data_merra, data_obs, exclude = {'height', 'lev'}) 

    # ================= prep model input X (standardize, add 4 surface vars) =================
    X = xr.map_blocks(standardize, data_merra, template=data_merra) 
    
    levs = X.coords['lev'].values
    num_levs = len(levs) # should be 72
    surface_vars = ['AIRD', 'KM', 'RI', 'QV']
    for sv in surface_vars:
        sv_row = X[sv].sel(lev=[71]).squeeze() 
        
        X_sv = sv_row
        for i in range(num_levs - 1):
            X_sv = xr.concat([X_sv, sv_row], dim='lev')

        X[sv + "_sfc"] = X_sv.assign_coords(lev=levs)
    
    # ==================== clean up input X ========================
    X = X.unify_chunks()
    X = X.to_array()
    X = X.rename({'variable':'feature'}) 
    X = X.stack(s=('time', 'lev'))
    X = X.squeeze() 
    X = X.transpose()
    X = X.chunk({'s': 72*1024}) 
    
    # ==================== clean up target Y =======================
    y = data_obs['W_asr_std'] 
    y = y.stack(s=('time', 'height'))
    y = y.chunk({'s': 72*1024})

    return X.load(), y.load() 
    

# Code to Create Interval Plots

## Load the Data

In [None]:
modelpath_wnet = "" # Insert path name to Wnet model
modelpath_wnet_prior = "" # Insert path name to Wnet-prior
wnet = load_model(modelpath_wnet, compile=False)
wnet_prior = load_model(modelpath_wnet_prior, compile=False)

## Plots

In [None]:
site_resid, site_resid_prior, site_wnet, site_wnet_prior, site_observed = get_site_vars(4, 72)
create_confidence_interval_plots(site_wnet, site_resid, site_observed, "72", normalize=True, y_limit = 98)
create_confidence_interval_plots(site_wnet_prior, site_resid_prior, site_observed, "Wnet-prior 72", normalize=True, y_limit = 98)

In [None]:
site_resid, site_resid_prior, site_wnet, site_wnet_prior, site_observed = get_site_vars(3, 48)
create_confidence_interval_plots(site_wnet, site_resid, site_observed, "48", normalize=True)
create_confidence_interval_plots(site_wnet_prior, site_resid_prior, site_observed, "Wnet-prior 72", normalize=True)