In [None]:
import sys, os
import importlib as imp
import gc

import pandas as pd
import xarray as xr
import pickle
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import random

from scipy.stats import linregress, t, norm

import experiment_settings
import file_methods, plots, data_processing
from DIRECTORIES import MODEL_DIRECTORY, DATA_DIRECTORY

savefig_dpi = 300

In [None]:
SAVE_FILES = False

In [None]:
# Name of data files for obs testing
DATA_NAMES = (
    "obs_ERA5_deepC+CERES", 
    # "obs_ERA5_SST_deepC+CERES", 
    "obs_COBE2_deepC+CERES", 
    "obs_PCMDI_deepC+CERES", 
    "obs_HadISST_deepC+CERES", 
)

# Names of the trained networks
NETW_NAMES = ("MultiModel_hpt0_s3", )

TREND_START_YEAR = 2001
TREND_END_YEAR = 2023

# Number of networks and datasets
NUM_NETW = len(NETW_NAMES)
NUM_DATA = len(DATA_NAMES)

imp.reload(data_processing)
imp.reload(experiment_settings)
imp.reload(file_methods)

# Initialize prediction dictionary
sims_dict = {}

# Load network models
for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
    print("*********  "+EXP_NAME_net+"  *********")

    # Get settings of network
    settings_net = experiment_settings.get_settings(EXP_NAME_net[:-3])
    settings_net["rng_seed"] = int(EXP_NAME_net[-1])

    # Set seeds
    tf.random.set_seed(settings_net["rng_seed"])
    random.seed(settings_net["rng_seed"])
    np.random.seed(settings_net["rng_seed"])

    # Load network
    model_name = file_methods.get_model_name(settings_net)
    if not os.path.exists(MODEL_DIRECTORY + model_name + "_model"):
        raise RuntimeError("No such model experiment: " + model_name)
    model = file_methods.load_tf_model(model_name, MODEL_DIRECTORY)

    # Save network
    sims_dict[EXP_NAME_net] = {
        "model": model,
        "settings": settings_net,
    }

