In [1]:
import shap
import keras
import numpy as np
from keras.models import Sequential, Model
from keras.layers import Dense, Reshape, Flatten, Input, Dropout
from keras.optimizers import Adam, SGD, RMSprop
from sklearn.preprocessing import normalize
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

In [2]:
# nonstationary forcing 
input_data = np.load('data/input_sal_temp_ns.npy')
output_data = np.load('data/output_sal_temp_ns.npy')

In [3]:
X_train = input_data[0:30000]
y_train = output_data[0:30000]
X_val = input_data[30000:50000]
y_val = output_data[30000:50000]
X_test = input_data[50000:100000]
y_test = output_data[50000:100000] 

In [4]:
scaler = StandardScaler()
scaler.fit(X_train)
scaler.mean_,scaler.scale_

X_train_scaled = scaler.transform(X_train)
X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)
input_scaled = scaler.transform(input_data)

Y_train = keras.utils.to_categorical(y_train)
Y_val = keras.utils.to_categorical(y_val)
Y_test = keras.utils.to_categorical(y_test)

In [5]:
X_train100 = shap.sample(X_train_scaled, 100)
X_val100 = shap.sample(X_val_scaled, 100)
shap_values_list = []

In [None]:
for i in range(10):
    
    print(i)
    
    model = Sequential()
    model.add(Dense(12, input_dim=5, activation='relu'))
    model.add(Dense(6, activation='relu'))
    model.add(Dense(3, activation='relu'))
    model.add(Dense(2, activation='softmax'))
    
    model.load_weights('models_nonrandom/stommel_model_ts_ns_'+str(i)+'.h5')
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    
    explainer = shap.KernelExplainer(model.predict, X_train100, feature_names=['time', 'salinity', 'temperature', 'salinity forcing', 'temperature forcing'])
    shap_values = explainer.shap_values(X_val100)
    shap_values_list.append(shap_values)    

0


  0%|          | 0/100 [00:00<?, ?it/s]



In [None]:
shap_values_list_np = np.array(shap_values_list)

In [None]:
shap_values_list_np.shape

In [None]:
shap_values_mean = np.mean(shap_values_list_np, axis=0)

In [None]:
shap.summary_plot(shap_values_mean[1], X_val100, feature_names=['time', 'salinity', 'temperature', 'salinity forcing', 'temp forcing'])

In [None]:
shap.summary_plot(shap_values_mean[0], X_val100, feature_names=['time', 'salinity', 'temperature', 'salinity forcing', 'temp forcing'])

In [None]:
shap.dependence_plot(1, shap_values_mean[0], X_val100, feature_names=['time', 'salinity', 'temperature', 'salinity forcing', 'temp forcing'])

In [None]:
shap.dependence_plot(2, shap_values_mean[1], X_val100, feature_names=['time', 'salinity', 'temperature', 'salinity forcing', 'temp forcing'])

In [None]:
shap.dependence_plot(0, shap_values_mean[1], X_val100, feature_names=['time', 'salinity', 'temperature', 'salinity forcing', 'temp forcing'])

In [None]:
shap.dependence_plot(4, shap_values_mean[1], X_val100, feature_names=['time', 'salinity', 'temperature', 'salinity forcing', 'temp forcing'])

In [None]:
transformed_X = scaler.inverse_transform(X_val100)
qvals = []
alpha=0.2; 
beta=0.8; 
k=10e9
for x in transformed_X: 
    DeltaT = x[2]
    DeltaS = x[1]
    qvals.append(-k*(-alpha*DeltaT + beta*DeltaS))

In [None]:
k=1e9
alpha=0.2
beta=0.8

In [None]:
def get_q_color(DeltaT, DeltaS):
    q = -k*(-alpha*DeltaT + beta*DeltaS)
    if q < 0:
        return 'r'
    else:
        return 'b'

In [None]:
shap_values_mean[1][:, ].shape

In [None]:
shap_values_mean.shape

In [None]:
for i in range(100):
    # temp > salinity => green
    if shap_values_mean[0, i, 2] > shap_values_mean[0, i, 1]:
        plt.scatter(transformed_X[i, 0], -k*(-alpha*transformed_X[i, 2] + beta*transformed_X[i, 1])/1e9, marker='o', s=18, color='#2ca02c')
    else:
        plt.scatter(transformed_X[i, 0], -k*(-alpha*transformed_X[i, 2] + beta*transformed_X[i, 1])/1e9, marker='o', s=18, color='#1f77b4')

plt.xlabel('Time (kyr)');
plt.ylabel('AMOC (Sv)');
plt.title('Dependency of Model Prediction on Temperature and Salinity');

q = np.load('data/q_sal_temp_ns.npy')
plt.plot(input_data[30000:50000], q[30000:50000]/1e10, linestyle=':', color='grey')
plt.plot(input_data[30000:50000], input_data[30000:50000]*0, 'k--')

# Define custom legend labels and colors (0 corresponds to one color, 1 to another)
legend_labels = ['Temperature', 'Salinity', 'Circulation volume transport']  # Custom labels
colors = ['#2ca02c', '#1f77b4']  # Corresponding colors for the categories

# Create custom legend handles
legend_handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10) for color in colors]
grey_line_handle = plt.Line2D([0], [0], linestyle=':', color='grey')
legend_handles.append(grey_line_handle)

# Add the custom legend to the plot
plt.legend(legend_handles, legend_labels, loc='lower left');

#plt.savefig('dependency_plot_ns', dpi=300)