# Save results for predicting Radiation from SSTs and its ingredients
authors: Maria Rugenstein, Elizabeth A. Barnes, and Senne Van Loon

date: March 7, 2024

## Python stuff

In [None]:
import sys, os
import gc

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import random

from sklearn import metrics

import experiment_settings, xai
import file_methods, plots, data_processing

import matplotlib as mpl


import cartopy as ct
from cartopy.util import add_cyclic_point
from cmcrameri import cm

import matplotlib as mpl

plt.style.use("default")
mpl.rcParams["savefig.facecolor"] = "white"
mpl.rcParams["figure.dpi"] = 150
savefig_dpi = 300

In [None]:
print(f"python version = {sys.version}")
print(f"numpy version = {np.__version__}")
print(f"xarray version = {xr.__version__}")
print(f"tensorflow version = {tf.__version__}")

In [None]:
from DIRECTORIES import MODEL_DIRECTORY, PREDICTIONS_DIRECTORY, NETCDF_DIRECTORY, DATA_DIRECTORY

SAVE_NETCDF = True

## User Choices

In [None]:
CMIP_model = "MPI"
# base_exp = "_IV"
base_exp = "_hist_rcp85"
suffix = ""

savename_prefix = "R_"+CMIP_model+base_exp+suffix
experiment_dict = {
    "InternalVariability": CMIP_model+"_IV"+suffix,
    "InternalVariability_linear": CMIP_model+"_IV_linear",
}
if CMIP_model == "MPI":
    experiment_dict["greens_function"]= "MPI_GF"

## Load the data and saved model

In [None]:
# BASE DATA
settings_base = experiment_settings.get_settings(CMIP_model+base_exp)

# Set seeds
rng_seed = settings_base["rng_seed_list"][0]
settings_base["rng_seed"] = rng_seed
tf.random.set_seed(settings_base["rng_seed"])
random.seed(settings_base["rng_seed"])
np.random.seed(settings_base["rng_seed"])

# GET THE DATA
(
    _,
    _,
    x_test,
    _,
    _,
    labels_test,
    lat,
    lon,
    map_shape,
    member_shape,
    time_shape,
    _,
    member_enrollment,
    _,
) = data_processing.get_cmip_data(
    DATA_DIRECTORY,
    settings_base,
    n_train_val_test=settings_base["n_train_val_test"],
)
x_test = x_test.astype('float32')

sims_dict = {}
for exp_type in experiment_dict.keys():
    print("*********  "+exp_type+"  *********")

    # Get model settings
    settings = experiment_settings.get_settings(experiment_dict[exp_type])

    # Load model
    settings["rng_seed"] = rng_seed
    model_name = file_methods.get_model_name(settings)
    if not os.path.exists(MODEL_DIRECTORY + model_name + "_model"):
        raise RuntimeError("No such model experiment: " + model_name)
    model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)

    # Load predictions
    predictions = file_methods.load_predictions(PREDICTIONS_DIRECTORY + model_name + "_predictions.pickle")

    # Make OOS predictions
    predictions["labels_oos_test"] = labels_test
    predictions["pred_oos_test"] = model.predict(x_test)

    # Metrics (Validation, testing, OOS)
    mse = metrics.mean_squared_error(
            predictions["labels_val"], predictions["pred_val"]
        )
    r2 = metrics.r2_score(
                predictions["labels_val"], predictions["pred_val"]
            )
    print("Validation:  "\
         +"mse = "+plots.num_lab(mse,4)\
         +", r2 = "+plots.num_lab(r2,4))
    
    mse = metrics.mean_squared_error(
            predictions["labels_test"], predictions["pred_test"]
        )
    r2 = metrics.r2_score(
                predictions["labels_test"], predictions["pred_test"]
            )
    print("Testing:     "\
         +"mse = "+plots.num_lab(mse,4)\
         +", r2 = "+plots.num_lab(r2,4))
    
    mse = metrics.mean_squared_error(
            predictions["labels_oos_test"], predictions["pred_oos_test"]
        )
    r2 = metrics.r2_score(
                predictions["labels_oos_test"], predictions["pred_oos_test"]
            )
    print("Testing OOS: "\
         +"mse = "+plots.num_lab(mse,4)\
         +", r2 = "+plots.num_lab(r2,4))
    print("")

    # Save settings + predictions
    sims_dict[exp_type] = {
        "model": model,
        "predictions": predictions,
        "settings": settings,
        "type": exp_type,
    }

    # clear things out
    __ = gc.collect()