# Loop over datasets
for j_dat, EXP_NAME_dat in enumerate(DATA_NAMES):
    print("*********  "+EXP_NAME_dat+"  *********")

    # Get data settings & name
    settings_dat = experiment_settings.get_settings(EXP_NAME_dat)

    # Loop over networks
    for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
        settings_net = sims_dict[EXP_NAME_net]['settings']
        model_name_net = file_methods.get_model_name(settings_net)
        settings_dat["rng_seed"] = settings_net["rng_seed"]

        add_mask = settings_net['input_region']
        if add_mask:
            settings_dat["input_region"] = add_mask

        if "subtract_val" in sims_dict[EXP_NAME_net]['settings'].keys():
            if sims_dict[EXP_NAME_net]['settings']["subtract_val"]:
                settings_dat["subtract_val"] = model_name_net+".pickle"
            
            with open(MODEL_DIRECTORY+settings_dat["subtract_val"], 'rb') as f:
                _ = pickle.load(f)
                y_subtract = pickle.load(f)
        else:
            settings_dat["subtract_val"] = False
            y_subtract = np.array([0.])

        if EXP_NAME_dat[:3] == 'obs':
            # GET THE DATA for OOS testing
            N_truth_da, tas, years = data_processing.get_obs_data(
                DATA_DIRECTORY,
                settings_dat,
            )
            test_shape = 1
            time_shape = years.size

            # Remove 2023 if no tas data
            if years[-1]==2022 and EXP_NAME_dat[-5:]=='CERES':
                N_truth_da = N_truth_da[:-1]

            # Get N on same timescale
            years_N = N_truth_da['year'].to_numpy()            
            N_truth = np.full((test_shape,time_shape),np.nan)
            iyears = np.where((years_N[0] <= years) & (years <= years_N[-1]))[0]
            N_truth[0,iyears] = N_truth_da.to_numpy()

            # R and F are not known
            R_truth = np.full(N_truth.shape,np.nan).T
            F_truth = np.full(N_truth.shape,np.nan)
            F_truth = F_truth[0,:]
            tas = tas[:,:,:,np.newaxis]
        elif EXP_NAME_dat[:4] == 'DATA':
            
            # GET THE DATA
            (
                _,
                _,
                tas,
                _,
                _,
                R_truth,
                lat,
                lon,
                map_shape,
                member_shape,
                time_shape,
                _,
                member_enrollment,
                _,
                F_truth,
            ) = data_processing.get_cmip_data(
                DATA_DIRECTORY,
                settings_dat,
                n_train_val_test=settings_dat["n_train_val_test"],
                get_forcing = True
            )
            test_shape = settings_dat["n_train_val_test"][2]
            N_truth = F_truth + R_truth.reshape(test_shape, time_shape)

            years = np.arange(settings_dat["yr_bounds"][0], settings_dat["yr_bounds"][1] + 1)


        else:
            raise RuntimeError("Dataformat not recognized.")
        
        # Make predictions
        sims_dict[EXP_NAME_net][EXP_NAME_dat] = {}
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["labels"] = R_truth
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["pred"] = sims_dict[EXP_NAME_net]['model'].predict(tas) + y_subtract[np.newaxis,:]

        # Reshape and save R
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["R_truth"] = \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["labels"].reshape(
                test_shape, time_shape
            )
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["R_pred"] = \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["pred"].reshape(
                test_shape, time_shape
            )
        
        # Save truth
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["F_truth"] = F_truth
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["N_truth"] = N_truth
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["years"] = years

        # Forcing prediction
        sims_dict[EXP_NAME_net][EXP_NAME_dat]["F_pred"] = \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["N_truth"]\
             - sims_dict[EXP_NAME_net][EXP_NAME_dat]["R_pred"]

Forster_df=pd.read_csv(DATA_DIRECTORY+"obs/ERF_best_1750-2023.csv",sep=",")
Forster_y = Forster_df['year'].to_numpy()
Forster_F = Forster_df.loc[:, Forster_df.columns != 'year'].sum(axis=1).to_numpy()
Forster_iy = np.where(\
        (Forster_y >= 2001) & \
        (Forster_y <= 2020))[0]
Forster_F = Forster_F - np.mean(Forster_F[Forster_iy])

In [None]:
def linear_regression_with_confidence_bounds(x, y, confidence=0.1):
    """
    Compute linear regression of x and y with confidence bounds.

    Parameters:
        x (array-like): Independent variable data.
        y (array-like): Dependent variable data.
        confidence (float, optional): Confidence level for the bounds. Default is 0.95.

    Returns:
        slope (float): Slope of the regression line.
        intercept (float): Intercept of the regression line.
        slope_stderr (float): Standard error of the slope.
        intercept_stderr (float): Standard error of the intercept.
        confidence_bounds (tuple): Lower and upper confidence bounds of the regression line.
    """
    # Compute linear regression
    res = linregress(x, y)
    slope = res.slope
    intercept = res.intercept
    
    # Degrees of freedom
    df = len(x) - 2
    
    # T-statistic for given confidence level
    t_value = t.ppf(confidence / 2, df)
        
    # Standard error of the slope and intercept
    slope_stderr = res.stderr
    intercept_stderr = res.intercept_stderr
    
    # Confidence bounds for the regression line
    slope_lb = slope - t_value * slope_stderr
    slope_ub = slope + t_value * slope_stderr
    intercept_lb = intercept - t_value * intercept_stderr
    intercept_ub = intercept + t_value * intercept_stderr
    
    slope_bounds = (slope_lb, slope_ub)
    intercept_bounds = (intercept_lb, intercept_ub)
    
    return slope, intercept, slope_bounds, intercept_bounds, slope_stderr, intercept_stderr

