<a href="https://colab.research.google.com/github/adamggibbs/marine-carbonate-system-ml-prediction/blob/master/NN_Development_Framework.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ML Model Training Pipeline
This notebook will train a neural network to estimate pH given time, location, temperature, pressure, salinity, oxygen as inputs. This notebook automates the entire pipeline including quality control, data cleaning, data preprocessing, model training, and model evaluation. 

Before running this notebook:
1. Create a folder in your Google Drive to store all data and models
2. Run the Directory Setup Colab notebook to create the necessary directory structure for this notebook to run.
3. Place all training files you want to use for training in the 'training/' directory
4. Place all testing files you want to use for testing in the 'testing/' directory
5. Check User Defined variables at the top of this notebook

Once all these steps have been satisfied, go to the 'Runtime' tab and select 'Run All' or click ctrl+f9 to run the entire notebook. This notebook will take approximately **1 hour to run and fully complete** with this time varying based on the number of training and testing files as well as the size of the neural network you choose.

When complete this notebook will have trained and saved a neural network model in the 'models/' directory for you to use to estimate pH as described before. There is a Estimation notebook that can be used to input models and take in data and output the same input data but with pH estimations added. 

In [None]:
#@title # Set up environment.

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

import os
import math
import copy
from sklearn.utils import shuffle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('seaborn')
import seaborn as sns
sns.set_color_codes(palette='colorblind')

%tensorflow_version 2.x
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import metrics
from tensorflow.keras.layers.experimental import preprocessing
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error

# Necessary User Defined Variables
Specify the user defined variables as followed:
- `root_dir_name`  Name of the folder in Google Drive to store everything
- `model_name`  Name of the model for figures
- `input_vars` The input variables to be used (see vars_README.md for list)
- `output_var` The target variable to be estimated (see vars_README.md for list)
- `save_txt ` Boolean whether to save copies of intermediate files as .txt files
- `save_csv`  Boolean whether to save copies of intermediate files as .csv files
- `save_figs`  Boolean whether to save evaluation figures as own files (they will also be saved in notebook, but will be lost if you reset the notebook)
- `show_figs`  Boolean whether to show the figures produced in the notebook output as they are made

** All values filled in are basic defaults

In [None]:
# name of folder in Google Drive 
# should be a relative path from 'My Drive' and end with trailing '/'
# exclude beginning '/'
root_dir_name = 'Example/'

# desired name of trained model
model_name = 'model_name'

input_vars = ['DATE', 'LATITUDE', 'LONGITUDE', 'PRS', 'TMP', 'SAL', 'OXYGEN']
output_var = 'PH_INSITU'

# what file type you're using,
# 'csv' for any comma separated value, 'txt' for any tab separated value
input_file_type = 'csv'

# save options for intermediate data files
# can save both or either .txt and .csv files
save_txt = False
save_csv = True

# display figures?
show_figs = True
# save figures?
save_figs = False

### Other User Defined Variables

In [None]:
# reduce density (by)
reduce_density = False
density_reduction_factor = 10

# model hyperparams
# list of the hidden layers as the number of neurons they have
model_layers = [48,48]
# list of activation functions of each layer
# if empty, all will be set to 'sigmoid'
# lenght of act_funcs must match length of model_layers
act_funcs = []

# model evaluation and fig settings
shallow = 200
d_label = "\n(Depths <{})".format(shallow)
m_col = 'b'
c_col = 'g'

# Training Pipeline

In [None]:
#@title ### Initialize directory variables.
# data dirs
root_dir = '/content/drive/MyDrive/'+ root_dir_name
data_dir = root_dir + 'data/'
training_dir=data_dir + 'training/'
testing_dir = data_dir + 'testing/'
model_dir = root_dir + 'models/'
fig_dir = root_dir + 'figs/'

