In [None]:
import math
import xarray as xr
import numpy as np
import warnings

import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm, ListedColormap
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
from matplotlib import cm
import matplotlib.colors as colors
from glob import glob
from matplotlib import rc
import matplotlib
warnings.filterwarnings("ignore",category=matplotlib.MatplotlibDeprecationWarning)

import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.io.img_tiles as cimgt
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter

import pop_tools
from scipy.optimize import curve_fit, minimize
from joblib import Parallel, delayed

from mpl_toolkits.axes_grid1 import make_axes_locatable

import pandas as pd

In [None]:
path="./data/"

In [None]:
# @title Helper functions (hidden)
def angle_mean(a): # calculate average on a circle
  xbar = np.mean(np.cos(a*np.pi/180))
  ybar = np.mean(np.sin(a*np.pi/180))
  return math.atan2(ybar,xbar)*180.0/np.pi
def deg_to_dms(deg):
  d = int(deg)
  md = abs(deg - d) * 60
  m = int(md)
  return [d, int(md), (md - m) * 60]
def format_lat(lat):
  d,m,s = deg_to_dms(lat)
  return "%d°%s"%(abs(d),"N" if d>0 else "S")
def format_lng(lng):
  d,m,s = deg_to_dms(lng)
  return "%d°%s"%(abs(d),"E" if d>0 else "W")
#  return "%d°%s=%f,%f"%(abs(d),"E" if d>0 else "W", lng, d)
def format_latlng(lat,lng):
  return "%s %s"%(format_lat(lat), format_lng(lng))

In [None]:
all_curves = xr.open_dataset(f'{path}/all_curves_global.nc', decode_times=False)
all_curves

In [None]:
REGIONS = all_curves.dims["region"]
NPOLYGONS = [int(np.sum(~np.isnan(all_curves.isel(region=r, season=0, time=0).frac_ALK_excess_surf.values)))
             for r in range(REGIONS)]
SEASONS=4

# Read polygon masks
#Pacific_polygon_masks = xr.open_dataset(f'{path}/polygon_data/pacific_polygon_masks.nc')
#Atlantic_polygon_masks = xr.open_dataset(f'{path}/atlantic_polygon_masks.nc')
#South_polygon_masks = xr.open_dataset(f'{path}/south_polygon_masks.nc')
#Southern_polygon_masks = xr.open_dataset(f'{path}/southern_ocean_polygon_masks.nc')

Pacific_polygon_masks = np.load('./data/polygon_data/Pacific_final_polygon_mask.npy')
Atlantic_polygon_masks = np.load('./data/polygon_data/Atlantic_final_polygon_mask.npy')
South_polygon_masks = np.load('./data/polygon_data/South_final_polygon_mask_120EEZ_180openocean.npy')
Southern_polygon_masks = np.load('./data/polygon_data/Southern_Ocean_final_polygon_mask.npy')

print(final_polygon_mask_pacific.shape)

POLYGON_MASKS = [Pacific_polygon_masks, Atlantic_polygon_masks, South_polygon_masks, Southern_polygon_masks]
DATALEN = 180

TOPLAYER_DEPTH=10 #meters

print("REGIONS: ", REGIONS)
print("SEASONS:", SEASONS)
print("NPOLYGONS: ", NPOLYGONS)

In [None]:
# Read POP grid
# nlat: 384 nlon: 320
grid = pop_tools.get_grid('POP_gx1v7')[['TAREA', 'KMT', 'TLAT', 'TLONG', 'REGION_MASK']]
grid.to_netcdf(path="~/pop_grid_gx1v7.nc")

#grid = xr.open_dataset(f'{path}/pop_grid_gx1v7.nc', decode_times=False)
tlong = grid.TLONG.values
tlat = grid.TLAT.values

In [None]:
%%time
polygon_mask_map = {}
for r in range(len(NPOLYGONS)):
    num_polygon = NPOLYGONS[r]
    polygon_masks = POLYGON_MASKS[r]
    for p in range(num_polygon):
      mask = polygon_masks[p]

      index = np.where(mask > 0)
      _tlat=tlat[mask>0]
      _tlong=tlong[mask>0]
      polygon_mask_map[(r,p)] = {"mask":mask,
                                 "index":index,
                                 "tlat": _tlat,
                                 "tlng": _tlong,
                                 "mean_latlng":(np.mean(_tlat),
                                            angle_mean(_tlong))
      }

In [None]:
indexmap = np.full(tlong.shape, np.nan)
for r in range(len(NPOLYGONS)):  # 0-Pacific or 1-Atlantic
    num_polygon = NPOLYGONS[r]
    polygon_masks = POLYGON_MASKS[r]
    for p in range(num_polygon):  # number of polygons
      mask = polygon_masks[p]
      index = np.where(mask > 0)
      indexmap[:,:][index] = (mask * (r*1E6+p))[index]

# Load $\eta_{max}$ data from carbonate model


In [None]:
dDIC_dALK_all = xr.open_dataset('./data/dDIC_dALK_all.nc', decode_times=False)
eta_max_data = np.vstack([dDIC_dALK_all.ULAT.mean('nlon'),
                          dDIC_dALK_all.mean('nlon').mean('time').to_array()[0] ])
eta_max_data = eta_max_data[:,~np.isnan(eta_max_data[1])] # get rid of nan points
eta_max_data = eta_max_data[:,eta_max_data[0].argsort()] # ensure strict ordering in x
eta_max_data = np.pad(eta_max_data, ( (0,0), (1,1) ), mode="edge")
eta_max_data[0,0] = -90
eta_max_data[0,-1] = 90

In [None]:
import  scipy.interpolate

eta_max_func_raw = scipy.interpolate.PchipInterpolator(eta_max_data[0], eta_max_data[1], axis=0, extrapolate=True)
latrange = np.linspace(-180,180)

def gaussian(x, mu, sigma):
  return 1/(sigma*np.sqrt(2*np.pi)) * np.exp(-0.5*((x-mu)/sigma)**2)

kernel = gaussian(latrange, 0 , 20)
kernel /= np.sum(kernel)
smoothed =  np.convolve(kernel, eta_max_func_raw(latrange), mode="same" )
eta_max_func_smoothed = scipy.interpolate.PchipInterpolator(latrange, smoothed, axis=0, extrapolate=False)

# calculate surface weighted average
latrange_80_80 = np.linspace(-80,80)
waverage = np.sum(eta_max_func_raw(latrange_80_80)*np.cos(latrange_80_80/180*np.pi))/np.sum(np.cos(latrange_80_80/180*np.pi))