years_trend = np.arange(TREND_START_YEAR,TREND_END_YEAR+1)

# Get N trend standard deviation
bound_95 = 0.02
z_25 = t.ppf(0.975,len(years_trend)-2)
z_5 = t.ppf(0.95,len(years_trend)-2)
std_N = bound_95/norm.ppf(0.975)

# Forster trend
tmp_yrs = Forster_y
iy_ = np.where((TREND_START_YEAR <= tmp_yrs) & (tmp_yrs <= TREND_END_YEAR))[0]
y_ = Forster_F[iy_].copy()
slope, intercept, _, _, slope_stderr, intercept_stderr = \
    linear_regression_with_confidence_bounds(years_trend,y_,0.95)
print("Forster F = "+plots.num_lab(slope*10,4).ljust(7,"0")
    +" +- "+plots.num_lab(slope_stderr*z_5*10,4).ljust(7,"0")
)

# N trend
if TREND_START_YEAR>=1985:
    tmp_yrs = sims_dict[NETW_NAMES[0]][DATA_NAMES[0]]["years"].copy()
    iy_ = np.where((TREND_START_YEAR <= tmp_yrs) & (tmp_yrs <= TREND_END_YEAR))[0]
    y_ = sims_dict[NETW_NAMES[0]][DATA_NAMES[0]]["N_truth"][0,iy_].copy()
    slope, intercept, _, _, slope_stderr, intercept_stderr = \
        linear_regression_with_confidence_bounds(years_trend,y_,0.95)
    print("CERES N = "+plots.num_lab(slope*10,4).ljust(7,"0")
        +" +- "+plots.num_lab(slope_stderr*z_5*10,4).ljust(7,"0")
    )

