In [None]:
import os
import importlib as imp

import pandas as pd
import pickle
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import random

import sklearn
from sklearn import metrics

import experiment_settings
import file_methods, plots, data_processing
from DIRECTORIES import MODEL_DIRECTORY, DATA_DIRECTORY

savefig_dpi = 300

In [None]:
SAVE_FILES = False
RECOMPUTE_ANOMALIES = (1870,1900)

In [None]:
# Name of data files for obs testing
DATA_NAMES = (
    "DATA_MPI_hist_rcp85", 
    "DATA_CanESM5_hist_ssp245", 
    "DATA_IPSL_hist_ssp245", 
    "DATA_MIROC6_hist_ssp245", 
)

# Names of the trained networks
NETW_NAMES = ("MultiModel_hpt0_s3", )

TREND_START_YEAR = 2001
TREND_END_YEAR = 2023

# Number of networks and datasets
NUM_NETW = len(NETW_NAMES)
NUM_DATA = len(DATA_NAMES)

imp.reload(data_processing)
imp.reload(experiment_settings)
imp.reload(file_methods)

# Initialize prediction dictionary
sims_dict = {}

# Load network models
for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
    print("*********  "+EXP_NAME_net+"  *********")

    # Get settings of network
    settings_net = experiment_settings.get_settings(EXP_NAME_net[:-3])
    settings_net["rng_seed"] = int(EXP_NAME_net[-1])

    # Set seeds
    tf.random.set_seed(settings_net["rng_seed"])
    random.seed(settings_net["rng_seed"])
    np.random.seed(settings_net["rng_seed"])

    # Load network
    model_name = file_methods.get_model_name(settings_net)
    if not os.path.exists(MODEL_DIRECTORY + model_name + "_model"):
        raise RuntimeError("No such model experiment: " + model_name)
    model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)

    # Save network
    sims_dict[EXP_NAME_net] = {
        "model": model,
        "settings": settings_net,
    }

# Loop over datasets
for j_dat, EXP_NAME_dat in enumerate(DATA_NAMES):
    print("*********  "+EXP_NAME_dat+"  *********")

    # Get data settings
    settings_dat = experiment_settings.get_settings(EXP_NAME_dat)

    # Loop over networks
    for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
        settings_net = sims_dict[EXP_NAME_net]['settings']
        model_name_net = file_methods.get_model_name(settings_net)
        settings_dat["rng_seed"] = settings_net["rng_seed"]

        add_mask = settings_net['input_region']
        if add_mask:
            settings_dat["input_region"] = add_mask

        if "subtract_val" in sims_dict[EXP_NAME_net]['settings'].keys():
            if sims_dict[EXP_NAME_net]['settings']["subtract_val"]:
                settings_dat["subtract_val"] = model_name_net+".pickle"
            
            with open(MODEL_DIRECTORY+settings_dat["subtract_val"], 'rb') as f:
                _ = pickle.load(f)
                y_subtract = pickle.load(f)
        else:
            settings_dat["subtract_val"] = False
            y_subtract = np.array([0.])

        # GET THE DATA
        (
            _,
            _,
            tas,
            _,
            _,
            R_truth,
            lat,
            lon,
            map_shape,
            member_shape,
            time_shape,
            _,
            member_enrollment,
            _,
            F_truth,
        ) = data_processing.get_cmip_data(
            DATA_DIRECTORY,
            settings_dat,
            n_train_val_test=settings_dat["n_train_val_test"],
            get_forcing = True
        )
        test_shape = settings_dat["n_train_val_test"][2]
        years = np.arange(settings_dat["yr_bounds"][0], settings_dat["yr_bounds"][1] + 1)

        if RECOMPUTE_ANOMALIES:
            iy_anom = np.where(\
                (RECOMPUTE_ANOMALIES[0] <= years) & \
                (years <= RECOMPUTE_ANOMALIES[1]))[0]
            F_truth = F_truth - np.mean(F_truth[iy_anom])
            y_subtract = - np.array([np.mean(R_truth.reshape(test_shape, time_shape)[:,iy_anom])])
            

        N_truth = F_truth + R_truth.reshape(test_shape, time_shape) + y_subtract[np.newaxis,:]

        
        # Make predictions
        sims_dict[EXP_NAME_net][EXP_NAME_dat] = {}
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["labels"] = R_truth + y_subtract[np.newaxis,:]
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["pred"] = sims_dict[EXP_NAME_net]['model'].predict(tas) + y_subtract[np.newaxis,:]

        # Reshape and save R
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["R_truth"] = \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["labels"].reshape(
                test_shape, time_shape
            )
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["R_pred"] = \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["pred"].reshape(
                test_shape, time_shape
            )
        
        # Save truth
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["F_truth"] = F_truth
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["N_truth"] = N_truth
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["years"] = years

        # Forcing prediction
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["F_pred"] = \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["N_truth"]\
             - sims_dict[EXP_NAME_net][EXP_NAME_dat]["R_pred"]

