In [None]:
import sys, os
import importlib as imp

import xarray as xr
import pickle
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import random

import sklearn
from sklearn import metrics

import experiment_settings, xai
import file_methods, plots, data_processing
from DIRECTORIES import MODEL_DIRECTORY, DATA_DIRECTORY

import matplotlib as mpl
import cartopy as ct
from cmcrameri import cm
from cartopy.util import add_cyclic_point

savefig_dpi = 300

In [None]:
SAVE_FILES = False
SAVE_FILES_NC = False
SAVE_FILES_TAS = False
SAVE_FILES_AVGGRAD = False
show_colorbar = True

In [None]:
# Name of data files for obs testing
DATA_NAMES = (
    'DATA_MPI_hist_rcp85_sx', 
    'DATA_CanESM5_hist_ssp245_sx', 
    'DATA_IPSL_hist_ssp245_sx', 
    'DATA_MIROC6_hist_ssp245_sx', 
    'DATA_1pctC02_sx', 
)

# Names of the trained networks
NETW_NAMES = (
    'MultiModel_hpt0_s3', 
)

# Number of networks and datasets
NUM_NETW = len(NETW_NAMES)
NUM_DATA = len(DATA_NAMES)

imp.reload(data_processing)
imp.reload(experiment_settings)
imp.reload(file_methods)

# Initialize prediction dictionary
sims_dict = {}

# Load network models
for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
    print("*********  "+EXP_NAME_net+"  *********")

    # Get settings of network
    settings_net = experiment_settings.get_settings(EXP_NAME_net[:-3])
    settings_net["rng_seed"] = int(EXP_NAME_net[-1])

    # Set seeds
    tf.random.set_seed(settings_net["rng_seed"])
    random.seed(settings_net["rng_seed"])
    np.random.seed(settings_net["rng_seed"])

    # Load network
    model_name = file_methods.get_model_name(settings_net)
    if not os.path.exists(MODEL_DIRECTORY + model_name + "_model"):
        raise RuntimeError("No such model experiment: " + model_name)
    model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)

    # Save network
    sims_dict[EXP_NAME_net] = {
        "model": model,
        "settings": settings_net,
    }

# Loop over datasets
for j_dat, EXP_NAME_dat in enumerate(DATA_NAMES):
    print("*********  "+EXP_NAME_dat+"  *********")

    # Get data settings & name
    settings_dat = experiment_settings.get_settings(EXP_NAME_dat[:-3])
    model_name = file_methods.get_model_name(settings_dat)

    # Loop over networks
    for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
        settings_net = sims_dict[EXP_NAME_net]['settings']
        model_name_net = file_methods.get_model_name(settings_net)

        if EXP_NAME_dat[-1] == 'x':
            settings_dat["rng_seed"] = settings_net["rng_seed"]
        else:
            settings_dat["rng_seed"] = int(EXP_NAME_dat[-1])

        if "subtract_val" in sims_dict[EXP_NAME_net]['settings'].keys():
            if sims_dict[EXP_NAME_net]['settings']["subtract_val"]:
                settings_dat["subtract_val"] = model_name_net+".pickle"
            
            with open(MODEL_DIRECTORY+settings_dat["subtract_val"], 'rb') as f:
                _ = pickle.load(f)
                y_subtract = pickle.load(f)
        else:
            y_subtract = np.array([0.])
        
        # GET THE DATA
        (
            _,
            _,
            x_test,
            _,
            _,
            labels_test,
            lat,
            lon,
            map_shape,
            member_shape,
            time_shape,
            _,
            member_enrollment,
            _,
        ) = data_processing.get_cmip_data(
            DATA_DIRECTORY,
            settings_dat,
            n_train_val_test=settings_dat["n_train_val_test"],
        )
        # Number of testing members
        test_shape = settings_dat["n_train_val_test"][2]*len(settings_dat["all_models"])

        # Data years
        years = np.arange(settings_dat["yr_bounds"][0], settings_dat["yr_bounds"][1] + 1)
    
        # Make predictions by network 
        sims_dict[EXP_NAME_net][EXP_NAME_dat] = {}
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["labels"] = labels_test + y_subtract[np.newaxis,:]
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["pred"] = sims_dict[EXP_NAME_net]['model'].predict(x_test) + y_subtract[np.newaxis,:]

        # Reshape temperature data
        x_eval_ = x_test.copy().astype('float32')
        x_eval = x_eval_.reshape((test_shape, time_shape, map_shape[0], map_shape[1],1))

        # Initialize gradients
        grads = np.empty((test_shape, time_shape, map_shape[0], map_shape[1]))

        # Compute gradients
        for mm in range(test_shape):
            grads[mm,:,:,:] = xai.get_gradients(\
                        sims_dict[EXP_NAME_net]['model'], x_eval[mm,:,:,:,:], pred_idx=0)

        # Mask gradients
        add_mask = settings_dat["input_region"]
        if add_mask:
            mask = xr.load_dataarray(DATA_DIRECTORY + add_mask).to_numpy()
            mask[mask<0.5] = np.nan
            grads = grads * mask
            tas_masked = x_eval.squeeze() * mask
        grads = grads.astype('float32')
        
        # Save results
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["R_truth"] = \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["labels"].copy().reshape(
                test_shape, time_shape
            )
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["R_pred"] = \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["pred"].copy().reshape(
                test_shape, time_shape
            )
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["years"] = years
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["tas"] = tas_masked
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["grads"] = grads
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["lat"] = lat
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["lon"] = lon