In [None]:
trend_dict = {}
for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
    fig, ax = plt.subplots(1,2,figsize=(12,4))
    ax[0].set_ylabel("F")
    ax[1].set_ylabel("R")
    ax[0].set_title(EXP_NAME_net)
    F_ALL = []
    R_ALL = []
    trend_dict[EXP_NAME_net] = {}
    for j_dat, EXP_NAME_dat in enumerate(DATA_NAMES):
        trend_dict[EXP_NAME_net][EXP_NAME_dat] = {}
        tmp_yrs = sims_dict[EXP_NAME_net][EXP_NAME_dat]['years']
        iy_ = np.where((TREND_START_YEAR <= tmp_yrs) & (tmp_yrs <= TREND_END_YEAR))[0]
        y_ = sims_dict[EXP_NAME_net][EXP_NAME_dat]["F_pred"][0,:].copy()
        y_ = y_[iy_]

        if TREND_END_YEAR > tmp_yrs[-1]:
            y_ = np.append(y_,np.nan)
        finite_F = np.logical_not(np.isnan(y_))
        
        F_ALL.append(y_)

        slope, intercept, _, _, slope_stderr, intercept_stderr = \
            linear_regression_with_confidence_bounds(years_trend[finite_F],y_[finite_F],0.95)

        print(EXP_NAME_dat+", F Slope = "+plots.num_lab(slope*10,4).ljust(7,"0")
            +" +- "+plots.num_lab(slope_stderr*10,4).ljust(7,"0")
            +", Intersect = "+plots.num_lab(intercept,4).ljust(5,"0")
            +" +- "+plots.num_lab(intercept_stderr*10,4).ljust(7,"0")
        )

        trend_dict[EXP_NAME_net][EXP_NAME_dat]['F'] = y_
        trend_dict[EXP_NAME_net][EXP_NAME_dat]['F_slope'] = slope
        trend_dict[EXP_NAME_net][EXP_NAME_dat]['F_slope_std'] = slope_stderr
        trend_dict[EXP_NAME_net][EXP_NAME_dat]['F_intercept'] = intercept
        trend_dict[EXP_NAME_net][EXP_NAME_dat]['F_intercept_std'] = intercept_stderr

        ax[0].plot(years_trend,y_-np.nanmean(y_),color=plots.npcols[np.mod(j_dat,6)])
        line = slope * years_trend
        ax[0].plot(years_trend, line-np.mean(line),'--',color=plots.npcols[np.mod(j_dat,6)])

        y_ = sims_dict[EXP_NAME_net][EXP_NAME_dat]["R_pred"][0,:].copy()
        y_ = y_[iy_]
        if TREND_END_YEAR > tmp_yrs[-1]:
            y_ = np.append(y_,np.nan)
        finite_R = np.logical_not(np.isnan(y_))
        R_ALL.append(y_)

        slope, intercept, _, _, slope_stderr, intercept_stderr = \
            linear_regression_with_confidence_bounds(years_trend[finite_R],y_[finite_R],0.95)

        print(EXP_NAME_dat+", R Slope = "+plots.num_lab(slope*10,4).ljust(7,"0")
            +" +- "+plots.num_lab(slope_stderr*10,4).ljust(7,"0")
            +", Intersect = "+plots.num_lab(intercept,4).ljust(5,"0")
            +" +- "+plots.num_lab(intercept_stderr*10,4).ljust(7,"0")
        )

        trend_dict[EXP_NAME_net][EXP_NAME_dat]['R'] = y_
        trend_dict[EXP_NAME_net][EXP_NAME_dat]['R_slope'] = slope
        trend_dict[EXP_NAME_net][EXP_NAME_dat]['R_slope_std'] = slope_stderr
        trend_dict[EXP_NAME_net][EXP_NAME_dat]['R_intercept'] = intercept
        trend_dict[EXP_NAME_net][EXP_NAME_dat]['R_intercept_std'] = intercept_stderr

        ax[1].plot(years_trend,y_-np.nanmean(y_),color=plots.npcols[np.mod(j_dat,6)])
        line = slope * years_trend
        ax[1].plot(years_trend, line-np.mean(line),'--',color=plots.npcols[np.mod(j_dat,6)])

    F_ALL = np.array(F_ALL)
    R_ALL = np.array(R_ALL)

    x_all = np.broadcast_to(years_trend,F_ALL.shape).flatten()
    y_all = F_ALL.flatten()
    finite_F = np.logical_not(np.isnan(y_all))


    F_all_slope, F_all_intercept, F_all_slope_bounds, F_all_intercept_bounds, F_all_slope_stderr, F_all_intercept_stderr = linear_regression_with_confidence_bounds(x_all[finite_F],y_all[finite_F],0.95)

    print("ALL F Slope = "+plots.num_lab(F_all_slope*10,4).ljust(7,"0")
            +" +- "+plots.num_lab(F_all_slope_stderr*10,4).ljust(7,"0")
            +" ("+plots.num_lab(F_all_slope_bounds[0]*10,4).ljust(7,"0")
            +","+plots.num_lab(F_all_slope_bounds[1]*10,4).ljust(7,"0")+")"
            +", Intersect = "+plots.num_lab(F_all_intercept,4).ljust(5,"0")
            +" ("+plots.num_lab(F_all_intercept_bounds[0],4).ljust(5,"0")
            +","+plots.num_lab(F_all_intercept_bounds[1],4).ljust(5,"0")+")"
        )

    trend_dict[EXP_NAME_net]['F'] = F_ALL
    trend_dict[EXP_NAME_net]['F_slope'] = F_all_slope
    trend_dict[EXP_NAME_net]['F_slope_std'] = F_all_slope_stderr
    trend_dict[EXP_NAME_net]['F_intercept'] = F_all_intercept
    trend_dict[EXP_NAME_net]['F_intercept_std'] = F_all_intercept_stderr


    line = F_all_slope * years_trend
    ax[0].plot(years_trend, line-np.mean(line),color='k',linewidth=3)

    x_all = np.broadcast_to(years_trend,R_ALL.shape).flatten()
    y_all = R_ALL.flatten()
    finite_R = np.logical_not(np.isnan(y_all))

    R_all_slope, R_all_intercept, R_all_slope_bounds, R_all_intercept_bounds, R_all_slope_stderr, R_all_intercept_stderr = linear_regression_with_confidence_bounds(x_all[finite_R],y_all[finite_R],0.95)

    print("ALL R Slope = "+plots.num_lab(R_all_slope*10,4).ljust(7,"0")
            +" +- "+plots.num_lab(R_all_slope_stderr*10,4).ljust(7,"0")
            +" ("+plots.num_lab(R_all_slope_bounds[0]*10,4).ljust(7,"0")
            +","+plots.num_lab(R_all_slope_bounds[1]*10,4).ljust(7,"0")+")"
            +", Intersect = "+plots.num_lab(R_all_intercept,4).ljust(5,"0")
            +" ("+plots.num_lab(R_all_intercept_bounds[0],4).ljust(5,"0")
            +","+plots.num_lab(R_all_intercept_bounds[1],4).ljust(5,"0")+")"
        )

    trend_dict[EXP_NAME_net]['R'] = R_ALL
    trend_dict[EXP_NAME_net]['R_slope'] = R_all_slope
    trend_dict[EXP_NAME_net]['R_slope_std'] = R_all_slope_stderr
    trend_dict[EXP_NAME_net]['R_intercept'] = R_all_intercept
    trend_dict[EXP_NAME_net]['R_intercept_std'] = R_all_intercept_stderr

    print(trend_dict[EXP_NAME_net]['F_slope'],trend_dict[EXP_NAME_net]['F_intercept'])



