# Imports

In [None]:
import xarray as xr
import pandas as pd
from keras.models import load_model
import numpy as np
import matplotlib.pyplot as plt

# Read in Pre-trained Models

In [None]:
# NEED TO INSERT FILEPATHS TO LOAD MODEL
model0 = load_model("", compile=False)
model1 = load_model("", compile=False)
model2 = load_model("", compile=False)
model3 = load_model("", compile=False)
model4 = load_model("", compile=False)
model5 = load_model("", compile=False)
model6 = load_model("", compile=False)
model7 = load_model("", compile=False)
model8 = load_model("", compile=False)
model9 = load_model("", compile=False)
modelbase = load_model("", compile=False)

# Process Data

In [None]:
### Process the data

batch_size = 61440

US_LOCS = {"lat1": 25, 
           "lat2": 50,
           "lon1": -150,
           "lon2": -50}
COORDS = {"US": US_LOCS}

MEANS = [243.9, 0.6, 6.3, 0.013, 0.0002, 5.04, 21.8, 0.002, 9.75e-7, 7.87e-6]
STDS = [30.3, 0.42, 16.1, 7.9, 0.05, 20.6, 20.8, 0.0036, 7.09e-6, 2.7e-5]
SURF_VARS = ['AIRD', 'KM', 'RI', 'QV']

def standardize(ds, s, m):
    assert len(list(ds.data_vars)) == len(m)

    # data_vars are ['T', 'AIRD', 'U', 'V', 'W', 'KM', "RI', 'QV', 'QI', 'QL']
    for i, var in  enumerate(ds.data_vars):  
        ds[var] = (ds[var] - m[i])/s[i]

    return ds

In [None]:
file_path = "" # INSERT FILEPATH FOR G5NR DATA
global_data = xr.open_mfdataset(file_path)
global_data = global_data.where(global_data['lev'] != 0, drop=True)
global_data = global_data.sel(lat=slice(COORDS["US"]["lat1"], COORDS["US"]["lat2"]), 
                              lon=slice(COORDS["US"]["lon1"], COORDS["US"]["lon2"]))

times = [""] # INSERT DESIRED TIMESTAMP IN A LIST[STR] FORMAT

data_in = global_data.sel(time = times)
data_in = data_in[['T', 'AIRD', 'U', 'V', 'W', 'KM', 'RI', 'QV', 'QI', 'QL']]
data_in = xr.map_blocks(standardize, data_in, kwargs = {"m":MEANS, "s": STDS}, template = data_in)
data_in = data_in # this is a DataSet

data_out = global_data.sel(time = times)
data_out = data_out['Wstd'] # this is a DataArray

Xall = data_in
yall = data_out
levs = Xall.coords['lev'].values

for var in SURF_VARS:
    Xs = Xall[var].sel(lev = [71]) # 1 level above surface
    Xsfc = Xs
    
    for lev in range(len(levs)-1):
        Xsfc = xr.concat([Xsfc, Xs], dim='lev')
        
    Xsfc = Xsfc.assign_coords(lev=levs)
    Xall[f"{var}_sfc"] = Xsfc

Xall =  Xall.unify_chunks()
Xall = Xall.to_array()
Xall = Xall.stack( s = ('time', 'lat', 'lon', 'lev')) 
Xall = Xall.rename({"variable":"ft"})                       
Xall = Xall.squeeze()
Xall = Xall.transpose()
Xall = Xall.chunk({"ft":14, "s": batch_size})

yall = yall.stack(s = ('time', 'lat', 'lon', 'lev' ))
yall =  yall.squeeze()
yall =  yall.transpose()   
yall =  yall.chunk({"s": batch_size})

X = Xall.load()

y_hat0 = model0.predict(Xall, batch_size = 2048)
y_hat1 = model1.predict(Xall, batch_size = 2048)
y_hat2 = model2.predict(Xall, batch_size = 2048)
y_hat3 = model3.predict(Xall, batch_size = 2048)
y_hat4 = model4.predict(Xall, batch_size = 2048)
y_hat5 = model5.predict(Xall, batch_size = 2048)
y_hat6 = model6.predict(Xall, batch_size = 2048)
y_hat7 = model7.predict(Xall, batch_size = 2048)
y_hat8 = model8.predict(Xall, batch_size = 2048)
y_hat9 = model9.predict(Xall, batch_size = 2048)
y_hat = modelbase.predict(Xall)

## Prepare Confidence Intervals