In [None]:
#@title ### Initialize data cleaning preprocessing functions.
# FUNCTION TO READ GLIDER FILE IN A PANDAS DATAFRAME
def read_glider_file(file):
  names = [ 'Cruise', 'Station', 'Type',	'DATE', 'TIME', 'LONGITUDE',	
         'LATITUDE',	'QF',	'PRS', 'PRS_QF', 'TMP', 'TMP_QF', 
         'SAL', 'SAL_QF', 'Sigma_theta', 'ST_QF', 'DEPTH', 'DEPTH_QF', 
         'OXYGEN', 'OXYGEN_QF',	'SATOXY',	'SATOXY_QF',	'NITRATE', 
         'NITRATE_QF', 'CHL_A', 'CHL_A_QF', 'BBP700', 'BBP700_QF', 'PH_INSITU', 
         'PH_INSITU_QF', 'BBP532', 'BBP_532_QF', 'CDOM', 'CDOM_QF', 'TALK_CANYONB',	
         'TALK_QF', 'DIC_CANYONB', 'DIC_QF', 'pCO2_CANYONB', 'pCO2_QF', 
         'SAT_AR_CANYONB', 'SAT_AR_QF', 'pH25C_1atm', 'pH25C_1atm_QF' ]

  if input_file_type == 'csv':
    df = pd.read_csv(file, header=0, sep=',')
  else:
    df = pd.read_csv(file, skiprows=7, header=None, sep='\t', names=names)
  
  df = df.dropna(axis=0, how='any').reset_index(drop=True)
  
  return df

################################################################################

# CREATE FUNCTION TO CREATE A NUMPY ARRAY OF INPUTS FROM
# GLIDER DATA FILE

'''
process_glider_input()
  description:
    This function reads in a data file in csv format and
    creates a pandas dataframe from it. From there it loops through
    and removes all bad data points according to the quality control
    flags. It then takes the desired input parameters as sepcified on 
    line 84 and puts them into a numpy array.

  args:
    file: string that contains file name of dataset
'''
def process_glider_file(file, save_txt=False, save_csv=False):

  df = read_glider_file(file)

  # throw away first day
  start_date = int(df['DATE'][0][3:5])
  start_time = float(df['TIME'][0][0:2]) + float(df['TIME'][0][3:5]) / 60

  drop_index = 0
  for index, row in df.iterrows():
    curr_date = int(row['DATE'][3:5])
    curr_time = float(row['TIME'][0:2]) + float(row['TIME'][3:5]) / 60
    if (curr_date > start_date and curr_time > start_time) or curr_date > start_date + 1:
      drop_index = index
      break

  # drop first day of data
  df = df.drop(index=df.index[:drop_index], axis=0).reset_index(drop=True)

  # take only data we care about
  df = df[['DATE', 'LATITUDE', 'LONGITUDE', 'PRS', 'PRS_QF', 'TMP', 'TMP_QF',
          'SAL', 'SAL_QF', 'OXYGEN', 'OXYGEN_QF', 'SATOXY', 'SATOXY_QF',
          'PH_INSITU', 'PH_INSITU_QF', 'TALK_CANYONB', 'TALK_QF', 'DIC_CANYONB',
          'DIC_QF', 'pCO2_CANYONB', 'pCO2_QF']]

  for input_var in input_vars:
      if input_var in ['DATE', 'LATITUDE', 'LONGITUDE']:
        continue
      # drop bad inputs
      index = 0
      to_drop = []
      for flag in df[input_var + '_QF']:
        if int(flag) > 0:
          to_drop.append(index)
        index += 1
      df = df.drop(to_drop).reset_index(drop=True)

  # drop bad outputs
  index = 0
  to_drop = []
  for flag in df[output_var + '_QF']:
    if int(flag) > 0:
      to_drop.append(index)
    index += 1
  df = df.drop(to_drop).reset_index(drop=True)

  # take subset of only parameters for inputs
  # this array contains only "good" data points
  inputs = df[input_vars]
  outputs = df[output_var]
  # convert dataframe in numpy array
  inputs = inputs.to_numpy(dtype='str')
  outputs = outputs.to_numpy(dtype='str')

  # change date format
  for row in inputs:
    date = row[0]
    row[0] = date[6:10] + date[0:2] + date[3:5]

  # return the array
  return inputs, outputs

################################################################################

def prep_data(inputs, outputs):

  # TRANSFORM DATE AND PRESSURE INPUTS

  # method to help transform date
  def date_to_nth_day(the_date):
    date = pd.to_datetime(the_date)
    new_year_day = pd.Timestamp(year=date.year, month=1, day=1)
    day_of_the_year = (date - new_year_day).days + 1
    return day_of_the_year

  # loop through inputs and perform transformations
  for input in inputs:
    # adjust date
    date = input[0]
    frac_year = date_to_nth_day(date) / 365.0
    input[0] = int(date[0:4]) + frac_year 
    
  return inputs, outputs

