In [None]:
# This script performs the analysis of training a convolutional neural network and predicting y given X, where y and X are synthetic benchmark datasets, 
# as descirbed in Mamalakis et al. 2022. We also apply DeepSHAP to explain the predictions of the network. 

 
# citation: 
# Mamalakis, A., E.A. Barnes, I. Ebert-Uphoff (2022) “Investigating the fidelity of explainable 
# artificial intelligence methods for application of convolutional neural networks in geoscience,” 
# arXiv preprint https://arxiv.org/abs/2202.03407. 

# Editor: Dr Antonios Mamalakis (amamalak@colostate.edu)

In [None]:
#.............................................
# IMPORT STATEMENTS
#.............................................

# local env is AIgeo_new

#General Python math functions
import math
#Loading in data (netcdf files)
import h5py
#Handling data
import numpy as np
import netCDF4 as nc
#Plotting figures
import matplotlib.pyplot as plt #Main plotting package

#machine learning package
import tensorflow as tf
tf.compat.v1.disable_v2_behavior() 
print(tf.__version__)


#Interpreting neural networks 
import  shap


In [None]:
#.............................................
# LOAD DATA
#.............................................

# load matlab data with the synthetic benchmark
# This data was generated using the matlab script Gen_Synth_SHAPES

filepath = 'synth_data_shapes.mat'
DATA = {}
f = h5py.File(filepath)
for k, v in f.items():
    DATA[k] = np.array(v)
 
InputX = np.array(DATA['X'])
lats = np.array(DATA['lat'])
lons= np.array(DATA['lon'])
y_synth = np.array(DATA['y'])
Cnt_tr = np.array(DATA['Cnt'])
print('data is loaded') # print message 'data is loaded'

In [None]:
#.............................................
# DATA MANIPULATION AND SANITY PLOT
#.............................................

Cnt_tr=np.swapaxes(Cnt_tr,-1,1)
InputX=np.swapaxes(InputX,-1,1)

lats=lats.flatten()
lons=lons.flatten()
#Flatten the y time series 
y_synth=y_synth.flatten()

#sanity plot (just for checking I have read the data correclty)
X, Y = np.meshgrid(lons, lats) 
cs = plt.contourf(X, Y, Cnt_tr[9], cmap ="jet")   
cbar = plt.colorbar(cs)   
plt.title('matplotlib.pyplot.contourf() Example') 
plt.show()

In [None]:
#.............................................
# PREPARE THE DATA FOR TRAINING
#.............................................

# Rename the sst array to X (inputs) and Y (labels) to stick with machine learning convention
X_all = np.copy(InputX[...,np.newaxis])
Y_all = np.copy(y_synth)

# Change the Y (label) array values to 1 if the sample is above 0 and 0 if the sample is below
Y_all[Y_all > 0] = 1 # square frames cover more area 
Y_all[Y_all <= 0] = 0 # circular frames cover more area

# Convert the Y array into a categorical array. 
Y_all = tf.keras.utils.to_categorical(Y_all)

# Set the fraction of samples that will be used for validation
frac_validate = 0.1

# Separate the X and Y matrices into training and validation sub-sets
# For this problem, we will take the last fraction_validate fraction of samples as our validation dataset
X_train = X_all[:int(-frac_validate*len(X_all))]
Y_train = Y_all[:int(-frac_validate*len(Y_all))]

X_validation = X_all[int(-frac_validate*len(X_all)):]
Y_validation = Y_all[int(-frac_validate*len(Y_all)):]

#Create class weights for training the model. If the dataset is unbalanced, this helps ensure the model
# does not simply start guessing the class that has more samples.
#class_weight = class_weight_creator(Y_train)

#Calculate the number of inputs into the neural network (this will be helpful later on)
# This value is the number of latitudes times the number of longitudes
number_inputs = X_all.shape[-3:]

In [None]:
#.............................................
# BUILD THE CONVOLUTIONAL NEURAL NETWORK
#.............................................

model = tf.keras.models.Sequential()