Forster_df=pd.read_csv(DATA_DIRECTORY+"obs/ERF_best_1750-2023.csv",sep=",")
Forster_y = Forster_df['year'].to_numpy()
Forster_F = Forster_df.loc[:, Forster_df.columns != 'year'].sum(axis=1).to_numpy()
if RECOMPUTE_ANOMALIES:
    Forster_iy = np.where(\
        (RECOMPUTE_ANOMALIES[0] <= Forster_y) & \
        (Forster_y <= RECOMPUTE_ANOMALIES[1]))[0]
else:
    Forster_iy = np.where(\
            (Forster_y >= 2001) & \
            (Forster_y <= 2020))[0]
Forster_F = Forster_F - np.mean(Forster_F[Forster_iy])

In [None]:
# Initialize mse, r2
mse_save = np.full((NUM_NETW,NUM_DATA),np.nan)
r2_save = np.full((NUM_NETW,NUM_DATA),np.nan)

# Metrics
for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
    print("*********  "+EXP_NAME_net+"  *********")
    for j_dat, EXP_NAME_dat in enumerate(DATA_NAMES):
        mse_save[j_net,j_dat] = metrics.mean_squared_error(
                sims_dict[EXP_NAME_net][EXP_NAME_dat]["labels"].squeeze(), 
                sims_dict[EXP_NAME_net][EXP_NAME_dat]["pred"].squeeze()
            )
        r2_save[j_net,j_dat] = metrics.r2_score(
                sims_dict[EXP_NAME_net][EXP_NAME_dat]["labels"].squeeze(), 
                sims_dict[EXP_NAME_net][EXP_NAME_dat]["pred"].squeeze()
            )
        print(EXP_NAME_dat+":"\
            +" mse = "+plots.num_lab(mse_save[j_net,j_dat],4)\
            +", r2 = "+plots.num_lab(r2_save[j_net,j_dat],4))
        sims_dict[EXP_NAME_net][EXP_NAME_dat]['mse'] = mse_save[j_net,j_dat]
        sims_dict[EXP_NAME_net][EXP_NAME_dat]['r2'] = r2_save[j_net,j_dat]

In [None]:
fig, ax = plt.subplots(NUM_DATA,1,figsize=(10,5*NUM_DATA))
if NUM_DATA == 1:
    ax = np.array([ax])

for j_dat, EXP_NAME_dat in enumerate(DATA_NAMES):
    ax[j_dat].plot(Forster_y[120:],Forster_F[120:],color='g',linewidth=1.5,label="Forster")
    for j_net, EXP_NAME_net in enumerate(NETW_NAMES):

        # Response
        ax[j_dat].plot(
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["years"], \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["N_truth"].T, \
            linewidth=0.5,color='k')
        ax[j_dat].plot(
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["years"], \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["N_truth"].mean(axis=0), \
            linewidth=2,color='k')

        # Forcing
        ax[j_dat].plot(
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["years"], \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["F_truth"], \
            linewidth=2,color='k')
        
        ax[j_dat].plot(
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["years"], \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["F_pred"].T, \
            linewidth=0.5,color=plots.npcols[0])
        ax[j_dat].plot(
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["years"], \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["F_pred"].mean(axis=0), \
            linewidth=2,label=EXP_NAME_net,color=plots.npcols[0])
        
        # Response
        ax[j_dat].plot(
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["years"], \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["R_truth"].T, \
            linewidth=0.5,color='k')
        ax[j_dat].plot(
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["years"], \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["R_truth"].mean(axis=0), \
            linewidth=2,color='k')
        
        ax[j_dat].plot(
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["years"], \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["R_pred"].T, \
            linewidth=0.5,color=plots.npcols[1])
        ax[j_dat].plot(
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["years"], \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["R_pred"].mean(axis=0), \
            linewidth=2,color=plots.npcols[1])
        

    ax[j_dat].set_xlim(1950,2040)
    # ax[j_dat].set_ylim(-2.8,1.3)
    ax[j_dat].set_xlabel("Year")
    ax[j_dat].set_ylabel("Radiative flux [W/m$^2$]")
    ax[j_dat].set_title(EXP_NAME_dat[5:])
    ax[j_dat].legend()


In [None]:
def write_line(f, x, y):
    f.write(str(x))
    f.write(" ")
    f.write(str(y))
    f.write("\n")

if SAVE_FILES:
    for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
        for j_dat, EXP_NAME_dat in enumerate(DATA_NAMES):
            for val in ['R_truth', 'R_pred', 'F_truth', 'N_truth', 'F_pred']:
                years = sims_dict[EXP_NAME_net][EXP_NAME_dat]['years']
                x = sims_dict[EXP_NAME_net][EXP_NAME_dat][val] 
                
                if x.ndim>1:
                    nn, _ = x.shape

                    for ii in range(nn):
                        with open("./figures/"+val+'_'+EXP_NAME_net+'_'+EXP_NAME_dat\
                                    +"_m"+str(ii)+".dat","w") as f:
                            for jj in range(years.size):
                                if np.isnan(x[ii,jj]):
                                    continue
                                write_line(f, years[jj], x[ii,jj])
                    x = x.mean(axis=0)
                    sfx="_mean"
                with open("./figures/"+val+'_'+EXP_NAME_net+'_'+EXP_NAME_dat+sfx+".dat","w") as f:
                    for jj in range(years.size):
                        if np.isnan(x[jj]):
                            continue
                        write_line(f, years[jj], x[jj])
                    sfx=""