plt.plot(eta_max_data[0], eta_max_data[1], label="raw data", marker="o")
plt.plot(latrange, eta_max_func_raw(latrange), label="")
plt.plot(latrange, eta_max_func_smoothed(latrange), label="kernel smoothed", ls="dashed")
plt.axhline(waverage, ls="dotted", label="Global area weighted average")
plt.xlim(-90,90)
plt.ylim(0.75,0.95)
plt.ylabel('$\partial[DIC]/\partial [Alk]$');
plt.xlabel('Latitude')
plt.legend(loc="upper center")

eta_max_func = eta_max_func_raw

# Set up two-plume box model

In [None]:
def equilibration_curve(t, eta_max, ta, tl, tb):
    '''
    t: time
    eta_max: initial condition (this doubles as intrinsic "max" efficiency, ~0.81)
    ta: apparent e-folding time of gas exchange early reservoir
    tl: e-folding time of transfer to second reservoir
    tb: apparent e-folding time of gas exchange for the late reservoir
    '''
    tal = 1.0/(1/ta+1/tl)
    Q = (tl*ta - tl*tb)/(ta*tl - tb*tl - ta*tb)
    return eta_max * (1
            -  Q * np.exp(-t/tal)
            -  (1-Q) * np.exp(-t/tb))

In [None]:
t=np.linspace(0,15*12,100)
plt.plot(t, equilibration_curve(t, 0.85, 10,  20,  20))
plt.plot(t, equilibration_curve(t, 0.85, 10,  20, 200))
plt.plot(t, equilibration_curve(t, 0.85,  3,  80,  40))
plt.title("Example equilibration curves")
plt.ylabel("$\eta(t)$")
plt.xlabel("t (months)" )
plt.show()

In [None]:
def dilution_curve(t, dila, tl, dilb, pulsewidth = 1.0):
  A = dila
  B = dilb

  # if the pulse was instantaneous, the result is
  if pulsewidth==0.0:
    result = (A-B)*np.exp(-t/tl) + B
    return result

  # But with a finite pulse length we have a convolution between the
  # underlying exponential and the wide pulse:

  # A correction factor due to the convolution with the finite-width pulse
  correction_factor = (tl/pulsewidth)*(np.exp(pulsewidth/tl)-1)
  result = np.where(t<pulsewidth,
     (A-B)*(tl/pulsewidth)*(1-np.exp(-t/tl))        + (t/pulsewidth)*B,
     (A-B)*correction_factor*np.exp(-t/tl)          +                B)

  return result

t=np.linspace(0.0,15*12,500)
plt.plot(t, dilution_curve(t, dila=1/10, tl=40, dilb=1/50))
plt.plot(t, dilution_curve(t, dila=1/20, tl=40, dilb=1/50))
plt.plot(t, dilution_curve(t, dila=1/20, tl=10, dilb=1/100))
plt.title("Examples: Dilution of excess surface alkalinity")
plt.ylabel("$\mu(t)$")
plt.xlabel("t (months)")
plt.show()

In [None]:
# Simultaneously fit the equilibration curve and the surface dilution curve to the above functions for a given
# polygon, represented by a tuple of three indices: r,s and p.
def fit_simultaneous(r,s,p, N_skip = 6,
                     local_eta_max = None,
                     show=False, only_data=False, figaxs=None):
  assert r<REGIONS
  assert s<SEASONS
  assert p<NPOLYGONS[r]
  eq_y_data_raw = all_curves.isel(region=r, season=s, polygon=p).OAE_efficiency[3*s:3*s+DATALEN]
  eq_x_data_raw = np.arange(0,len(eq_y_data_raw)) + 0.5
  eq_y_data = eq_y_data_raw[N_skip:]
  eq_x_data = eq_x_data_raw[N_skip:]

  dil_y_data_raw = all_curves.isel(region=r, season=s, polygon=p).frac_ALK_excess_surf.values[3*s:3*s+DATALEN]
  dil_x_data_raw =  np.arange(0,len(dil_y_data_raw)) + 0.5 # t_data is relative to start of injection period
  dil_y_data = dil_y_data_raw[N_skip:]
  dil_x_data = dil_x_data_raw[N_skip:]

  if local_eta_max is None:
    lat = polygon_mask_map[(r,p)]["mean_latlng"][0]
    #local_eta_max = eta_max_func(lat)
    local_eta_max = eta_max_func_smoothed(lat)

  # Define the cost function
  def cost_function_eq(params, eq_x, eq_y, dil_x, dil_y):
    eta_max, ta, tl, tb, dila, dilb = params
    return (np.sum((eq_y - equilibration_curve(eq_x, eta_max, ta, tl, tb))**2)
            + 0.2*np.exp(20*np.clip(ta-tb,-1000,1)))  ## A penalty factor to prevent tb << ta

  def cost_function_dil(params, eq_x, eq_y, dil_x, dil_y):
    eta_max, ta, tl, tb, dila, dilb = params
    return np.sum((dil_y - dilution_curve(dil_x, dila,  tl,  dilb))**2)

  def cost_function(params, eq_x, eq_y, dil_x, dil_y):
    relative_cost = 1.0
    return (cost_function_eq(params, eq_x, eq_y, dil_x, dil_y) +
            cost_function_dil(params, eq_x, eq_y, dil_x, dil_y)*relative_cost )

  bounds=[[ 0.79,  2,     6,     6, 0.001, 0.001],
          [ 0.88, 90,  1000, 10000,   1.0,   1.0]]
  initial1 = [  np.clip(0.84, bounds[0][0],bounds[1][0]),  # etamax
                np.clip(10,   bounds[0][1],bounds[1][1]),   # ta
                np.clip(15,   bounds[0][2],bounds[1][2]),   # tl
                np.clip(25,  bounds[0][3],bounds[1][3]),    # tb
                np.clip(0.2 , bounds[0][4],bounds[1][4]),   # dila
                np.clip(0.05, bounds[0][5],bounds[1][5]),   # dilb
               ]
  # Alternative set of starting conditions that sometimes find a better fit
  initial2 = [  np.clip(0.88, bounds[0][0],bounds[1][0]),   # etamax
                np.clip(20,   bounds[0][1],bounds[1][1]),   # ta
                np.clip(40,   bounds[0][2],bounds[1][2]),   # tl
                np.clip(125,  bounds[0][3],bounds[1][3]),   # tb
                np.clip(0.2 , bounds[0][4],bounds[1][4]),   # dila
                np.clip(0.05, bounds[0][5],bounds[1][5]),   # dilb
               ]

  ## Two attempts to fit it. If initial conditions dont give a good
  ## answer, try initial2 in case it converges to a lower result.
  results = []
  for attempt in range(2):

    #         etamax,              ta,    tl,    tb,   dila,  dilb
    bounds=[[ 0.79,  2,     6,     6, 0.001, 0.001],
            [ 0.88, 90,  1000, 10000,   1.0,  1.0]]
    initial = [initial1, initial2][attempt]

    bounds[0][0]=min(0.9, max(0.79, local_eta_max-0.01))
    bounds[1][0]=min(0.9, max(0.79, local_eta_max+0.01))
    initial[0] = local_eta_max

    result = minimize(fun=cost_function,
           x0 = initial,
           bounds=[(a,b) for a,b in np.array(bounds).T],
           method="L-BFGS-B",
           args=(eq_x_data, eq_y_data, dil_x_data, dil_y_data))
    etamax, ta, tl, tb, dila, dilb = result.x

    results.append(result)
    if result.fun < 0.05:
      break # abort if the fit is good enough

  if len(results) > 1:
    result = results[0] if results[0].fun < results[1].fun else results[1]

  if not show:
    return etamax, ta, tl, tb, dila, dilb, result.fun,

  print("%1d,%1d,%3d,"%(r,s,p), ": %5f : %.3f %5.1f %5.1f %5.3f %5.3f %5.3f"%(result.fun, etamax, ta, tl, tb, dila, dilb))
  firstplot = True if figaxs is None else False
  fig, axs = plt.subplots(1,2, figsize=(10, 4)) if figaxs is None else figaxs

  # Plot the dilution
  ax1 = axs[0]
  ax1.plot(eq_x_data_raw, dil_y_data_raw, c="blue", ls="dotted",
           lw=3, label="Data, loc: %d, %d, %d, %s"%(r,s,p,format_latlng(*polygon_mask_map[(r,p)]["mean_latlng"])))
  ax1.plot(eq_x_data_raw[1:], dilution_curve(eq_x_data_raw[1:], dila,  tl,  dilb), c="black",
          label="Fit, Nskip=%d, err=%0.5f \n$dil_{a}$=%.4fm, $\\tau_\ell$=%.1fmo, $dil_{b}$=%.4fm"%(
              N_skip, result.fun, dila,  tl,  dilb))
  ax1.set_ylim(0.0001,max(0.1, np.max(dil_y_data_raw)))
  ax1.legend(loc="upper right")
  ax1.set_ylabel("Fraction Alk ($A_{0}$)")
  ax1.set_xlabel("time/mo")

  # Plot the equilibration
  ax=axs[1]
  ax.plot(eq_x_data_raw, eq_y_data_raw,  c="red", ls="dotted", lw=3, label=None)
  if not only_data:
    ax.plot(eq_x_data_raw, equilibration_curve(eq_x_data_raw, etamax, ta, tl, tb),
             label = r"%s: $\eta$=%.3f $\tau_a$=%.1fm " "\n" r"$\tau_\ell$=%.1fm,  $\tau_b$=%.1fm"%(["Jan","Apr","Jul","Oct"][s],etamax, ta, tl, tb),
             c="black")
  ax.set_ylim(0,1.0)
  ax.legend(loc='upper right')
  ax.set_ylabel("$\eta(t)$")
  ax.set_xlabel("time/mo")

  return fig, axs