In [None]:
#@title ### Load, clean, and preprocess data.

# create arrays to store train inputs and labels
train_input_arrays = []
train_output_arrays = []

# for each file in our directory clean and preprocess data
print("Processing the following training files:")
for file in os.listdir(training_dir):
  display(training_dir+file)

  curr_inputs, curr_outputs = process_glider_file(training_dir+file)
  curr_inputs, curr_outputs = prep_data(curr_inputs, curr_outputs)
  train_input_arrays.append(curr_inputs)
  train_output_arrays.append(curr_outputs)

print("Complete\n")
# combine all arrays of inputs and labels
train_inputs = np.concatenate(train_input_arrays)
train_outputs = np.concatenate(train_output_arrays)

# convert all inputs and outputs to float type
train_inputs = train_inputs.astype('float')
train_outputs = train_outputs.astype('float')

# if we're reducing density, reduce it
if reduce_density:
  indices = range(0,train_outputs.shape[0],density_reduction_factor)
  train_inputs = np.take(train_inputs, indices, axis=0)
  train_outputs = np.take(train_outputs, indices, axis=0)

# shuffle training data
train_inputs, train_outputs = shuffle(train_inputs, train_outputs, 
                                      random_state=101)

# display some metadata about training inputs and labels
print("Shape of training inputs: " + str(train_inputs.shape))
print("Shape of training outputs: " + str(train_outputs.shape))

num_print_rows = 5
print("First {} inputs:".format(num_print_rows))
display(train_inputs[0:num_print_rows])
print("First {} outputs:".format(num_print_rows))
display(train_outputs[0:num_print_rows])

## Create and Train Model

In [None]:
#@title ### Build and compile model.
# METHOD TO BUILD AND COMPILE MODEL

# name of the model to save
model_name = model_name + "_Layers({})".format(model_layers)

#create model
model = keras.Sequential()

# create and add normalization layer
normalizer = preprocessing.Normalization(axis=-1)
normalizer.adapt(train_inputs)
model.add(normalizer)

# add hidden layers
if len(act_funcs) == 0:
  for layer in model_layers:
    model.add(layers.Dense(layer, activation='sigmoid'))
else:
  for i in range(0, len(model_layers)):
    model.add(layers.Dense(model_layers[i], activation=act_funcs[i]))

# add final output layer
model.add(layers.Dense(1))

# compile the model
model.compile(loss='mean_squared_error',
              optimizer=tf.keras.optimizers.Adam(0.001),
              metrics=[ metrics.MeanAbsoluteError(),
                        metrics.MeanSquaredError(),
                        metrics.RootMeanSquaredError() ])

# Create model and display summary
print('Model Name: ' + model_name)
model.summary()

In [None]:
#@title ### Train model.
%%time
history = model.fit(
    train_inputs, train_outputs,
    validation_split=0.1, epochs=100,
    use_multiprocessing=True)

model.save(model_dir+model_name)

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss History')
plt.ylabel('Loss (mean_squared_error)')
plt.xlabel('Epoch')
plt.legend(['train', 'validation'], loc='upper right')
plt.show()


In [None]:
#@title # Visualize Accuracy on Training Set
test_predictions = model.predict(train_inputs).flatten()

ax = plt.axes(aspect='equal')

plt.scatter(train_outputs, test_predictions)

lims = [
    np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
    np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
]
plt.ylim(lims)
plt.xlim(lims)
plt.plot(lims,lims, color='black')

plt.title('pH Predictions vs Observations')
plt.xlabel('True Values pH')
plt.ylabel('Predictions pH')

plt.show()

# Model Evaluation

In [None]:
#@title # Get Testing Data
test_input_arrays = []
test_output_arrays = []

for file in os.listdir(testing_dir):
  display(testing_dir+file)

  curr_inputs, curr_outputs = process_glider_file(testing_dir+file)
  curr_inputs, curr_outputs = prep_data(curr_inputs, curr_outputs)
  test_input_arrays.append(curr_inputs)
  test_output_arrays.append(curr_outputs)

test_inputs = np.concatenate(test_input_arrays)
test_outputs = np.concatenate(test_output_arrays)

test_inputs = test_inputs.astype('float')
test_outputs = test_outputs.astype('float')

display(test_inputs.shape)
display(test_outputs.shape)