print(sims_dict.keys())

## Save predictions

In [None]:
filename = NETCDF_DIRECTORY + savename_prefix + "_predictions_test.nc"

# Define years
years = np.arange(settings_base["yr_bounds"][0], settings_base["yr_bounds"][1] + 1)

# Save labels
data_vars_dict = {}
data_vars_dict = dict(
    labels=(
        ["member", "year"],
        labels_test.reshape(
            settings_base["n_train_val_test"][2], time_shape
        ),
    ),
)

# Save OOS predictions
for exp_type in sims_dict.keys():
    data_vars_dict[exp_type] = (
        ["member", "year"],
        sims_dict[exp_type]["predictions"]["pred_oos_test"].reshape(
            sims_dict[exp_type]["settings"]["n_train_val_test"][2], time_shape
        ),
    )

# Make Dataset
ds = xr.Dataset(
    data_vars=data_vars_dict,
    coords=dict(
        year=years,
        member=member_enrollment[2],
    ),
    attrs=dict(description="Data from experiments: " + str(list(experiment_dict.values()))),
)

# Save to file
if SAVE_NETCDF:
    ds.to_netcdf(filename, format="NETCDF4")
display(ds)


## Plot predictions

In [None]:
plt.figure(figsize=(20, 7), dpi=250)
plt.plot(ds["year"], ds["labels"][0, :], label="labels", linestyle="-", linewidth=10, alpha=0.5, color="gray")

for varname, da in ds.data_vars.items():
    if varname == "labels":
        continue
    if "nonlinear" in varname:
        lt, alpha = "-", 0.75
    else:
        lt, alpha = "--", 1.0
    plt.plot(ds["year"], da[0, :], label=varname, linestyle=lt, alpha=alpha, linewidth=2.5)

plt.title("member #0")
plt.legend()


In [None]:
if base_exp[:5] == '_1pct' or base_exp == "_IV":
    raise ValueError()

## COMPUTE GRADIENTS
Gradients are computed on in-sample testing members, and averaged over all members

In [None]:
savefilename = NETCDF_DIRECTORY + "R_" + CMIP_model + suffix + "_xaiGrads_test.nc"

# Define years and evaluating data
years = np.arange(settings["yr_bounds"][0], settings["yr_bounds"][1] + 1)

data_vars_dict = {}
for exp_type in sims_dict.keys():
    # Get in-sample test data
    (_,_,x_eval,_,_,_,_,_,_,_,_,_,_,_,) = data_processing.get_cmip_data(
        DATA_DIRECTORY,
        sims_dict[exp_type]["settings"],
        n_train_val_test=sims_dict[exp_type]["settings"]["n_train_val_test"],
    )
    x_eval = x_eval.astype('float32')

    # Reshape x_eval (otherwise leads to memory issues)
    x_eval = x_eval.reshape((sims_dict[exp_type]["settings"]["n_train_val_test"][2], time_shape, map_shape[0], map_shape[1],1))

    # Initialize gradients
    grads = np.empty((sims_dict[exp_type]["settings"]["n_train_val_test"][2], time_shape, map_shape[0], map_shape[1]))

    # Compute gradients
    for mm in range(sims_dict[exp_type]["settings"]["n_train_val_test"][2]):
        grads[mm,:,:,:] = xai.get_gradients(sims_dict[exp_type]["model"], \
                                             x_eval[mm,:,:,:,:], pred_idx=0)

    # Apply mask to gradient
    mask = xr.load_dataarray(DATA_DIRECTORY + settings["input_region"]).to_numpy()
    grads = grads * mask
    grads = grads.astype('float32')
    
    # Take ensemble mean
    grads = grads.mean(axis=0)
    grads[grads == 0] = np.nan
    data_vars_dict[exp_type] = (["year", "lat", "lon"], grads.copy())

# ---------------------------------------------------------------------------
# Create dataset
ds = xr.Dataset(
    data_vars=data_vars_dict,
    coords=dict(
        year=years,
        lat=lat.astype('float32'),
        lon=lon.astype('float32'),
    ),
    attrs=dict(description="Data from experiments: " + str(list(experiment_dict.values()))),
)
if SAVE_NETCDF:
    ds.to_netcdf(savefilename, format="NETCDF4")