In [None]:
def plot_location(r,s,p,  ax, no_left=False, no_right=False, label=""):
  assert r<REGIONS
  assert s<SEASONS
  assert p<NPOLYGONS[r]
  etamax, ta, tl, tb, dila, dilb, err = fit_simultaneous(r,s,p,  show=False)
  eq_y_data_raw = all_curves.isel(region=r, season=s, polygon=p).OAE_efficiency[3*s:3*s+DATALEN]
  eq_x_data_raw = np.arange(0,len(eq_y_data_raw)) + 0.5
  dil_y_data_raw = all_curves.isel(region=r, season=s, polygon=p).frac_ALK_excess_surf.values[3*s:3*s+DATALEN]
  dil_x_data_raw =  np.arange(0,len(dil_y_data_raw)) + 0.5 # t_data is relative to start of injection period

  # Plot the dilution
  ax2 = ax.twinx()
  ax1 = ax

  # A virtual line that's just there to get the location traits into the lagend
  ln0 = ax1.plot(eq_x_data_raw[1:],eq_x_data_raw[1:]*0, alpha=0.0, c="w",
                 label="%s, %s, %s"%(label, ["Jan", "Apr", "Jul", "Oct"][s],
                                 format_latlng(*polygon_mask_map[(r,p)]["mean_latlng"])))

  ax1.plot(eq_x_data_raw[1:], dil_y_data_raw[1:], c="blue", ls="dotted",
           lw=3, label="Data, loc: %d, %d, %d, %s"%(r,s,p,format_latlng(*polygon_mask_map[(r,p)]["mean_latlng"])))
  ln1 = ax1.plot(eq_x_data_raw[1:], dilution_curve(eq_x_data_raw[1:], dila,  tl,  dilb), c="blue",
          label="$\\mu$(t): $\\mu_{a}$=%.3f, $\\tau_\ell$=%.fm, $\\mu_{b}$=%.3f"%(dila,  tl,  dilb))
  ax1.set_ylim(0.0001,0.20)
  ax1.set_yticks([0.00, 0.02,0.04,0.06,0.08,0.10,0.12,0.14,0.16])

  if not no_left:
    ax1.set_ylabel("Frac. Surface Alk, $\\mu(t)$", color='b')
  #else:
  #  ax1.set_yticklabels("")

  ax1.spines['left'].set_color('blue')
  ax1.tick_params(axis='y', colors='blue')
  ax1.yaxis.label.set_color('blue')

  # -----------------------
  # Plot the equilibration
  ax2.plot(eq_x_data_raw, eq_y_data_raw,  c="red", ls="dotted", lw=3, label=None)
  ln2 = ax2.plot(eq_x_data_raw, equilibration_curve(eq_x_data_raw, etamax, ta, tl, tb),
             label = r"$\eta(t)$: $\eta_{max}$=%.2f $\tau_a$=%.fm  $\tau_b$=%.fm"%(etamax, ta, tb),
             c="red")
  ax2.set_ylim(0,1.4)
  ax2.set_yticks([0.0,0.2,0.4,0.6,0.8,1.0])

  lns = ln0+ln2+ln1
  labs = [l.get_label() for l in lns]
  ax.legend(lns, labs, loc="upper right") #, mode="expand", bbox_to_anchor=(0.0, 0.0, 1.0, 1.0))

  if not no_right:
    ax2.set_ylabel("$\eta(t)$  [mol/mol]", color='r')

  ax2.spines['left'].set_color('red')
  ax2.tick_params(axis='y', colors='red')
  ax2.yaxis.label.set_color('red')
  ax2.set_xlabel("time/mo")
  ax2.set_xticks(np.arange(0,8)*24)
  ax2.xaxis.set_minor_locator(MultipleLocator(12))

  return fig, ax