In [None]:
#@title # Depth Conversion

# FUNCTION: CONVERT PRESSURE TO DEPTH
# Python version of the following MATLAB function
'''
% SW_DPTH    Depth from pressure
%===========================================================================
% SW_DPTH   $Id: sw_dpth.m,v 1.1 2003/12/12 04:23:22 pen078 Exp $
%           Copyright (C) CSIRO, Phil Morgan 1992.
%
% USAGE:  dpth = sw_dpth(P,LAT)
%
% DESCRIPTION:
%    Calculates depth in metres from pressure in dbars.
%
% INPUT:  (all must have same dimensions)
%   P   = Pressure    [db]
%   LAT = Latitude in decimal degress north [-90..+90]
%         (lat may have dimensions 1x1 or 1xn where P(mxn).
%
% OUTPUT:
%  dpth = depth [metres]
%
% AUTHOR:  Phil Morgan 92-04-06  (morgan@ml.csiro.au)
%
% DISCLAIMER:
%   This software is provided "as is" without warranty of any kind.
%   See the file sw_copy.m for conditions of use and licence.
%
% REFERENCES:
%    Unesco 1983. Algorithms for computation of fundamental properties of
%    seawater, 1983. _Unesco Tech. Pap. in Mar. Sci._, No. 44, 53 pp.
%=========================================================================
'''
def pres_to_depth(pres, lat):

  # define constants
  DEG2RAD = math.pi/180
  c1 = 9.72659
  c2 = -2.2512E-5
  c3 = 2.279E-10
  c4 = -1.82E-15
  gam_dash = 2.184E-6

  # convert latitude
  LAT = abs(lat)
  X = math.sin(LAT*DEG2RAD)
  X = X*X

  # calculate denomenator and numerator
  denom = 9.780318*(1.0+(5.2788E-3 + 2.36E-5*X)*X) + gam_dash*0.5*pres 
  numer = (((c4 * pres + c3) * pres + c2) * pres + c1) * pres

  # return quotient
  return numer / denom


# FUNCTION: CONVERT ARRAY OF INPUTS TO AN ARRAY OF DEPTHS
def get_depths(input_arr):
  # create numpy array to store depths
  depths = np.zeros(len(input_arr))
  # loop through inputs and calculate depth
  ind = 0
  for input in input_arr:
    depths[ind] = pres_to_depth(input[3], input[1])
    ind += 1
  # return depth array
  return depths


# get depths from inputs
test_depths = get_depths(test_inputs)

In [None]:
#@title # Make Predictions

# make predictions on testing dataset with neural network
test_predictions = model.predict(test_inputs, verbose=1).flatten()

# GET DATAPOINTS FROM SHALLOW DEPTHS (<200m)
# loop through depths and get indices of deep datapoints 
ind = 0
inds = []
for depth in test_depths:
  if depth > shallow:
    inds.append(ind)
  ind += 1

# delete deep datapoints and store resulting array of shallow datapoints
shallow_test_depths = np.delete(test_depths, inds, axis=0)
shallow_test_predictions = np.delete(test_predictions, inds, axis =0)
shallow_outputs = np.delete(test_outputs, inds, axis=0)

# ensure there are no deep datapoints
count = 0
for depth in shallow_test_depths:
  if depth > shallow:
    count += 1
print('Number of datapoints of depth greater than {}m: {}'.format(shallow, count))

print('Shape of each shallow data array (all should be equal):')
display(shallow_test_depths.shape)
display(shallow_test_predictions.shape)
display(shallow_outputs.shape)

## Error Metrics
### MAE, MSE, RMSE

In [None]:
#@title ### Error metrics table. 
# TEST INITIAL PERFORMANCE ON TRAINING, VALIDATION, & TESTING SETS

model_metrics = np.zeros((1,6))

# get metrics for DNN
mae = mean_absolute_error(test_outputs, test_predictions)
mse = mean_squared_error(test_outputs, test_predictions)
rmse = math.sqrt(mse)
# round metrics to 5 decimal places
model_metrics[0][0] = np.round(mae, 5)
model_metrics[0][1] = np.round(mse, 5)
model_metrics[0][2] = np.round(rmse, 5)
# get metrics for DNN
s_mae = mean_absolute_error(shallow_outputs, shallow_test_predictions)
s_mse = mean_squared_error(shallow_outputs, shallow_test_predictions)
s_rmse = math.sqrt(s_mse)
# round metrics to 5 decimal places
model_metrics[0][3] = np.round(s_mae, 5)
model_metrics[0][4] = np.round(s_mse, 5)
model_metrics[0][5] = np.round(s_rmse, 5)