In [None]:
### Prepare the data!

predictions = [y_hat0, y_hat1, y_hat2, y_hat3, y_hat4, y_hat5, y_hat6, y_hat7, y_hat8, y_hat9]
predictions_stacked = np.stack(predictions)

lb = np.percentile(predictions_stacked,0, axis=0)
ub = np.percentile(predictions_stacked, 100, axis=0)
lb = lb.squeeze()
ub = ub.squeeze()

y = yall.values.reshape((738072, 1))
y = y.squeeze()

mask = (y >= lb) & (y <= ub)
proportion_between = np.sum(mask) / len(y)
print("Coverage Rate of Confidence Interval", proportion_between)
sorted_indices = np.argsort(y)

sorted_lb = lb[sorted_indices]
sorted_ub = ub[sorted_indices]
sorted_y = y[sorted_indices]

# normalize
sorted_lb = sorted_lb/sorted_y
sorted_ub = sorted_ub/sorted_y
sorted_y = sorted_y/sorted_y

sorted_y = sorted_y.squeeze()
sorted_lb = sorted_lb.squeeze()
sorted_ub = sorted_ub.squeeze()

# Plot Data

In [None]:
axis_font = 17
label_font = 15
legend_font = 16
title_font = 18

In [None]:
# Set x-axis for the plot
x = np.arange(len(sorted_y))

# Create the plot
plt.figure(figsize=(10, 7))
plt.plot(x, sorted_y, label=r'G5NR', color='magenta', linewidth=4)
plt.fill_between(x, sorted_lb, sorted_ub, color='skyblue', alpha=1, label='CI')

# Add labels and legend
plt.xlabel("Sorted and Normalized $\sigma_W$", fontsize=axis_font)
plt.ylabel(r'$\sigma_W$ Values (m/s)', fontsize=axis_font)
plt.tick_params(axis='both', which='major', labelsize=label_font)
title = "Confidence Interval for Wnet-prior (US, 1 timestamp)" 
plt.title(title, fontsize=title_font)
plt.legend(fontsize = legend_font)
# plt.savefig("CI_normalized.png", dpi=300)
plt.show()

## Plotting Miscoverage

In [None]:
# Create a boolean mask where y is not greater than or equal to lb and less than or equal to ub
mask_not_between = (sorted_y < sorted_lb) | (sorted_y > sorted_ub)

# Use np.where to find the indices where the condition is True
indices_not_between = np.where(mask_not_between)[0]

In [None]:
filtered_y = sorted_y[indices_not_between][-500:]
filtered_lb = sorted_lb[indices_not_between][-500:]
filtered_ub = sorted_ub[indices_not_between][-500:]

x = np.arange(len(filtered_y))

# Create the plot
plt.figure(figsize=(10, 7))
plt.plot(x, filtered_y, label='G5NR', color='magenta', linewidth=4)
plt.fill_between(x, filtered_lb, filtered_ub, color='skyblue', alpha=1, label='CI')
plt.tick_params(axis='both', which='major', labelsize=label_font)
# Add labels and legend
plt.xlabel(r'Sorted and Normalized $\sigma_W$', fontsize=axis_font)
plt.ylabel(r'$\sigma_W$ Values (m/s)', fontsize=axis_font)

title = r"Misclassified Observations - Large $\sigma_W$"
plt.title(title, fontsize = title_font)
plt.legend(fontsize = legend_font)
# plt.savefig("CI_large.png", dpi=300)
plt.show()

In [None]:
filtered_y = sorted_y[indices_not_between][:500]
filtered_lb = sorted_lb[indices_not_between][:500]
filtered_ub = sorted_ub[indices_not_between][:500]

x = np.arange(len(filtered_y))

# Create the plot
plt.figure(figsize=(10, 7))
plt.plot(x, filtered_y, label='G5NR', color='magenta', linewidth=4)
plt.fill_between(x, filtered_lb, filtered_ub, color='skyblue', alpha=1, label='CI')
plt.tick_params(axis='both', which='major', labelsize=label_font)
# Add labels and legend
plt.xlabel(r'Sorted and Normalized $\sigma_W$', fontsize=axis_font)
plt.ylabel(r'$\sigma_W$ Values (m/s)', fontsize=axis_font)

title = r"Misclassified Observations - Small $\sigma_W$"
plt.title(title, fontsize=title_font)
plt.legend(fontsize =legend_font)
# plt.savefig("CI_small.png", dpi=300)
plt.show()