In [None]:
# Initialize mse, r2
mse_save = np.full((NUM_NETW,NUM_DATA),np.nan)
r2_save = np.full((NUM_NETW,NUM_DATA),np.nan)

# Metrics
for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
    print("*********  "+EXP_NAME_net+"  *********")
    for j_dat, EXP_NAME_dat in enumerate(DATA_NAMES):
        mse_save[j_net,j_dat] = metrics.mean_squared_error(
                sims_dict[EXP_NAME_net][EXP_NAME_dat]["labels"].squeeze(), 
                sims_dict[EXP_NAME_net][EXP_NAME_dat]["pred"].squeeze()
            )
        r2_save[j_net,j_dat] = metrics.r2_score(
                sims_dict[EXP_NAME_net][EXP_NAME_dat]["labels"].squeeze(), 
                sims_dict[EXP_NAME_net][EXP_NAME_dat]["pred"].squeeze()
            )
        print(EXP_NAME_dat+":"\
            +" mse = "+plots.num_lab(mse_save[j_net,j_dat],4)\
            +", r2 = "+plots.num_lab(r2_save[j_net,j_dat],4))
        sims_dict[EXP_NAME_net][EXP_NAME_dat]['mse'] = mse_save[j_net,j_dat]
        sims_dict[EXP_NAME_net][EXP_NAME_dat]['r2'] = r2_save[j_net,j_dat]

In [None]:
m = 0

r2_shift = 0.06
imp.reload(plots)
fig, ax = plt.subplots(NUM_DATA,1,figsize=(10, 5*NUM_DATA))
if NUM_DATA == 1:
    ax = np.array([ax])
for j_dat, EXP_NAME_dat in enumerate(DATA_NAMES):
    for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
        da = sims_dict[EXP_NAME_net][EXP_NAME_dat]

        # Plot truth
        if j_net == 0:
            ax[j_dat].plot(da['years'], da["R_truth"][m, :], \
                label="Truth", linestyle="-", \
                linewidth=3, color="black")
            
        # Plot predictions
        ax[j_dat].plot(da["years"], da["R_pred"][m, :], label=EXP_NAME_net, linewidth=1.5)

        # Add RMSE and R2 values
        rmse = np.sqrt(da['mse'])
        r2 = da['r2']
        ax[j_dat].text(0.97, 0.98-r2_shift*j_net, 
                       plots.num_lab(r2,2).ljust(4,'0')+" ("+plots.num_lab(rmse,2).ljust(4,'0')+"W/m$^2$)", 
                       transform=ax[j_dat].transAxes, fontsize=12, verticalalignment='top',
                       horizontalalignment='right',color=plots.npcols[j_net])

    ax[j_dat].set_xticks(np.arange(1880,2100,40))
    ax[j_dat].set_xlim(da['years'][0],da['years'][-1])
    ax[j_dat].set_xlabel("Time (yrs)")
    ax[j_dat].set_ylabel("$R = N-F$ (W/m$^2$)")
    ax[j_dat].minorticks_on()

    ax[j_dat].set_title(EXP_NAME_dat)
    
ax[0].legend(loc='lower left')


if SAVE_FILES:
    fig.savefig("./figures/testingCNN.pdf",bbox_inches='tight',pad_inches = 0)

In [None]:
imp.reload(plots)

# Define years over which to take average gradient
PLOT_YEAR_MIN = 1870
PLOT_YEAR_MAX = 2100

# Define plot range
l = np.linspace(-4,4,41)
t = np.linspace(-4,4,5)
norm = mpl.colors.BoundaryNorm(l, cm.vik.N)