# create table
# labels for columns and rows
col_labels = ['MAE', 'MSE', 'RMSE', 'S-MAE*', 'S-MSE*', 'S-RMSE*']
row_labels = [model_name]

# add data to table
fig, ax = plt.subplots(1, figsize=(10,3.5))
table = ax.table(cellText=model_metrics, cellLoc='center', loc='center', 
                 rowLabels=row_labels, colLabels=col_labels)
# title plot
fig.suptitle('Error Metrics',
             ha='center', va='center', fontsize=20, weight='bold')
# scale plot
table.set_fontsize(15)
table.scale(1,4)
ax.axis('off')

plt.text(1,0, "* \"S-\" denotes \"Shallow\" which specifies depths less than 200m",
         ha='right', fontsize=12)

plt.tight_layout()

# save and show plot
if save_figs:
  plt.savefig(fig_dir+'combined_err_metrics.png', bbox_inches='tight')
if show_figs:
  plt.show()

## Estimates vs Observations 1-1 Plots

In [None]:
plt.rcParams["axes.edgecolor"] = "white"
plt.rcParams["axes.linewidth"]  = 0

In [None]:
#@title ### For overall dataset, plot predictions vs observations.

fig, ax1 = plt.subplots(1, 1, figsize=(10,5), sharey=True)
ax1.set_aspect('equal')

# plot scatter plots of estimations vs observations
# (x=observations, y=estimations, blue=our DNN, red=CANYON-B)
ax1.scatter(test_outputs, test_predictions, label=model_name,
            color=m_col, s=10, alpha=0.1)


ax1.set_xlabel('pH Measured')
ax1.xaxis.label.set_size(14)
ax1.set_ylabel('pH Estimated')
ax1.yaxis.label.set_size(14)

# get axis limits from min and max ofdata
lims = [
    np.min([ax1.get_xlim(), ax1.get_ylim()]),  # min of both axes
    np.max([ax1.get_xlim(), ax1.get_ylim()]),  # max of both axes
]

ax1.plot(lims,lims, color='k')
ax1.set_ylim(lims)
ax1.set_xlim(lims)

# title, label, and legend plot
fig.suptitle('pH Estimations vs Observations\n', y=1,
             va='center', fontsize=18, weight='bold')
ax1.set_title(model_name, fontsize=16, weight='bold')

plt.tight_layout()

# show plot
if save_figs:
  plt.tight_layout()
  plt.savefig(fig_dir+'est_v_obs.png', bbox_inches='tight')
if show_figs:
  plt.show()


In [None]:
#@title ### For shallow dataset, plot predictions vs observations.
fig, ax1 = plt.subplots(1, 1, figsize=(10,5), sharey=True)
ax1.set_aspect('equal')

# plot scatter plots of estimations vs observations
# (x=observations, y=estimations, blue=our DNN, red=CANYON-B)
ax1.scatter(shallow_outputs, shallow_test_predictions, label=model_name,
            color=m_col, s=10, alpha=0.1)

ax1.set_xlabel('pH Measured')
ax1.xaxis.label.set_size(14)
ax1.set_ylabel('pH Estimated')
ax1.yaxis.label.set_size(14)

# get axis limits from min and max ofdata
lims = [
    np.min([ax1.get_xlim(), ax1.get_ylim()]),  # min of both axes
    np.max([ax1.get_xlim(), ax1.get_ylim()]),  # max of both axes
]

ax1.plot(lims,lims, color='black')
ax1.set_ylim(lims)
ax1.set_xlim(lims)

# title, label, and legend plot
fig.suptitle('pH Estimations vs Observations' + d_label, y=1.025,
             va='center', fontsize=18, weight='bold')
ax1.set_title(model_name, fontsize=16, weight='bold')

plt.tight_layout()

# show plot
if save_figs:
  plt.savefig(fig_dir+'est_v_obs_shallow.png', bbox_inches='tight')
if show_figs:
  plt.show()


## Depth vs Error Plots

In [None]:
#@title ### Plot error vs depth for overall dataset.
# get errors (estimations - observations)
error = test_predictions - test_outputs