In [None]:
def get_average_equilibration(r,s0,p, month=60):
  assert r<REGIONS
  assert p<NPOLYGONS[r]
  y_data=[]
  for s in range(SEASONS):
    y_data.append(all_curves.isel(region=r, season=s, polygon=p).OAE_efficiency[3*s:3*s+DATALEN])
  y_data = np.vstack(y_data)
  stddev = np.sqrt(np.var(y_data,axis=0))
  mean = np.mean(y_data,axis=0)
  return mean[month], stddev[month], y_data[s0][month]

In [None]:
get_average_equilibration(3,0,1, month=60-1)
all_curves.isel(region=3, polygon=1).OAE_efficiency

In [None]:
def fit_all_data(filename='~/simul_fit_horizontal.h5', regions=None, seasons=None, maxpoly=10000):
  columns=['r','s', 'p',
           'eta_max', 'ta', 'tl', 'tb', 'dila', 'dilb', 'err', 'lat', 'lng',
           'mean24', 'stddev24', 'eta24',
           'mean60', 'stddev60', 'eta60',
           'mean180', 'stddev180', 'eta180']
  dfs = []
  for r in (regions or range(REGIONS)):
    print("Region: ", r)
    for s in (seasons or range(SEASONS)):
      print(" Season: ", s)
      for p in range(min(maxpoly, NPOLYGONS[r])):
        if p%20==0: print("  Polygon: ", p)
        eta_max, ta, tl, tb, dila, dilb, err = fit_simultaneous(r, s, p,  show=False)

        mean24, stddev24, eta24  = get_average_equilibration(r,s,p, month=24-1)
        mean60, stddev60, eta60 = get_average_equilibration(r,s,p, month=60-1)
        mean180, stddev180, eta180 = get_average_equilibration(r,s,p, month=180-1)

        # calc avg lat and long of polygon
        lat = polygon_mask_map[(r,p)]["mean_latlng"][0]
        lng = polygon_mask_map[(r,p)]["mean_latlng"][1]
        dfs.append(pd.DataFrame([[r, s, p,  eta_max, ta, tl, tb, dila, dilb, err, lat, lng,
                                  mean24, stddev24, eta24,
                                  mean60, stddev60, eta60,
                                  mean180, stddev180, eta180]],
                                columns=columns))

  df = pd.concat(dfs)
  df = df.set_index(['r','s','p'], drop=False)
  if filename: df.to_hdf(filename,'residence_times')
  return df

qdf = fit_all_data(filename='~/simul_fit.h5', regions=[1], maxpoly=1)
qdf

# SLOW: Fit all polygons

In [None]:
%%time
from multiprocessing import Pool
import time
def run_job(input_index):
  df_partial = fit_all_data(filename=None, seasons=[input_index])
  return df_partial

if True: ## Change this to True to recalculate all the fits
  processes_count = 4
  processes_pool = Pool(processes_count)
  dfs = processes_pool.map(run_job, range(4))
  df = pd.concat(dfs)
  df.to_hdf('~/analysis_final.h5','two_box_model')
else:
  df = pd.read_hdf(f'{path}/analysis_final.h5')
df

In [None]:
# Calculate Q (the mixing parameter) as well
df["Q"] = (df["tl"]* (df["ta"] - df["tb"]))/(df["ta"]*df["tl"] - df["tb"]*df["tl"] - df["ta"]*df["tb"] )
df["tal"] = 1/(1/df["tl"] + 1/df["ta"])

In [None]:
fig, axs = plt.subplots(2,4, figsize=(15, 10))
axs=axs.flatten()
_=axs[0].hist(df["eta_max"], bins=100)
axs[0].set_xlabel("$\eta_{max}$, median= %.3f"%np.nanmedian(df["eta_max"]))
_=axs[1].hist(df["ta"], bins=100)
axs[1].set_xlabel("$\\tau_a$, months, median= %.1f"%np.nanmedian(df["ta"]))
_=axs[2].hist(df["tl"], bins=100, range=[0,250])
axs[2].set_xlabel("$\\tau_\ell$, months, median= %.1f"%np.nanmedian(df["tl"]))
_=axs[3].hist(df["tb"], bins=100, range=[0,500])
axs[3].set_xlabel("$\\tau_b$, months, median= %.1f"%np.nanmedian(df["tb"]))
_=axs[4].hist(df["dila"], bins=100)
axs[4].set_xlabel("$\\mu_a$, median= %.1f"%np.nanmedian(df["dila"]))
_=axs[5].hist(df["dilb"], bins=100)
axs[5].set_xlabel("$\\mu_b$, median= %.1f"%np.nanmedian(df["dilb"]))
_=axs[6].hist(df["err"], bins=100)
axs[6].set_xlabel("err, median= %.4f"%np.nanmedian(df["err"]))
_=axs[7].hist(df["Q"],bins=100,  label="Q mixing")
axs[7].set_xlabel("Q, median= %.4f"%np.nanmedian(df["Q"]))
plt.show()

# FIGURE 4: Latitudinal crossections of parameters $\tau_a$,$\tau_\ell$,$\tau_b$

In [None]:
xi = np.arange(-100,100,2)
def geo_mean(x):
  return np.exp(np.mean(np.log(x)))