model.add(tf.keras.layers.Conv2D(32,(5,5),strides=(2,2),activation='relu',padding='same',input_shape=number_inputs))
#model.add(tf.keras.layers.Conv2D(32,(5,5),strides=(1,1),activation='relu',padding='same'))
model.add(tf.keras.layers.MaxPooling2D(2))
model.add(tf.keras.layers.Conv2D(32,(5,5),strides=(1,1),activation='relu',padding='same'))
#model.add(tf.keras.layers.Conv2D(32,(5,5),strides=(1,1),activation='relu',padding='same'))
model.add(tf.keras.layers.MaxPooling2D(2))
model.add(tf.keras.layers.Conv2D(64,(3,3),strides=(1,1),activation='relu',padding='same'))
model.add(tf.keras.layers.MaxPooling2D(2))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(units=128,activation='relu'))
model.add(tf.keras.layers.Dense(units=64,activation='relu'))
model.add(tf.keras.layers.Dense(units=2,activation='softmax'))
#model.add(tf.keras.layers.Dense(1,activation='linear',use_bias=False))

#Define the learning rate of the neural network
learning_rate = 0.01

# We will use the stochastic gradient descent (SGD) optimizer, because we have control over
# the learning rate and it is effective for our problem.

model.compile(optimizer=tf.keras.optimizers.SGD(lr=learning_rate),
              loss = 'categorical_crossentropy', 
              metrics=['accuracy'] )

model.summary()    

In [None]:
#.............................................
# UNCOMMENT TO TRAIN THE NEURAL NETWORK
#.............................................

#batch_size = 128 #The number of samples the network sees before it backpropagates (batch size)
#epochs =  10 #The number of times the network will loop through the entire dataset (epochs)
#shuffle = True #Set whether to shuffle the training data so the model doesn't see it sequentially 
#verbose = 2 #Set whether the model will output information when trained (0 = no output; 2 = output accuracy every epoch)

###Train the neural network!
#model.fit(X_train, Y_train, validation_data=(X_validation, Y_validation), 
#          batch_size=batch_size, epochs=epochs, shuffle=shuffle, verbose=verbose) #, class_weight=class_weight)

In [None]:
#.............................................
# LOAD ALREADY TRAINED MODEL
#.............................................


# load model, including its weights and the optimizer
model = tf.keras.models.load_model('my_model_shapes.h5')
# Show the model architecture
model.summary()
# loss and accuracy in "new model"
loss, acc = model.evaluate(X_validation, Y_validation, verbose=2)
print('Restored model, categorical crossentropy: ', loss)
print('Restored model, categorical accuracy: ', acc)


In [None]:
#.............................................
# GET EXPLANATIONS FROM SHAP
#.............................................

import shap

# select a set of background examples to take an expectation over
background = X_train[np.random.choice(X_train.shape[0], 5000, replace=False)]
#background=np.zeros((1,X_train.shape[1]*X_train.shape[2],1))

# explain predictions of the model on three images
e = shap.DeepExplainer(model, background)
# ...or pass tensors directly
# e = shap.DeepExplainer((model.layers[0].input, model.layers[-1].output), background)

In [None]:
# get explanations
shap_values = e.shap_values(X_validation[[344,3566],:,:])

# plot the feature attributions
shap.image_plot(shap_values, -X_validation[[344,3566],:,:])

In [None]:
shap_values1 = np.array(shap_values)

shap_values1 = np.copy(shap_values1[:,:,:,:,0])

In [None]:
#.............................................
# SAVE SHAP RESULTS
#.............................................

In [None]:
fn = 'SHAP.nc'
ds = nc.Dataset(fn, 'w', format='NETCDF4')

time = ds.createDimension('time', 2) # this is essentially number of samples 
lat = ds.createDimension('lat', 65)
lon = ds.createDimension('lon', 65)
classes = ds.createDimension('classes',2)

times = ds.createVariable('time', 'f4', ('time',))
latss = ds.createVariable('lat', 'f4', ('lat',))
lonss = ds.createVariable('lon', 'f4', ('lon',))
value = ds.createVariable('SHAP', 'f4', ('time', 'classes','lat', 'lon'))
value.units = 'unitless'

latss[:] = np.copy(lats)
lonss[:] = np.copy(lons)
value[:] = np.copy(shap_values1)

print('var size ', value.shape)
ds.close()