In [None]:
fig, ax = plt.subplots(NUM_NETW,1,figsize=(10,5*NUM_NETW))
if NUM_NETW == 1:
    ax = np.array([ax])

for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
    ax[j_net].plot(Forster_y,Forster_F,color='k',linewidth=1.5,label="Forster")

    var_R_all = np.empty(NUM_DATA)
    for j_dat, EXP_NAME_dat in enumerate(DATA_NAMES):

        ax[j_net].plot(
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["years"], \
            sims_dict[EXP_NAME_net][EXP_NAME_dat]["F_pred"][0,:], \
            linewidth=1.5,label=EXP_NAME_dat[4:-6],color=plots.npcols[j_dat])
        
        # Get trends
        a = trend_dict[EXP_NAME_net][EXP_NAME_dat]['F_slope']
        b = trend_dict[EXP_NAME_net][EXP_NAME_dat]['F_intercept']
        var_R_all[j_dat] = trend_dict[EXP_NAME_net][EXP_NAME_dat]['R_slope_std']**2
        
        ax[j_net].plot(
            years_trend, a*years_trend+b,'--', \
            linewidth=1.5,color=plots.npcols[j_dat])

    # Total trend
    a = trend_dict[EXP_NAME_net]['F_slope']
    b = trend_dict[EXP_NAME_net]['F_intercept']

    std_R = np.sqrt(np.mean(var_R_all))
    std_F = np.sqrt(std_N**2+std_R**2)

    ax[j_net].plot(years_trend, a*years_trend+b,'--', \
            linewidth=2.5,color='red')
    print(EXP_NAME_net+": dF = "\
            +plots.num_lab(a*10,4).ljust(6,'0') + " +- "
            +plots.num_lab(std_F*z_5*10,4).ljust(6,"0")
            +" W/m^2/decade"\
        )

    ax[j_net].set_xlim(1985,2023)
    ax[j_net].set_ylim(-2.8,1.3)
    ax[j_net].set_xlabel("Year")
    ax[j_net].set_ylabel("F [W/m$^2$]")
    ax[j_net].set_title(EXP_NAME_net)
    ax[j_net].legend()


In [None]:
def write_line(f, x, y):
    f.write(str(x))
    f.write(" ")
    f.write(str(y))
    f.write("\n")