display(ds)


In [None]:
l = np.linspace(-0.0013,0.0013,27)
t = np.linspace(-0.001 ,0.001 ,3)
norm = mpl.colors.BoundaryNorm(l, cm.vik.N)

for varname, da in ds.data_vars.items():

    x_plot, lons_cyc = add_cyclic_point(da.to_numpy().mean(axis=0), coord=lon)

    # Setup figure
    fig, ax = plots.setup_figure(nCols=1,nRows=1,size=(5,4),mask=True)

    cf = ax.pcolormesh(lons_cyc,lat,x_plot\
                      ,norm=norm\
                      ,transform=plots.data_crs\
                      ,cmap=cm.vik)
    cb = plt.colorbar(cf,ax=ax, orientation = "horizontal",shrink=1.0, extend='both',ticks=t)
    cb.set_label("Gradient [W/m$^2$/K]")
    ax.set_title(varname)

## COMPUTE ATTRIBUTION
Gradients are computed on in-sample testing members, multiplied by out-of-sample SST, and averaged over all members

In [None]:
savefilename = NETCDF_DIRECTORY + savename_prefix + "_xaiAttribution_test.nc"

# Define years and evaluating data
years = np.arange(settings["yr_bounds"][0], settings["yr_bounds"][1] + 1)

x_test_ = x_test.reshape((settings_base["n_train_val_test"][2], time_shape, map_shape[0], map_shape[1],1))

data_vars_dict = {}
for exp_type in sims_dict.keys():

    # Initialize gradients
    grads = np.empty((sims_dict[exp_type]["settings"]["n_train_val_test"][2], time_shape, map_shape[0], map_shape[1]))

    # Compute gradients * SST
    for mm in range(sims_dict[exp_type]["settings"]["n_train_val_test"][2]):
        grads[mm,:,:,:] = xai.get_gradients(sims_dict[exp_type]["model"], \
                            x_test_[mm,:,:,:,:], pred_idx=0) \
                          * np.squeeze(x_test_)[mm,:,:,:]

    # Apply mask
    mask = xr.load_dataarray(DATA_DIRECTORY + settings["input_region"]).to_numpy()
    grads = grads * mask
    grads = grads.astype('float32')

    # Take ensemble mean
    grads = grads.mean(axis=0)
    grads[grads == 0] = np.nan
    data_vars_dict[exp_type] = (["year", "lat", "lon"], grads.copy())

# ---------------------------------------------------------------------------
# Create dataset
ds = xr.Dataset(
    data_vars=data_vars_dict,
    coords=dict(
        year=years,
        lat=lat,
        lon=lon,
    ),
    attrs=dict(description="Data from experiments: " + str(list(experiment_dict.values()))),
)
if SAVE_NETCDF:
    ds.to_netcdf(savefilename, format="NETCDF4")
display(ds)


In [None]:
l = np.linspace(-0.002,0.002,21)
t = np.linspace(-0.002,0.002,5)
norm = mpl.colors.BoundaryNorm(l, cm.vik.N)

for varname, da in ds.data_vars.items():

    x_plot, lons_cyc = add_cyclic_point(da.to_numpy().mean(axis=0), coord=lon)

    # Setup figure
    fig, ax = plots.setup_figure(nCols=1,nRows=1,size=(5,4),mask=True)

    cf = ax.pcolormesh(lons_cyc,lat,x_plot\
                      ,norm=norm\
                      ,transform=plots.data_crs\
                      ,cmap=cm.vik)
    cb = plt.colorbar(cf,ax=ax, orientation = "horizontal",shrink=1.0, extend='both',ticks=t)
    cb.set_label("Attribution [W/m$^2$]")
    ax.set_title(varname)

fig, ax = plt.subplots(1,1,figsize=(10, 6))
for varname, da in ds.data_vars.items():
    ax.plot(da['year'],np.nansum(da.to_numpy(),axis=(1,2)),label=varname)
ax.set_xlabel("year")
ax.set_ylabel("Radiation [W/m$^2$]")
ax.set_xlim(da['year'][0],da['year'][-1])
ax.legend()