fig, ax = plt.subplots(1, figsize=(10,11))

# plot the scatter plots
ax.scatter(error, test_depths, label=model_name, 
           color=m_col, s=10, alpha=0.1)

# place a vertical line at x=0 to represent error=0
ax.axvline(x=0, color='black')
# invert y-axis to better represent depth
ax.invert_yaxis()
ax.set_xlabel('pH Error (Est - Obs)')
ax.xaxis.label.set_size(14)
ax.set_ylabel('Depth (m)')
ax.yaxis.label.set_size(14)

# label plot
fig.suptitle('pH Error vs Depth', y=1.025,
             fontsize=20, weight='bold')
ax.set_title(model_name, fontsize=16, weight='bold')

plt.tight_layout()

# show plot
if save_figs:
  plt.savefig(fig_dir+'err_v_depth.png', bbox_inches='tight')
if show_figs:
  plt.show()


In [None]:
#@title ### Plot Error vs depth for shallow dataset.
# get errors (estimations - observations)
shallow_error = shallow_test_predictions - shallow_outputs

fig, ax = plt.subplots(figsize=(10,10))

# plot the scatter plots
ax.scatter(shallow_error, shallow_test_depths, label=model_name, 
           color=m_col, s=10, alpha=0.1)

# place a vertical line at x=0 to represent error=0
ax.axvline(x=0, color='black')
# invert y-axis to better represent depth
ax.invert_yaxis()
ax.set_xlabel('pH Error (Est - Obs)')
ax.xaxis.label.set_size(14)
ax.set_ylabel('Depth (m)')
ax.yaxis.label.set_size(14)

# label plot
fig.suptitle('pH Error vs Depth' + d_label, y=1.025,
             va='center', fontsize=20, weight='bold')
ax.set_title(model_name, fontsize=16, weight='bold')

plt.tight_layout()

# save and show plot
if save_figs:
  plt.savefig(fig_dir+'err_v_depth_shallow.png', bbox_inches='tight')
if show_figs:
  plt.show()


## Heatmaps

In [None]:
plt.rcParams["axes.edgecolor"] = "black"
plt.rcParams["axes.linewidth"]  = 2

In [None]:
#@title ### Plot heatmap for error vs depth for overall dataset.

fig, ax1 = plt.subplots(1, figsize=(10,10))

ranges = [[-0.1, 0.1],
          [0, 1000]]

hh1 = ax1.hist2d(error, test_depths, 
                 range=ranges, bins=(150,100), cmap=plt.cm.plasma)
ax1.invert_yaxis()
ax1.axvline(x=0, color='w')
fig.colorbar(hh1[3], ax=ax1)

fig.suptitle('Heatmap of Error vs Depth', y=1.025,
             ha='center', va='center', fontsize=20, weight='bold')
ax1.set_title(model_name, fontsize=18, weight='bold')

ax1.set_xlabel('pH Error (Est - Obs)')
ax1.xaxis.label.set_size(16)
ax1.set_ylabel('Depth (m)')
ax1.yaxis.label.set_size(16)

plt.tight_layout()

if save_figs:
  plt.savefig(fig_dir + 'err_v_depth_heatmap.png', bbox_inches='tight')
if show_figs:
  plt.show()

In [None]:
#@title ### Plot heatmap for error vs depth for shallow dataset.
fig, ax1 = plt.subplots(1, figsize=(10,10))

ranges = [[-0.1, 0.1],
          [0,shallow]]

hh1 = ax1.hist2d(shallow_error, shallow_test_depths, 
                 range=ranges, bins=(150,100), cmap=plt.cm.plasma)
ax1.invert_yaxis()
ax1.axvline(x=0, color='w')
fig.colorbar(hh1[3], ax=ax1)

fig.suptitle('Heatmap of Error vs Depth' + d_label, y=1.025,
             ha='center', va='center', fontsize=20, weight='bold')
ax1.set_title(model_name, fontsize=18, weight='bold')

ax1.set_xlabel('pH Error (Est - Obs)')
ax1.set_ylabel('Depth (m)')

plt.tight_layout()

if save_figs:
  plt.savefig(fig_dir+'err_v_depth_heatmap_shallow.png', bbox_inches='tight')
if show_figs:
  plt.show()

# Conclusion

In [None]:
#@title # Finish notebook.
print("Notebook complete.")