if SAVE_FILES:
    # Save observational mean
    for j_net, EXP_NAME_net in enumerate(NETW_NAMES):
        years_all = np.arange(1850,2024)
        R_ALL = np.full((NUM_DATA,years_all.size),np.nan)
        F_ALL = np.full((NUM_DATA,years_all.size),np.nan)

        for j_dat, EXP_NAME_dat in enumerate(DATA_NAMES):
            DAT_NAME = EXP_NAME_dat[4:-12]
            yrs = sims_dict[EXP_NAME_net][EXP_NAME_dat]["years"]
            iyears = np.where((yrs[0] <= years_all) & (years_all <= yrs[-1]))[0]

            # Save R
            val = sims_dict[EXP_NAME_net][EXP_NAME_dat]["R_pred"][0,:]
            with open("./figures/R_"+EXP_NAME_net+'_'+DAT_NAME+".dat","w") as f:
                for jj in range(yrs.size):
                    if np.isnan(val[jj]):
                        continue
                    write_line(f, yrs[jj], val[jj])
            R_ALL[j_dat,iyears] = val.copy()

            # Save F
            val = sims_dict[EXP_NAME_net][EXP_NAME_dat]["F_pred"][0,:]
            with open("./figures/F_"+EXP_NAME_net+'_'+DAT_NAME+"_deepC.dat","w") as f:
                for jj in range(yrs.size):
                    if np.isnan(val[jj]):
                        continue
                    if yrs[jj] > 2001:
                        continue
                    write_line(f, yrs[jj], val[jj])
            with open("./figures/F_"+EXP_NAME_net+'_'+DAT_NAME+"_CERES.dat","w") as f:
                for jj in range(yrs.size):
                    if np.isnan(val[jj]):
                        continue
                    if yrs[jj] < 2001:
                        continue
                    write_line(f, yrs[jj], val[jj])
            F_ALL[j_dat,iyears] = val.copy()


        R_OBSMEAN = np.nanmean(R_ALL,axis=0)
        with open("./figures/R_"+EXP_NAME_net+"_OBSmean.dat","w") as f:
            for jj in range(years_all.size):
                if np.isnan(R_OBSMEAN[jj]):
                    continue
                write_line(f, years_all[jj], R_OBSMEAN[jj])

        F_OBSMEAN = np.nanmean(F_ALL,axis=0)
        with open("./figures/F_"+EXP_NAME_net+"_OBSmean_deepC.dat","w") as f:
            for jj in range(years_all.size):
                if np.isnan(F_OBSMEAN[jj]):
                    continue
                if years_all[jj] > 2001:
                    continue
                write_line(f, years_all[jj], F_OBSMEAN[jj])
        with open("./figures/F_"+EXP_NAME_net+"_OBSmean_CERES.dat","w") as f:
            for jj in range(years_all.size):
                if np.isnan(F_OBSMEAN[jj]):
                    continue
                if years_all[jj] < 2001:
                    continue
                write_line(f, years_all[jj], F_OBSMEAN[jj])

    yrs = sims_dict[EXP_NAME_net]["obs_ERA5_deepC+CERES"]["years"]
    val = sims_dict[EXP_NAME_net]["obs_ERA5_deepC+CERES"]["N_truth"][0,:]
    with open("./figures/N_deepC.dat","w") as f:
        for jj in range(yrs.size):
            if np.isnan(val[jj]):
                continue
            if yrs[jj] > 2001:
                continue
            write_line(f, yrs[jj], val[jj])
    with open("./figures/N_CERES.dat","w") as f:
        for jj in range(yrs.size):
            if np.isnan(val[jj]):
                continue
            if yrs[jj] < 2001:
                continue
            write_line(f, yrs[jj], val[jj])

    # Forster_y, Forster_F
    with open("./figures/F_Forster.dat","w") as f:
        for jj in range(Forster_y.size):
            write_line(f, Forster_y[jj], Forster_F[jj])