def average(xdata,ydata,xi,width=4):
  result = []
  for x in xi:
    yd = geo_mean(ydata[(xdata > x-width//2) & (xdata < x+width//2)])
    result.append(yd)
  return np.array(result)

fig, axs = plt.subplots(5,1, figsize=(4, 9.0),sharex=True)
axs = axs.flatten()
lati = np.linspace(-80,80,160)
lat=df["lat"]
lng=df["lng"]
def scatter_and_line(ax, x,y, s=0.2, label=None):
  _=ax.plot(lati, average(x,y,lati,width=10), label=label)
  _=ax.scatter(x,y, s=s)

scatter_and_line(axs[0],df[df["s"]==0]["lat"], df[df["s"]==0]["ta"], label="Boreal winter")
scatter_and_line(axs[0],df[df["s"]==2]["lat"], df[df["s"]==2]["ta"], label="Boreal summer")
axs[0].set_ylabel("$\\tau_a$  (months)")

scatter_and_line(axs[1], df[df["s"]==0]["lat"], df[df["s"]==0]["tl"], s=0.2, label="Boreal winter")
scatter_and_line(axs[1], df[df["s"]==2]["lat"], df[df["s"]==2]["tl"], s=0.2, label="Boreal summer")
axs[1].set_ylabel("$\\tau_\ell$  (months)")
axs[1].set_yticks([10,20,30,40,50,100])

scatter_and_line(axs[2], df[df["s"]==0]["lat"], df[df["s"]==0]["tb"], s=0.2, label="Boreal winter")
scatter_and_line(axs[2], df[df["s"]==2]["lat"], df[df["s"]==2]["tb"], s=0.2, label="Boreal summer")
axs[2].set_ylabel("$\\tau_b$  (months)")
axs[2].set_ylim(10,3000)
#axs[2].set_yticks([1,3,10,20,30,40,50,100])

scatter_and_line(axs[3],df[df["s"]==0]["lat"], df[df["s"]==0]["tb"]/df[df["s"]==0]["ta"], s=0.2, label="Boreal winter")
scatter_and_line(axs[3],df[df["s"]==2]["lat"], df[df["s"]==2]["tb"]/df[df["s"]==2]["ta"], s=0.2, label="Boreal summer")
axs[3].set_ylabel("$\\tau_b/\\tau_a$")

scatter_and_line(axs[4],df[df["s"]==0]["lat"], df[df["s"]==0]["ta"]/df[df["s"]==0]["tl"], s=0.2, label="Boreal winter")
scatter_and_line(axs[4],df[df["s"]==2]["lat"], df[df["s"]==2]["ta"]/df[df["s"]==2]["tl"], s=0.2, label="Boreal summer")
axs[4].set_ylabel("$\\tau_a/\\tau_\ell$")
axs[4].set_ylim(0.03,10)

axs[0].legend(loc = "upper left", bbox_to_anchor=(-0.03, 1.17, 1.0, 0.1), ncol=2)
axs[-1].set_xlabel("latitude (deg)")

for i,ax in enumerate(axs[0:5]):
  ax.set_yscale("log")
  ax.set_xticks([-75,-60,-45,-30,-15,0,15,30,45,60,75])
  ax.set_xticklabels(["‐75","‐60","‐45","‐30","‐15","0","15","30","45","60","75"])
  ax.yaxis.tick_right()
  ax.yaxis.set_label_position("right")
  ax.text(-0.07,0.85,chr( ord('a')+i ), fontsize=14, weight='bold', transform=ax.transAxes)
  ax.grid(which='major', axis='x', ls='dashed')

plt.subplots_adjust(hspace=0.1)
fig.set_dpi(200)
plt.show()

# FIGURE S6 Multifigure with individual fits

In [None]:
fig, axs = plt.subplots(4,2, figsize=(10, 12), sharex=True)
i=0
for y in range(axs.shape[0]):
  for x in range(axs.shape[1]):
    no_left  = x>0
    no_right = x<(axs.shape[1]-1)
    if i==0: plot_location(1, 0,   0, axs[y,x], no_left=no_left, no_right=no_right, label="Labrador sea")         # a
    if i==1: plot_location(1, 0,  42, axs[y,x], no_left=no_left, no_right=no_right, label="Norway")               # b
    if i==2: plot_location(1, 0,  63, axs[y,x], no_left=no_left, no_right=no_right, label="Newfoundland")         # c
    if i==3: plot_location(1, 0,  16, axs[y,x], no_left=no_left, no_right=no_right, label="North sea")            # d
    if i==4: plot_location(1, 0, 142, axs[y,x], no_left=no_left, no_right=no_right, label="Subtropical atlantic") # e
    if i==5: plot_location(1, 0, 137, axs[y,x], no_left=no_left, no_right=no_right, label="Subtropical atlantic") # f
    if i==6: plot_location(1, 0,  56, axs[y,x], no_left=no_left, no_right=no_right, label="Coast of Brazil")      # g
    if i==7: plot_location(1, 0, 129, axs[y,x], no_left=no_left, no_right=no_right, label="Equatorial Atlantic")  # h
    axs[y,x].text(0.02,0.9,chr(ord('a')+i), fontsize=14, weight='bold', transform=axs[y,x].transAxes)

    i+=1

fig.tight_layout()

In [None]:
fig, axs = plt.subplots(4,2, figsize=(9, 12), sharex=True)
i=0
for y in range(axs.shape[0]):
  for x in range(axs.shape[1]):
    no_left  = x>0
    no_right = x<(axs.shape[1]-1)
    if i==0: plot_location(0, 0, 155, axs[y,x], no_left=no_left, no_right=no_right, label="Bering sea")
    if i==1: plot_location(0, 0,   7, axs[y,x], no_left=no_left, no_right=no_right, label="Hawaii")
    if i==2: plot_location(0, 0, 126, axs[y,x], no_left=no_left, no_right=no_right, label="Subtropical pacific")
    if i==3: plot_location(0, 2, 126, axs[y,x], no_left=no_left, no_right=no_right, label="Subtropical pacific")
    if i==4: plot_location(0, 0, 199, axs[y,x], no_left=no_left, no_right=no_right, label="Pacific Equatiorial")
    if i==5: plot_location(2, 0, 189, axs[y,x], no_left=no_left, no_right=no_right, label="Southern Ocean")
    if i==6: plot_location(2, 0,  13, axs[y,x], no_left=no_left, no_right=no_right, label="Kerguelen")
    if i==7: plot_location(3, 2,  33, axs[y,x], no_left=no_left, no_right=no_right, label="Ross sea")
    axs[y,x].text(0.02,0.9,chr(ord('i')+i), fontsize=14, weight='bold', transform=axs[y,x].transAxes)
    i+=1

fig.tight_layout()

# Plot the parameters on a map

In [None]:
def make_map_from_dataset(df):
  names = ["eta_max", "ta", "tl", "tb", "dila", "dilb", "err", "mean60", "stddev60", "mean180", "stddev180", "Q"]
  data_vars = dict([(name, (["season", "nlat", "nlon"], np.full((4, *tlong.shape), np.nan))) for name in names])
  data_vars["rsp"] = (["season", "nlat", "nlon"], np.full((4, *tlong.shape), np.nan))

  for (r,s,p), row in df.iterrows():
    polygon_masks = POLYGON_MASKS[r]
    mask = polygon_masks[p]
    index = np.where(mask > 0)
    for name in names: data_vars[name][1][s,:,:][index] = (mask * row[name])[index]
    data_vars["rsp"][1][s,:,:][index] = (mask * (r*1E6+s*1E5+p))[index]

  # form a dataset
  whole_ds = xr.Dataset(
    data_vars = data_vars,
    coords=dict(
        TLONG=(["nlat", "nlon"], tlong),
        TLAT=(["nlat", "nlon"], tlat),
    ),
  )
  whole_ds['season'] = ['January', 'April', 'July', 'October']
  return whole_ds
whole_ds = make_map_from_dataset(df)

In [None]:
def plot_mapped_data(whole_ds, name, lims, title="title", single_panel=False, cmap_label="rainbow"):
  FONTSIZE = 13
  def modify(ax):
    # label the cells
    for (r,p),pd  in polygon_mask_map.items():
      lat,lng = pd["mean_latlng"]
      lng = lng - central_longitude
      if lng < -180: lng += 360
      ax.text(lng,lat, f'{p}', rotation='horizontal',
              va='center', ha='center', fontsize=FONTSIZE-8)
    ax.set_extent([00, 360, -80, 80], crs=ccrs.PlateCarree())
    ax.stock_img()
    lon_formatter = LongitudeFormatter(zero_direction_label=False)
    lat_formatter = LatitudeFormatter()
    ax.xaxis.set_major_formatter(lon_formatter)
    ax.yaxis.set_major_formatter(lat_formatter)
  central_longitude=260
  seasons=4
  if single_panel: seasons=1
  gridsize = 1 if single_panel else 2
  titles = ["$\\tau_{loss}$", "ML depth","Deep Reservoir depth"]
  fig = plt.figure(figsize=(10*gridsize,8.7/2.0*gridsize))
  axs=[]
  for j in range(seasons):
        ax = fig.add_subplot(gridsize, gridsize, j+1, projection=ccrs.PlateCarree(central_longitude=central_longitude))
        ax.pcolormesh(whole_ds.TLONG, whole_ds.TLAT, whole_ds.isel(season=j)[name],
                          transform=ccrs.PlateCarree(), cmap=cmap_label, vmin=lims[0], vmax=lims[1])

        if j < 2:
          ax.set_title(title, loc='center', fontsize=FONTSIZE+2)
        if j%2 == 0:
            ax.set_yticks(np.arange(-60, 80, 30), crs=ccrs.PlateCarree())
            ax.set_yticklabels(ax.get_yticks(), fontsize=FONTSIZE)
        if j>=2:
          ax.set_xticks(np.arange(0, 360, 60), crs=ccrs.PlateCarree())
          ax.set_xticklabels(ax.get_xticks(), fontsize=FONTSIZE)

        if not single_panel:
          ax.text(0, 70, f'{whole_ds.season.values[j]}', rotation='horizontal', va='center', ha='center', fontsize=FONTSIZE)
        modify(ax)
        axs.append(ax)
  return fig, axs

  def add_colorbar(x0, y0, vmin, vmax, label, cmap_label="rainbow"):
    '''
    x0, y0: start location for the colorbar
    vmin, vmax: range of the colorbar
    label: label of the colorbar'
    '''
    cax = fig.add_axes([x0, y0, 0.2, 0.03])  # [x0, y0, width, height]
    cmap = plt.colormaps[cmap_label]
    normalize = plt.Normalize(vmin=vmin, vmax=vmax)  # Normalize the color values
    sm = cm.ScalarMappable(cmap=cmap, norm=normalize)
    cbar = fig.colorbar(sm, cax=cax, shrink=0.9, label=label, orientation='horizontal')
    cbar.ax.tick_params(labelsize=FONTSIZE)

  add_colorbar(0.15, 0.05, lims[0], lims[1], name, cmap_label)
  plt.subplots_adjust(wspace=0.01, hspace=0.01)



# FIGURE1 Part 1:  Mean eta and variance at 5 and 15 yrs

In [None]:
from matplotlib.image import imread
def plot_4panels(whole_ds, data_array, lims, cmap, datascale="linear", title="title", figscale=1.0):
  FONTSIZE = 13
  def modify(ax, numbers=False):
    # label the cells
    if numbers:
     for (r,p),pd  in polygon_mask_map.items():
      lat,lng = pd["mean_latlng"]
      lng = lng - central_longitude
      if lng < -180: lng += 360
      ax.text(lng,lat, f'{p}', rotation='horizontal',
              va='center', ha='center', fontsize=FONTSIZE-8)
    ax.set_extent([00, 360, -80, 80], crs=ccrs.PlateCarree())
    ax.imshow(imread('./lightearth.jpg'),
              origin='upper', transform=ccrs.PlateCarree(),
              extent=[-180, 180, -90, 90])

    #ax.stock_img()
    #ax.coastlines()
    lon_formatter = LongitudeFormatter(zero_direction_label=False)
    lat_formatter = LatitudeFormatter()
    ax.xaxis.set_major_formatter(lon_formatter)
    ax.yaxis.set_major_formatter(lat_formatter)
  central_longitude=260

  gridsize = 2
  fig = plt.figure(figsize=(figscale*10*gridsize,figscale*8.7/2.0*gridsize))
  fig.set_dpi(200)

  for j in range(len(data_array)):
        ax = fig.add_subplot(gridsize, gridsize, j+1, projection=ccrs.PlateCarree(central_longitude=central_longitude))
        ax.pcolormesh(whole_ds.TLONG, whole_ds.TLAT,
                      np.log10(data_array[j]) if datascale == "log10" else data_array[j],
                      transform=ccrs.PlateCarree(), cmap=cmap[j], vmin=lims[j][0], vmax=lims[j][1])

        if j%2 == 0:
            ax.set_yticks(np.arange(-60, 80, 30), crs=ccrs.PlateCarree())
            ax.set_yticklabels(ax.get_yticks(), fontsize=FONTSIZE)
        if j>=2:
          ax.set_xticks(np.arange(0, 360, 60), crs=ccrs.PlateCarree())
          ax.set_xticklabels(ax.get_xticks(), fontsize=FONTSIZE)

        #ax.text(0, 70, title[j], rotation='horizontal', va='center', ha='center', fontsize=FONTSIZE)
        ax.text(-140, 70, title[j], rotation='horizontal', va='center', ha='left', fontsize=FONTSIZE)
        ax.text(0.02,0.9,['a','b','c','d'][j], fontsize=14, weight='bold', transform=ax.transAxes)
        modify(ax)


  def add_colorbar(x0, y0, vmin, vmax, label, cmap):
    '''
    x0, y0: start location for the colorbar
    vmin, vmax: range of the colorbars[0]
    label: label of the colorbar'
    '''
    cax = fig.add_axes([x0, y0, 0.32, 0.03])  # [x0, y0, width, height]
    cmap = cmap
    normalize = plt.Normalize(vmin=vmin, vmax=vmax)  # Normalize the color values
    sm = cm.ScalarMappable(cmap=cmap, norm=normalize)
    cbar = fig.colorbar(sm, cax=cax, shrink=0.9, orientation='horizontal')
    cbar.set_label(label=label, size=FONTSIZE)
    cbar.ax.tick_params(labelsize=FONTSIZE)

  add_colorbar(0.16, 0.00, lims[0][0], lims[0][1], "Mean($\eta$)", cmap[0])
  add_colorbar(0.55, 0.00, lims[1][0], lims[1][1], "Stddev($\eta$)", cmap[1])
  plt.subplots_adjust(wspace=0.01, hspace=0.01)


In [None]:
plot_4panels(whole_ds, [whole_ds.isel(season=0)["mean60"],
                        whole_ds.isel(season=0)["stddev60"],
                        whole_ds.isel(season=0)["mean180"],
                        whole_ds.isel(season=0)["stddev180"]],
                        [(0.3,0.9),(0,0.1),
                         (0.3,0.9),(0,0.1)],
             cmap = [cm.get_cmap('rainbow'),
                     plt.colormaps["viridis"],
                     cm.get_cmap('rainbow'),
                     plt.colormaps["viridis"]],
             title=["Mean($\eta$) at 5 yrs",
                    "Stddev($\eta$) at 5 yrs",
                    "Mean($\eta$) at 15 yrs",
                    "Stddev($\eta$) at 15 yrs"],
             figscale=0.6)

# FIGURE  1 Part 2: Cross sections

In [None]:
fig, axs = plt.subplots(1,3, figsize=(10, 2.5), sharey=True)
axs = axs.flatten()
lati = np.linspace(-80,80,160)
lat=df["lat"]
lng=df["lng"]

scatter_and_line(axs[0],df[df["s"]==0]["lat"], df[df["s"]==0]["eta24"],  s=0.2, label="Boreal winter")
scatter_and_line(axs[0],df[df["s"]==2]["lat"], df[df["s"]==2]["eta24"],  s=0.2, label="Boreal summer")

scatter_and_line(axs[1],df[df["s"]==0]["lat"], df[df["s"]==0]["eta60"],  s=0.2, label="Boreal winter")
scatter_and_line(axs[1],df[df["s"]==2]["lat"], df[df["s"]==2]["eta60"],  s=0.2, label="Boreal summer")

scatter_and_line(axs[2],df[df["s"]==0]["lat"], df[df["s"]==0]["eta180"], s=0.2, label="Boreal winter")
scatter_and_line(axs[2],df[df["s"]==2]["lat"], df[df["s"]==2]["eta180"], s=0.2, label="Boreal summer")

axs[0].set_ylabel("$\eta(t)$")
axs[2].legend(bbox_to_anchor=(0.5,0.5))
for i,ax in enumerate(axs):
  ax.set_xticks([-75,-60,-45,-30,-15,0,15,30,45,60,75])
  ax.set_xticklabels(["‐75","‐60","‐45","‐30","‐15","0","15","30","45","60","75"])
  ax.text(0.02,1.05,['e','f','g'][i], fontsize=14, weight='bold', transform=ax.transAxes)
  ax.set_title("After %s years"%([2,5,15][i]), y=0.16, pad=-14)
  ax.grid()
  ax.set_xlabel("lat")
  ax.set_ylim(0.0,0.9)

fig.tight_layout()
plt.subplots_adjust(wspace=0, hspace=0)
fig.set_dpi(200)
plt.show()

In [None]:
fig, axs = plt.subplots(1,2, figsize=(8, 3.0),sharey=True)
axs = axs.flatten()
lati = np.linspace(-80,80,160)
lat=df["lat"]
lng=df["lng"]

scatter_and_line(axs[0],df[df["s"]==0]["lat"], df[df["s"]==0]["eta24"],  s=0.2, label="$\eta$ after 2yrs")
scatter_and_line(axs[0],df[df["s"]==0]["lat"], df[df["s"]==0]["eta60"],  s=0.2, label="$\eta$ after 5yrs")
scatter_and_line(axs[0],df[df["s"]==0]["lat"], df[df["s"]==0]["eta180"], s=0.2, label="$\eta$ after 15yrs")
axs[0].set_ylabel("Mean($\eta$) after 5yrs")

scatter_and_line(axs[1],df[df["s"]==2]["lat"], df[df["s"]==2]["eta24"],  s=0.2, label="$\eta$ after 2yrs")
scatter_and_line(axs[1],df[df["s"]==2]["lat"], df[df["s"]==2]["eta60"],  s=0.2, label="$\eta$ after 5yrs")
scatter_and_line(axs[1],df[df["s"]==2]["lat"], df[df["s"]==2]["eta180"], s=0.2, label="$\eta$ after 15yrs")
axs[0].set_ylabel("Mean($\eta$) after 15yrs")

axs[1].yaxis.tick_right()
axs[1].yaxis.set_label_position("right")

axs[1].legend(loc = "lower right")
axs[0].set_xlabel("lat")
axs[1].set_xlabel("lat")
for ax in axs:
  #ax.set_yscale("log")
  ax.set_xticks([-75,-60,-45,-30,-15,0,15,30,45,60,75])
  ax.set_xticklabels(["‐75","‐60","‐45","‐30","‐15","0","15","30","45","60","75"])

axs[0].text(0.02,0.9,'a', fontsize=14, weight='bold', transform=axs[0].transAxes)
axs[1].text(0.02,0.9,'b', fontsize=14, weight='bold', transform=axs[1].transAxes)
axs[0].grid()
axs[1].grid()
fig.tight_layout()

plt.show()

#Do a seperate plot just for coastal locations ?

In [None]:
def plot_4panels_logscale(whole_ds, data_array, lims, cmap, title="title", figscale=1.0):
  FONTSIZE = 13
  def modify(ax, numbers=False):
    # label the cells
    if numbers:
     for (r,p),pd  in polygon_mask_map.items():
      lat,lng = pd["mean_latlng"]
      lng = lng - central_longitude
      if lng < -180: lng += 360
      ax.text(lng,lat, f'{p}', rotation='horizontal',
              va='center', ha='center', fontsize=FONTSIZE-8)
    ax.set_extent([00, 360, -80, 80], crs=ccrs.PlateCarree())
    ax.imshow(imread('./lightearth.jpg'),
              origin='upper', transform=ccrs.PlateCarree(),
              extent=[-180, 180, -90, 90])
    lon_formatter = LongitudeFormatter(zero_direction_label=False)
    lat_formatter = LatitudeFormatter()
    ax.xaxis.set_major_formatter(lon_formatter)
    ax.yaxis.set_major_formatter(lat_formatter)
  central_longitude=260

  gridsize = 2
  fig = plt.figure(figsize=(figscale*10*gridsize,figscale*8.7/2.0*gridsize))
  fig.set_dpi(200)

  for j in range(len(data_array)):
        ax = fig.add_subplot(gridsize, gridsize, j+1, projection=ccrs.PlateCarree(central_longitude=central_longitude))
        ax.pcolormesh(whole_ds.TLONG, whole_ds.TLAT,
                      np.log10(data_array[j]),
                      transform=ccrs.PlateCarree(), cmap=cmap[j], vmin=np.log10(lims[j][0]), vmax=np.log10(lims[j][1]) )

        if j%2 == 0:
            ax.set_yticks(np.arange(-60, 80, 30), crs=ccrs.PlateCarree())
            ax.set_yticklabels(ax.get_yticks(), fontsize=FONTSIZE)
        if j>=2:
          ax.set_xticks(np.arange(0, 360, 60), crs=ccrs.PlateCarree())
          ax.set_xticklabels(ax.get_xticks(), fontsize=FONTSIZE)

        ax.text(-140, 70, title[j], rotation='horizontal', va='center', ha='left', fontsize=FONTSIZE)
        ax.text(0.02,0.9,['a','b','c','d'][j], fontsize=14, weight='bold', transform=ax.transAxes)
        modify(ax)


  def add_colorbar(x0, y0, vmin, vmax, label, cmap):
    '''
    x0, y0: start location for the colorbar
    vmin, vmax: range of the colorbars[0]
    label: label of the colorbar'
    '''
    cax = fig.add_axes([x0, y0, 0.62, 0.03])  # [x0, y0, width, height]
    cmap = cmap
    normalize = plt.Normalize(vmin=np.log10(vmin), vmax=np.log10(vmax) )  # Normalize the color values
    sm = cm.ScalarMappable(cmap=cmap, norm=normalize)
    cbar = fig.colorbar(sm, cax=cax, shrink=0.9, orientation='horizontal')
    cbar.set_label(label=label, size=FONTSIZE)
    cbar.ax.tick_params(labelsize=FONTSIZE)

    ticks = [0.1,0.3,1,3,10,30,100,300,1000,3000,10000]
    cbar.ax.set_xticks([np.log10(t) for t in ticks if (t>=vmin and t<=vmax) ])
    cbar.ax.set_xticklabels([t for t in ticks if (t>=vmin and t<=vmax) ])

  add_colorbar(0.16, 0.00, lims[0][0], lims[0][1], "time (months)", cmap[0])
  plt.subplots_adjust(wspace=0.01, hspace=0.01)


In [None]:
plot_4panels_logscale(whole_ds, [whole_ds.isel(season=i)["ta"] for i in range(0,4)],
                      [(3,100)]*4,
             cmap = [cm.get_cmap('rainbow')]*4,
             title=["$\\tau_a$ January Release",
                    "$\\tau_a$ April Release",
                    "$\\tau_a$ July Release",
                    "$\\tau_a$ October Release"],
             figscale=0.6)

In [None]:
plot_4panels_logscale(whole_ds, [whole_ds.isel(season=i)["tl"] for i in range(0,4)],
                      [(3,301)]*4,
             cmap = [cm.get_cmap('rainbow')]*4,
             title=["$\\tau_\ell$ January Release",
                    "$\\tau_\ell$ April Release",
                    "$\\tau_\ell$ July Release",
                    "$\\tau_\ell$ October Release"],
             figscale=0.6)

In [None]:
plot_4panels_logscale(whole_ds, [whole_ds.isel(season=i)["tb"] for i in range(0,4)],
                      [(3,10001)]*4,
             cmap = [cm.get_cmap('rainbow')]*4,
             title=["$\\tau_b$ January Release",
                    "$\\tau_b$ April Release",
                    "$\\tau_b$ July Release",
                    "$\\tau_b$ October Release"],
             figscale=0.6)

In [None]:
plot_4panels(whole_ds, [whole_ds.isel(season=i)["eta_max"] for i in range(0,4)],
                      [(0.70,1.0)]*4,
             cmap = [cm.get_cmap('rainbow')]*4,
             title=["$\eta_{max}$ January Release",
                    "$\eta_{max}$ April Release",
                    "$\eta_{max}$ July Release",
                    "$\eta_{max}$ October Release"],
             figscale=0.6)

In [None]:
plot_4panels(whole_ds, [whole_ds.isel(season=i)["err"] for i in range(0,4)],
                      [(0.0,0.1)]*4,
             cmap = [cm.get_cmap('rainbow')]*4,
             title=["Residual error, January",
                    "Residual error, April",
                    "Residual error, July",
                    "Residual error, October"],
             figscale=0.6)


# Examples of complex alkalinity dynamics beyond this model

These location have some residual fitting error (See map above)

In [None]:
fig, axs = plt.subplots(6,2, figsize=(9, 18), sharex=True)
i=0
for y in range(axs.shape[0]):
  for x in range(axs.shape[1]):
    no_left  = x>0
    no_right = x<(axs.shape[1]-1)
    if i==0: plot_location(0, 0, 126, axs[y,x], no_left=no_left, no_right=no_right, label="")  #q
    if i==1: plot_location(2, 1, 225, axs[y,x], no_left=no_left, no_right=no_right, label="")  #r
    if i==2: plot_location(2, 0, 225, axs[y,x], no_left=no_left, no_right=no_right, label="")  #s
    if i==3: plot_location(0, 2,  34, axs[y,x], no_left=no_left, no_right=no_right, label="")  #t
    if i==4: plot_location(0, 0, 178, axs[y,x], no_left=no_left, no_right=no_right, label="")  #u
    if i==5: plot_location(0, 3, 156, axs[y,x], no_left=no_left, no_right=no_right, label="")  #v
    if i==6: plot_location(0, 3,  83, axs[y,x], no_left=no_left, no_right=no_right, label="")  #w
    if i==7: plot_location(0, 2,  83, axs[y,x], no_left=no_left, no_right=no_right, label="")  #x
    if i==8:  plot_location(0, 1, 83,  axs[y,x], no_left=no_left, no_right=no_right, label="") #y
    if i==9:  plot_location(0, 2, 161, axs[y,x], no_left=no_left, no_right=no_right, label="") #z
    if i==10: plot_location(1, 2, 128, axs[y,x], no_left=no_left, no_right=no_right, label="") #aa
    if i==11: plot_location(2, 0, 200, axs[y,x], no_left=no_left, no_right=no_right, label="") #ab

    axs[y,x].text(0.02,0.9,chr(ord('q')+i), fontsize=14, weight='bold', transform=axs[y,x].transAxes)
    i+=1

fig.tight_layout()