for j_dat, EXP_NAME_dat in enumerate(DATA_NAMES):
    for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
        grad_dict = sims_dict[EXP_NAME_net][EXP_NAME_dat]

        # Get relevant years
        years = grad_dict['years']
        itime = np.where(np.logical_and(PLOT_YEAR_MIN <= years,years <= PLOT_YEAR_MAX))[0]

        # Average gradients over plot period (multiply by 1000 to get milliWatt/m^2/K)
        grads = grad_dict["grads"]
        x_plot = 1e3*grads[:,itime,:,:].mean(axis=(0,1))
        x_cyc, lons_cyc = add_cyclic_point(x_plot, coord=grad_dict["lon"])

        # Setup figure
        fig, ax = plots.setup_figure(nCols=1,nRows=1,size=(5,4),mask=False)
        ax.set_facecolor((0.8,0.8,0.8))
        cf = ax.pcolormesh(lons_cyc,grad_dict["lat"],x_cyc,norm=norm,transform=plots.data_crs,cmap=cm.vik)
        if show_colorbar:
            cb = plt.colorbar(cf,ax=ax, orientation = "horizontal",shrink=1.0, extend='both',ticks=t)
            cb.set_label("Gradient [mW/m$^2$/K]")
            ax.set_title("Network: "+EXP_NAME_net+", data: "+EXP_NAME_dat+", "+\
                        str(PLOT_YEAR_MIN)+"-"+str(PLOT_YEAR_MAX))

        if SAVE_FILES:
            fig.savefig("./figures/grad_"+EXP_NAME_net+"_"+EXP_NAME_dat+"_"+\
                        str(PLOT_YEAR_MIN)+"-"+str(PLOT_YEAR_MAX)+".pdf"\
                        ,bbox_inches='tight',pad_inches = 0)

In [None]:
if SAVE_FILES_NC:
    for j_dat, EXP_NAME_dat in enumerate(DATA_NAMES):
        for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
            da = sims_dict[EXP_NAME_net][EXP_NAME_dat]
            r0 = sims_dict[EXP_NAME_net]['model'].predict(np.zeros((1, map_shape[0], map_shape[1], 1)))[0,0]
            print(r0)

            grad_da = xr.DataArray(da['grads'],coords = {\
                    'year': da['years'], \
                    'lat': da['lat'], \
                    'lon': da['lon'] \
                },
                dims=('mem','year','lat','lon'),
                name="grad"
            )
            true_da = xr.DataArray(da['R_truth'],coords = {\
                    'year': da['years'], \
                },
                dims=('mem','year'),
                name="truth"
            )
            pred_da = xr.DataArray(da['R_pred'],coords = {\
                    'year': da['years'], \
                },
                dims=('mem','year'),
                name="prediction"
            )
            r2_da = xr.DataArray(da['r2'],
                dims=(),
                name="R2"
            )
            mse_da = xr.DataArray(da['mse'],
                dims=(),
                name="mse"
            )
            r0_da = xr.DataArray(r0,
                dims=(),
                name="pred_T0"
            )

            ds = xr.merge([grad_da,true_da,pred_da,r2_da,mse_da,r0_da])
            
            filename_save = "./figures/CNN_"+EXP_NAME_net+"__Data_"+EXP_NAME_dat
            ds.to_netcdf(filename_save+'.nc')
            
if SAVE_FILES_TAS:
    for j_dat, EXP_NAME_dat in enumerate(DATA_NAMES):
        EXP_NAME_net = NETW_NAMES[0]
        da = sims_dict[EXP_NAME_net][EXP_NAME_dat]

        tas_da = xr.DataArray(da['tas'],coords = {\
                'year': da['years'], \
                'lat': da['lat'], \
                'lon': da['lon'] \
            },
            dims=('mem','year','lat','lon'),
            name="tas"
        )
        
        filename_save = "./figures/TAS__Data_"+EXP_NAME_dat
        tas_da.to_netcdf(filename_save+'.nc')

if SAVE_FILES_AVGGRAD:
    for j_dat, EXP_NAME_dat in enumerate(DATA_NAMES):
        for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
            if EXP_NAME_net != EXP_NAME_dat:
                continue
            da = sims_dict[EXP_NAME_net][EXP_NAME_dat]

            tas_da = xr.DataArray(da['grads'].mean(axis=(0,1)),coords = {\
                    'lat': da['lat'], \
                    'lon': da['lon'] \
                },
                dims=('lat','lon'),
                name="grad"
            )
            
            filename_save = "./figures/GRAD_"+EXP_NAME_dat
            tas_da.to_netcdf(filename_save+'.nc')