In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import BoundaryNorm, ListedColormap
import matplotlib.cm as cm
from matplotlib import gridspec

import datetime
from datetime import datetime, timedelta
import os
import random
import seaborn as sns


%matplotlib inline

from pprint import pprint

import cartopy
import cartopy.crs as ccrs
import cartopy.io.img_tiles as cimgt
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
from cartopy.feature import NaturalEarthFeature
import cartopy.feature as cfeature


import cmocean

from mpl_toolkits.mplot3d import Axes3D

from timeit import default_timer as timer
from sklearn.cluster import KMeans
from scipy.spatial import ConvexHull, Delaunay
from scipy.interpolate import interp2d
from scipy import spatial
from glob import glob

import gsw


# SHB map

Obtain bathymetric data. You can use public datasets such as NOAA's ETOPO1. Download the data from: https://www.ngdc.noaa.gov/mgg/global/ and select the "ETOPO1 Global Relief Model" dataset.

In [None]:
da_map = xr.open_dataset('./data/ETOPO_2022_v1_60s_N90W180_bed.nc')

lon_min, lon_max = 14, 20
lat_min, lat_max =  -35, -29

da_sub = da_map.sel(lat=slice(lat_min, lat_max), lon= slice(lon_min, lon_max))

In [None]:
# Model grid from Fearon et al 2023
da = xr.open_dataset('./data/grid.nc')
earlySHBML = pd.read_pickle('./data/earlySHBML.pkl')

tlong = da.lon_rho.values
tlat = da.lat_rho.values
mask = da.mask_rho.isel(time=0).values
inshore_mask = xr.where(da.isel(time=0).h <= 100, 1, 0)  # inshore of 100m
mask_sel = mask*inshore_mask

ocean_indices = np.where(mask_sel==1)
tlong_sel = tlong[ocean_indices]
tlat_sel = tlat[ocean_indices]

In [None]:
def plot_map(ax):

    ax.set_extent([17, 18.5, -33.5, -31], crs=ccrs.PlateCarree())

    ax.set_xticks(np.arange(16.1, 19.5, 0.7), crs=ccrs.PlateCarree())
    ax.set_yticks(np.arange(-33.5, -30.5, 0.5), crs=ccrs.PlateCarree())
    lon_formatter = LongitudeFormatter(zero_direction_label=False)
    lat_formatter = LatitudeFormatter()
    ax.xaxis.set_major_formatter(lon_formatter)
    ax.yaxis.set_major_formatter(lat_formatter)  

    ax.add_feature(cfeature.LAND)

    colors = list(mcolors.TABLEAU_COLORS.values())
    ind_color = np.arange(len(colors)) # 0- 9

    edges_coord = []

    # water depth
    ca = da_sub.z.plot.contourf(ax=ax, levels=20, transform=ccrs.PlateCarree(), add_colorbar=False, alpha=0.5)

    # edges
    corners_long = [tlong[0, 0], tlong[0, -1], tlong[-1, -1], tlong[-1, 0], tlong[0, 0]]
    corners_lat = [tlat[0, 0], tlat[0, -1], tlat[-1, -1], tlat[-1, 0], tlat[0, 0]]

    # Plot the corners and connect them
    ax.plot(corners_long, corners_lat, '--', color='C3', alpha=0.7, )

    ax.scatter(tlong_sel, tlat_sel, c='gray', s=1, alpha=0.1, transform=ccrs.PlateCarree())
    
    mask = ~( (earlySHBML.Latitude < -32.6) & (earlySHBML.Longitude > 17.5) )
    ax.scatter(earlySHBML.Longitude[mask], earlySHBML.Latitude[mask], marker='o', s=20, color= 'k', transform=ccrs.PlateCarree(), label='SHBML')

    moor20m = [18.318, -32.292]
    ax.scatter(moor20m[0], moor20m[1], marker='^', color= 'r', s=80, transform=ccrs.PlateCarree(), label='20m mooring')

    # coordinate of 70 m mooring
    moor70m = [18.183, -32.329]
    ax.scatter(moor70m[0], moor70m[1], marker='*', color= 'r', s=80, transform=ccrs.PlateCarree(), label='70m mooring')

    # contour
    contours = ax.contour(da.lon_rho, da.lat_rho, da.isel(time=0).h, 
                          levels=5, colors='black', linewidths=1.5, alpha=0.8)
    plt.clabel(contours, inline=True, fontsize=9)

    ax.legend(loc='lower left', fontsize=10)

    ax.text(18.35, -31.8, 'South \nAfrica')
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')

plt.rcParams.update({'font.size': 14})

fig = plt.figure(figsize=(5, 6))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())

plot_map(ax)

In [None]:
# Find the water depth of 20m and 70m moorings

def find_point(lat_1 = 73.297, lon_1 = -145.456):

    # Flatten the 2D lat/lon grids
    lat_flat = da.lat_rho.values.ravel()
    lon_flat = da.lon_rho.values.ravel()
    
    dist = np.empty(shape=lat_flat.shape)
    
    for p in range(len(dist)):
        dist[p] = gsw.distance([lon_flat[p], lon_1], [lat_flat[p], lat_1])  
        
    # Find index of minimum distance
    idx_flat = np.nanargmin(dist)
    
    # Convert back to 2D indices
    i, j = np.unravel_index(idx_flat, da.lat_rho.shape)
    
    print(f"Closest point index: nlat={i}, nlon={j}")
    print(f"Coordinates: lat={da.lat_rho.values[i,j]}, lon={da.lon_rho.values[i,j]}")

    return i, j 

In [None]:
moor20m = [18.318, -32.292]
moor70m = [18.183, -32.329]

i_20, j_20 = find_point(lat_1 = moor20m[1], lon_1 = moor20m[0])
i_70, j_70 = find_point(lat_1 = moor70m[1], lon_1 = moor70m[0])

# Dissolved oxygen climatology

Pre-processing climatology data

In [None]:
# The SHBML climatology data (de Villiers, 2017) are available from: https://doi.pangaea.de/10.1594/PANGAEA.882218
fpath = './data/Climatology_SBUS.csv'
df = pd.read_csv(fpath, encoding='latin1')

l_vars = list(df.columns) # names of all the variables
var = ['Temp', 'Sal', 'O2', '[PO4]3-', '[NO3]- + [NO2]-', 'Si(OH)4', 'Chl a']
months = ["JAN", "FEB", "MAR", "APR", "MAY", "JUN", "JUL", "AUG", "SEP", "OCT", "NOV", "DEC"]

In [None]:
def get_ds_stn(stn):
    '''for a station, return a dataset with all variables'''

    df_sub = df[df['Station'] == stn]
    # extract longitude and latitude
    depth = df_sub['Depth water [m]'].values # depth array
    num_depth = np.size(depth) # number of depth levels

    all_data = [] # collect data from all variables
    for v in var:
        #v = var[0] # the variable
        v_12months = np.empty((num_depth, 12)) # initialized the array for 12 months' data
        for i in range(12):
            # find the variable name that has both the variable and the month
            selected_var = [s for s in l_vars if (v in s) and (months[i] in s)][0]

            v_12months[:, i] = df_sub[selected_var].values
        all_data.append(v_12months)

    # form a dataset
    variables_dict = {f'{v}': (['depth', 'month'], d) for v, d in zip(var, all_data)}
    ds_ = xr.Dataset(
        variables_dict,
        coords={'month': np.arange(1,13,1),
                'depth': depth,
               },
    )
    ds_ = ds_.expand_dims({'station': [stn]})
    
    return ds_

In [None]:
# stitch all stations together
stations = df['Station'].unique()

ds = None
for stn in stations:  
    stn = int(stn) # needs to be int to be iterable
    try:
        ds = xr.concat([ds, get_ds_stn(stn)], 'station')    
    except:
        ds = get_ds_stn(stn)
ds = ds.rename({'Temp': 'temp', 
                'Sal': 'sal', 
                'O2': 'oxy',
                '[PO4]3-': 'phos',
                '[NO3]- + [NO2]-': 'nitrate', 
                'Si(OH)4': 'silicate', 
                'Chl a': 'chl', 
               })
# add units for each variable
units_dict = {'temp': 'degC',
              'depth': 'm',
              'sal': 'psu',
              'oxy': 'µmol/l',
              'phos': 'µmol/l',
              'nitrate': 'µmol/l',
              'silicate': 'µmol/l',
              'chl': 'µg/l'
             }

# Loop over variable names and set units
for var_name, units in units_dict.items():
    ds[var_name].attrs['units'] = units

## Check some plots

In [None]:
# coords for stations
lat_ = df['Latitude'].unique()
lon_ = df['Longitude'].unique()

# distance from station 1
dist = np.cumsum(gsw.distance(lon_, lat_))/1000 # km
dist = np.insert(dist, 0, 0)

In [None]:
def detect_sharp_gradients(arr, threshold):
    
    '''replace the all values after the sharp gradient to NaN'''
    
    # Compute the gradient
    gradient = np.diff(arr)
    
    # Find the indices where the gradient exceeds the threshold
    sharp_points = np.where(np.abs(gradient) > threshold)[0]
    
    
    
    if sharp_points.size == 0:
        return arr
    else:
        arr[sharp_points[0]:] = np.nan
    #return sharp_points
        return arr


## Interpolate into a trasect with profiles every 1km

In [None]:
A_new_nan = np.empty((366, 150, 12))

for m in range(12):
    
    # Original data
    A_ = ds.oxy.values[:,:,m]
    A = np.nan_to_num(A_, nan=-999)
    x_original = dist
    y = ds.depth.values
    # Create the interpolation function
    f = interp2d(x_original, y, A.T, kind='linear')
    # New x-axis
    x_new = np.arange(0, 150)
    # Interpolate A to the new grid
    A_new = f(x_new, y)
    
    print(m, A_new.shape)


    for i in range(A_new.shape[1]):

        A_new_nan[:, i, m] = detect_sharp_gradients(A_new[:, i], 20)

In [None]:
DO_transect_interpolated = xr.Dataset(
    {
        'oxy': (['depth', 'dist2shore',  'month'], A_new_nan),

    },
    coords={
            'dist2shore': np.arange(0,150),
            'depth': ds.depth.values,
            'month': ds.month.values,
           },
)

In [None]:
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.contourf(dist, ds.depth, ds.oxy.isel(month=0).T, levels=np.arange(14,302, 20))
plt.ylim(366,0)
plt.colorbar()
plt.xlabel('Distance from shore (km)')
plt.ylabel('Depth (m)')
plt.title('Original')

for stn in range(ds.station.size):
    ind = ~np.isnan(ds.oxy.isel(month=0, station=stn).values)
    x_ = ind*dist[stn]
    plt.scatter(x_[ind], ds.depth.values[ind], c='gray', s=2)



plt.subplot(1, 2, 2)
plt.contourf(DO_transect_interpolated.dist2shore, DO_transect_interpolated.depth, DO_transect_interpolated.oxy.isel(month=0), levels=np.arange(14,302, 20))
plt.ylim(366,0)
plt.colorbar()
plt.xlabel('Distance from shore (km)')
plt.ylabel('Depth (m)')
plt.title('Interpolated for every 1 km')

for stn in range(DO_transect_interpolated.dist2shore.size):
    if stn%2==0:
        ind = ~np.isnan(DO_transect_interpolated.oxy.isel(month=0, dist2shore=stn).values)
        x_ = ind*DO_transect_interpolated.dist2shore.values[stn]
        plt.scatter(x_[ind], DO_transect_interpolated.depth.values[ind], c='gray', s=1, alpha=0.1)

## Interpolate coordinates

In [None]:
# Original 1D array of longitudes with 12 elements
longitudes = df['Longitude'].unique() 
latitudes = df['Latitude'].unique()

# Create an array with 150 elements for the new grid
new_x = np.linspace(0, len(longitudes) - 1, 150)

interpolated_longitudes = np.interp(new_x, np.arange(len(longitudes)), longitudes)
interpolated_latitudes = np.interp(new_x, np.arange(len(latitudes)), latitudes)

Check they match

In [None]:
plt.plot(np.linspace(0, len(longitudes) -1 , 150), interpolated_longitudes)
plt.plot(np.linspace(0, len(longitudes) -1 , 10), longitudes, '--')

plt.figure()

plt.plot(np.linspace(0, len(longitudes) -1 , 150), interpolated_latitudes)
plt.plot(np.linspace(0, len(longitudes) -1 , 10), latitudes, '--')

Read model bathymetry

In [None]:
da_giles = xr.open_mfdataset('./data/croco_grd.nc')

tlong_bth = da_giles.lon_rho.values
tlat_bth = da_giles.lat_rho.values
h_bth = da_giles.h.values

In [None]:
# Target points
target_lons = interpolated_longitudes
target_lats = interpolated_latitudes

closest_indices = []

for lon_point, lat_point in zip(target_lons, target_lats):
    # Calculate the squared distance for efficiency
    distances = (tlong_bth - lon_point)**2 + (tlat_bth - lat_point)**2
    
    # Find the index of the minimum distance
    idx = np.unravel_index(np.argmin(distances), tlong_bth.shape)
    
    closest_indices.append(idx)

In [None]:
h_sel = []
for i in range(len(closest_indices)):
    h_sel.append(h_bth[closest_indices[i]])

In [None]:
plt.plot(np.arange(150), h_sel)

In [None]:
fig = plt.figure(figsize=(5, 6))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
ax.set_extent([17, 18.5, -33.5, -31], crs=ccrs.PlateCarree())

ax.set_xticks(np.arange(16.3, 19, 0.5), crs=ccrs.PlateCarree())
ax.set_yticks(np.arange(-33.5, -30.5, 0.5), crs=ccrs.PlateCarree())
lon_formatter = LongitudeFormatter(zero_direction_label=False)
lat_formatter = LatitudeFormatter()
ax.xaxis.set_major_formatter(lon_formatter)
ax.yaxis.set_major_formatter(lat_formatter)  

ax.add_feature(cfeature.LAND)

# edges
corners_long = [tlong_bth[0, 0], tlong_bth[0, -1], tlong_bth[-1, -1], tlong_bth[-1, 0], tlong_bth[0, 0]]
corners_lat = [tlat_bth[0, 0], tlat_bth[0, -1], tlat_bth[-1, -1], tlat_bth[-1, 0], tlat_bth[0, 0]]
plt.plot(corners_long, corners_lat, '--', color='C3', alpha=0.7, )

for i in range(len(closest_indices)):
    ax.scatter(tlong_bth[closest_indices[i]], tlat_bth[closest_indices[i]],s=5, c='k', transform=ccrs.PlateCarree())
ax.scatter(longitudes, latitudes, s=10, c='r', transform=ccrs.PlateCarree())

ax.set_xlabel('Longitude', fontsize=14)
ax.set_ylabel('Latitude', fontsize=14)

## Now we have the h, we need to modify DO_transect_interpolated

In [None]:
DO_transect_interpolated

In [None]:
arr_depth = DO_transect_interpolated.depth.values
print('first depth is ', arr_depth[0], ' m')

new_oxy = np.copy(DO_transect_interpolated.oxy.values)
num_d, num_i, num_t = new_oxy.shape
num_d, num_i, num_t

In [None]:
for t in range(num_t):
    for i in range(num_i):
            DO_sel = new_oxy[:, i, t]
            
            # Count the number of non-NaN values
            count_non_nan = np.sum(~np.isnan(DO_sel))
            
            # this is the depth that has a value
            depth_with_value = arr_depth[0] + count_non_nan - 1 
            
            if depth_with_value <= int(h_sel[i]):
                
                # the depth range to append values
                diff_d = int(h_sel[i]) - depth_with_value
                
                new_oxy[count_non_nan : count_non_nan+diff_d, i, t] = DO_sel[count_non_nan-1]
                
            elif depth_with_value > int(h_sel[i]):
                
                # the depth range to append nan
                diff_d = - int(h_sel[i]) + depth_with_value
                
                new_oxy[count_non_nan - diff_d : count_non_nan, i, t] = np.nan


In [None]:
DO_transect_interpolated_final = xr.Dataset(
    {
        'oxy': (['depth', 'dist2shore',  'month'], new_oxy),

    },
    coords={
            'dist2shore': np.arange(0,150),
            'depth': ds.depth.values,
            'month': ds.month.values,
           },
)

In [None]:
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.contourf(DO_transect_interpolated_final.dist2shore, DO_transect_interpolated_final.depth, DO_transect_interpolated_final.oxy.isel(month=0), levels=np.arange(14,302, 20))
plt.ylim(366,0)
plt.colorbar()
plt.xlabel('Distance from shore (km)')
plt.ylabel('Depth (m)')
plt.title('Interpolated horizontal and vertical to h')

plt.plot(np.arange(150), h_sel, 'r')

for stn in range(DO_transect_interpolated_final.dist2shore.size):
    if stn%2==0:
        ind = ~np.isnan(DO_transect_interpolated_final.oxy.isel(month=0, dist2shore=stn).values)
        x_ = ind*DO_transect_interpolated_final.dist2shore.values[stn]
        plt.scatter(x_[ind], DO_transect_interpolated_final.depth.values[ind], c='gray', s=1, alpha=0.1)


plt.subplot(1, 2, 2)
plt.contourf(DO_transect_interpolated.dist2shore, DO_transect_interpolated.depth, DO_transect_interpolated.oxy.isel(month=0), levels=np.arange(14,302, 20))
plt.ylim(366,0)
plt.colorbar()
plt.xlabel('Distance from shore (km)')
plt.ylabel('Depth (m)')
plt.title('Interpolated for every 1 km')

plt.plot(np.arange(150), h_sel, 'r')

for stn in range(DO_transect_interpolated.dist2shore.size):
    if stn%2==0:
        ind = ~np.isnan(DO_transect_interpolated.oxy.isel(month=0, dist2shore=stn).values)
        x_ = ind*DO_transect_interpolated.dist2shore.values[stn]
        plt.scatter(x_[ind], DO_transect_interpolated.depth.values[ind], c='gray', s=1, alpha=0.1)

# Figure 1

In [None]:
xticks = np.concatenate((np.arange(9,13), np.arange(1,9)))
string_numbers = [str(xt) for xt in xticks]

def plot_DO_20m(ax):
    
    ns_mooring = pd.read_pickle('./data/WQM20m_daily.pkl') 
    
    var_name = 'Oxygen'
    ns_mooring['Month'] = ns_mooring.index.month

    # reorder months
    ns_mooring_copy = ns_mooring.copy()
    for i in range(12):
        ns_mooring_copy.loc[ns_mooring_copy['Month'] == xticks[i], 'Month'] = str(i+1)
    ns_mooring_copy['Month'] = ns_mooring_copy['Month'].astype(int)


    # box-and-whisker plot 
    flierprops = dict(marker='o', markerfacecolor='none', markersize=6, linestyle='none', markeredgecolor='black', markeredgewidth=0.8)
    sns.boxplot(ax=ax, x='Month', y=var_name, data=ns_mooring_copy, width=0.5, color='C2', flierprops=flierprops)

    # plot medians
    monthly_median = ns_mooring_copy.groupby('Month')[var_name].median().reset_index()
    ax.plot(np.arange(0,12), (monthly_median[var_name]), marker='o', color='black')

    ax.axhline(y=60, linestyle='--',color='gray')
    ax.set_xticks(np.arange(12), string_numbers)
    ax.set_xlabel('Month')
    ax.set_ylabel('DO (µmol kg$^{-1}$)')
    ax.set_ylim(0, 310)
    ax.set_yticks(np.arange(0,310,50));

def plot_DO_70m(ax):

    mooring_70m = pd.read_pickle('./data/WQM70m_daily.pkl') 
    var_name = 'Oxygen'
    mooring_70m['Month'] = mooring_70m.index.month

    # reorder months
    mooring_70m_copy = mooring_70m.copy()
    xticks_ = np.array([ 9, 10, 11, 2,  3,  4,  5,  6,  7,  8])
    string_numbers_ = [str(xt) for xt in xticks_]
    for i in range(10):
        if i<=2:
            mooring_70m_copy.loc[mooring_70m_copy['Month'] == xticks_[i], 'Month'] = str(i+1)
        else:
            mooring_70m_copy.loc[mooring_70m_copy['Month'] == xticks_[i], 'Month'] = str(i+3)
    mooring_70m_copy['Month'] = mooring_70m_copy['Month'].astype(int)

    mooring_70m_copy = mooring_70m_copy.append({var_name: np.nan, 'Month': 4}, ignore_index=True)
    mooring_70m_copy = mooring_70m_copy.append({var_name: np.nan, 'Month': 5}, ignore_index=True)

    # box-and-whisker plot 
    flierprops = dict(marker='o', markerfacecolor='none', markersize=6, linestyle='none', markeredgecolor='black', markeredgewidth=0.8)
    sns.boxplot(ax=ax, x='Month', y=var_name, data=mooring_70m_copy, width=0.5, color='C2', flierprops=flierprops)

    # plot medians
    monthly_median = mooring_70m_copy.groupby('Month')[var_name].median().reset_index()
    ax.plot(np.arange(12), (monthly_median[var_name]), marker='o', color='black')

    ax.axhline(y=60, linestyle='--',color='gray')
    ax.set_xticks(np.arange(12), string_numbers)
    ax.set_xlabel('Month')
    ax.set_ylabel('DO (µmol kg$^{-1}$)');
    ax.set_ylim(0, 310)
    ax.set_yticks(np.arange(0,310,50));

In [None]:
plt.rcParams.update({'font.size': 12.3})

fig = plt.figure(figsize=(13, 3.4), constrained_layout=True)

gs = gridspec.GridSpec(
    1, 3,
    figure=fig,
    wspace=0.03
)

ax1 = plt.subplot(gs[0, 0], projection=ccrs.PlateCarree())
ax2 = plt.subplot(gs[0, 1])
ax3 = plt.subplot(gs[0, 2])

plot_map(ax1)
plot_DO_20m(ax2)
plot_DO_70m(ax3)

ax1.text(15.3, -31, 'a', fontsize=18)
ax2.text(-3.2, 310, 'b', fontsize=18)
ax3.text(-3.2, 310, 'c', fontsize=18)

#fig.savefig('./figures/figure_1_half.png', dpi=200)


In [None]:
arr_area_hypoxia = []

plt.rcParams.update({'font.size': 12.3})

def add_colorbar(x0, y0, vmin, vmax, label, cmap_label='rainbow_r', levels=np.arange(0, 300, 30)):
    '''
    x0, y0: start location for the colorbar
    vmin, vmax: range of the colorbar
    label: label of the colorbar
    levels: discrete levels for the colorbar
    '''
    cax = fig.add_axes([x0, y0, 0.01, 0.7])  # [x0, y0, width, height]
    
    # Create a colormap and normalization based on the levels
    cmap = plt.get_cmap(cmap_label, len(levels) - 1)  # get discrete colormap
    norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)

    # Create the ScalarMappable with the discrete colormap and norm
    sm = cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])  # Needed for the colorbar

    # Create the colorbar
    cbar = fig.colorbar(sm, cax=cax, shrink=0.9, label=label, orientation='vertical')
    #cbar.ax.tick_params(labelsize=18)
    cbar.set_ticks(levels)  # Set the ticks to match the levels

order = [9,10,11,12,1,2,3,4,5,6,7,8]

fig = plt.figure(figsize=(13, 3.4))

for t in range(12):
    ax = fig.add_subplot(2, 6, t+1)
    
    data = DO_transect_interpolated_final.oxy.sel(month=order[t])
    x = DO_transect_interpolated_final.dist2shore.values
    y = DO_transect_interpolated_final.depth.values
    plt.contourf(x, y, data, levels=np.arange(0,300, 10), cmap='rainbow_r') # cmap=cmocean.cm.oxy
    contours = plt.contour(x, y, data, levels=np.arange(0,300, 60), colors='gray', linestyles='solid', linewidths=1) # cmap=cmocean.cm.oxy
    plt.clabel(contours, inline=True, fontsize=10, fmt='%1.0f')
    
    # hypoxia line
    contour_line = plt.contour(x, y, data, levels=[60], colors='black', linestyles='dashed', linewidths=2)
    plt.clabel(contour_line, inline=True, fontsize=10, fmt='%1.0f')
    
    plt.text(78, 280,  f"Month {order[t]}")
            
    ax.set_ylim(370,0)
    ax.set_yticks([])
    if t in [0,6]:
        ax.set_yticks(np.arange(300, -40, -100))
        ax.set_ylabel('Depth (m)')
    if t >= 6:
        ax.set_xticks(np.arange(0, 150, 50))
    if t < 6:
        ax.set_xticks([])

    plt.gca().invert_xaxis()

ax.text(550, 500, 'Offshore distance (km)')
ax.text(1010, -370, 'd', fontsize=18)

add_colorbar(0.91, 0.15, 0, 300, 'DO (µmol kg$^{-1}$)')
plt.subplots_adjust(wspace=0.06, hspace=0.05)

#fig.savefig(f'./figures/figure1_half2.png', bbox_inches='tight', dpi=200)

In [None]:
# combine the two halves of Figure 1
from PIL import Image

img1 = Image.open('./figures/figure1_half.png')
img2 = Image.open('./figures/figure1_half2.png')

if img1.width != img2.width:
    img2 = img2.resize((img1.width, int(img2.height * img1.width / img2.width)))

combined_height = img1.height + img2.height
combined_img = Image.new('RGB', (img1.width, combined_height), (255, 255, 255))

combined_img.paste(img1, (0, 0))
combined_img.paste(img2, (0, img1.height))

combined_img.save('./figures/figure1_combined.png')

# Figure 3

In [None]:
# residence time
whole_ds_1poly100m = xr.open_dataset('./data/res_time_sub.nc')
whole_ds_1poly100m_noMLD = xr.open_dataset('./data/res_time_whole.nc')

# mixed layer depth based on N2 metric
N2_mld_all = np.load('./data/mld_5years_N2.npy')
whole_ds_1poly100m['MLD_N2'] = (('time', 'eta_rho', 'xi_rho'), N2_mld_all)

# stratification
ds_stra = xr.open_dataset(f'./data/ds_stratification.nc')
ds_slope_adjusted_APG = xr.open_dataset(f'./data/ds_slope_APG_3km_domain_adjusted_closer2shore.nc')

# model outputs
ns_mdl10m = pd.read_pickle('./data/nearshore_10mextractionsfrom1kmmdl.pkl')

In [None]:
def get_monthly(whole_ds_var):
    '''return mean and std of a'''
    mean = whole_ds_var.groupby('time.month').mean(dim='time').isel(eta_rho=109, xi_rho=71)
    std = whole_ds_var.groupby('time.month').std(dim='time').isel(eta_rho=109, xi_rho=71)
    
    return mean, std

def reorder(mean):
    return np.concatenate((mean[-4:], mean[:-4]))

xticks = np.concatenate((np.arange(9,13), np.arange(1,9)))
string_numbers = [str(xt) for xt in xticks]

In [None]:
# mean and std of Tres, Tres_noMLD for inshore of 100m
mean_100m, std_100m = get_monthly(whole_ds_1poly100m.efold_time_int)
mean_noMLD_100m, std_noMLD_100m = get_monthly(whole_ds_1poly100m_noMLD.efold_time_int)
mean_diff_100m, std_diff_100m = get_monthly((whole_ds_1poly100m_noMLD - whole_ds_1poly100m).efold_time_int)

In [None]:
plt.rcParams.update({'font.size': 12})
FONTSIZE = 15

fig = plt.figure(figsize=(10,7.4))
ax0 = fig.add_subplot(2,2, 1)
ax1 = fig.add_subplot(2,2, 2)
ax2 = fig.add_subplot(2,2, 3)
ax3 = fig.add_subplot(2,2, 4)
plt.subplots_adjust(wspace=0.4, hspace=0.15)

####### T_res
ax = ax0

ax.plot(mean_noMLD_100m['month']-1, reorder(mean_noMLD_100m), marker='o', label=r"$\tau_{whole}$", color='C1')
ax.fill_between(mean_noMLD_100m['month']-1, reorder(mean_noMLD_100m) - reorder(std_noMLD_100m), reorder(mean_noMLD_100m) + reorder(std_noMLD_100m), alpha=0.2, color='C1')
ax.plot(mean_100m['month']-1, reorder(mean_100m), marker='o', label=r"$\tau_{sub}$", color='C0')
ax.fill_between(mean_100m['month']-1, reorder(mean_100m) - reorder(std_100m), reorder(mean_100m) + reorder(std_100m), alpha=0.2, color='C0')

ax.set_xticks(np.arange(12), string_numbers);
ax.set_xlabel('')
ax.set_ylabel("Residence times (days)")
ax.set_ylim(2,22)

ax.legend(loc='upper left', fontsize=12)


######### wind stress
ax = ax1
var_name = 'svstr'
ns_mdl10m['Month'] = ns_mdl10m.index.month
# reorder months
ns_mdl10m_copy = ns_mdl10m.copy()
for i in range(12):
    ns_mdl10m_copy.loc[ns_mdl10m_copy['Month'] == xticks[i], 'Month'] = str(i+1)
ns_mdl10m_copy['Month'] = ns_mdl10m_copy['Month'].astype(int)

# Create a box-and-whisker plot using seaborn
flierprops = dict(marker='o', markerfacecolor='none', markersize=6, linestyle='none', markeredgecolor='black', markeredgewidth=0.8)
sns.boxplot(ax=ax, x='Month', y=var_name, data=ns_mdl10m_copy, width=0.5, color='C3', flierprops=flierprops)
ax.axhline(y=0, color='r', linestyle='--', label='y=0')

# plot medians
monthly_median = ns_mdl10m_copy.groupby('Month')[var_name].median().reset_index()
ax.plot(np.arange(0,12), monthly_median[var_name], marker='o', color='black')
ax.set_ylabel("Alongshore wind stress (N m$^{-2}$)")
ax.set_xlabel("")
ax.set_xticks(np.arange(12), string_numbers);
ax.set_yscale('symlog', linthresh=0.01)
ax.set_ylim(-0.7, 0.4)

tau_monthly_median = monthly_median


######## stratification
ax= ax2
var_name = 'N2_int'
var_data = ds_stra[var_name]
df = var_data.to_dataframe()
df['Month'] = df.index.month
# reorder months
df_copy = df.copy()
for i in range(12):
    df_copy.loc[df_copy['Month'] == xticks[i], 'Month'] = str(i+1)
df_copy['Month'] = df_copy['Month'].astype(int)

flierprops = dict(marker='o', markerfacecolor='none', markersize=6, linestyle='none', markeredgecolor='black', markeredgewidth=0.8)
sns.boxplot(ax=ax, x='Month', y=var_name, data=df_copy, width=0.5, color='C4', flierprops=flierprops)

# plot medians
monthly_median = var_data.groupby('time.month').median('time')
ax.plot(np.arange(0,12), reorder(monthly_median), marker='o', color='black')

N_monthly_median = monthly_median
ax.set_xticks(np.arange(12), string_numbers);
ax.set_xlabel('Month')
ax.set_ylabel('Bouyancy frequency $N^2$ (rad s$^{-1}$)')


######## APG
ax = ax3
var_name = 'slope'
var_data = ds_slope_adjusted_APG[var_name]
df = var_data.to_dataframe()
df['Month'] = df.index.month
# reorder months
df_copy = df.copy()
for i in range(12):
    df_copy.loc[df_copy['Month'] == xticks[i], 'Month'] = str(i+1)
df_copy['Month'] = df_copy['Month'].astype(int)

flierprops = dict(marker='o', markerfacecolor='none', markersize=6, linestyle='none', markeredgecolor='black', markeredgewidth=0.8)
sns.boxplot(ax=ax, x='Month', y=var_name, data=df_copy, width=0.5, color='C5', flierprops=flierprops)

# plot median
monthly_median = var_data.groupby('time.month').median('time')
ax.plot(np.arange(0,12), reorder(monthly_median), marker='o', color='black')
ax.plot([-0.5,11.5],[0,0], 'r--')
ax.text(3.7, 0.025, 'Poleward')
ax.text(3.7, -0.028, 'Equatorward')
ax.set_ylim(-0.03, 0.03)
# Customize the plot labels and title
ax.set_xlabel('Month')
ax.set_ylabel('APG slope (m deg$^{-1}$)')
ax.set_xticks(np.arange(12), string_numbers);

ax0.text(-1.6, 2.15, 'a', fontsize=18, transform=ax.transAxes)
ax1.text(-0.21, 2.15, 'b', fontsize=18, transform=ax.transAxes)
ax2.text(-1.6, 1.05, 'c', fontsize=18, transform=ax.transAxes)
ax3.text(-0.21, 1.05, 'd', fontsize=18, transform=ax.transAxes)

fig.canvas.draw()

fig.savefig(f'./figures/figure_3.png', bbox_inches='tight', dpi=200)

# Figure S1

In [None]:
def line_equation_from_two_points(point1, point2):
    x1, y1 = point1
    x2, y2 = point2
    
    # Calculate slope (m)
    slope = (y2 - y1) / (x2 - x1)

    # Calculate y-intercept (b) using one of the points
    intercept = y1 - slope * x1

    return slope, intercept

In [None]:
# transect coords
point1 = (16.9, -29.8)
point2 = (18.3, -32.5)

k,b = line_equation_from_two_points(point1, point2)
print(k,b)

xtransect = np.arange(point1[0], point2[0]+0.01, 0.05)
ytransect = k*xtransect + b

In [None]:
da_3km = xr.open_mfdataset('./data/zeta_avg_3km_Y2008M1.nc')

tlong = da_3km.lon_rho.values
tlat = da_3km.lat_rho.values
mask = da_3km.mask_rho.values

ocean_indices = np.where(mask==1)
tlong_sel = tlong[ocean_indices]
tlat_sel = tlat[ocean_indices]

In [None]:
plt.rcParams.update({'font.size': 14})

fig = plt.figure(figsize=(5, 6))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())


ax.set_extent([15.8, 19.5, -35.0, -29.5], crs=ccrs.PlateCarree())
ax.set_xticks(np.arange(16, 19, 1), crs=ccrs.PlateCarree())
ax.set_yticks(np.arange(-35, -29.5, 1), crs=ccrs.PlateCarree())
lon_formatter = LongitudeFormatter(zero_direction_label=False)
lat_formatter = LatitudeFormatter()
ax.xaxis.set_major_formatter(lon_formatter)
ax.yaxis.set_major_formatter(lat_formatter)  

ax.add_feature(cfeature.LAND)

colors = list(mcolors.TABLEAU_COLORS.values())
ind_color = np.arange(len(colors)) # 0- 9

edges_coord = []

ax.text(18.35, -31.8, 'South \nAfrica')
ax.set_xlabel('Longitude')
ax.set_ylabel('Latitude')

ax.plot(xtransect,ytransect,'k--', linewidth=3)

msk = da_3km.mask_rho.values
msk[msk == 0] = np.nan
ca = ax.contourf(tlong, tlat, da_3km.zeta.isel(time=0) * msk, levels=15, cmap=cm.coolwarm)
fig.colorbar(ca, ax=ax, label='Sea surface height (m)', orientation='vertical', shrink=0.7)

#fig.savefig(f'./figures/figure_S1.png', bbox_inches='tight', dpi=200)

# Figure S5

In [None]:
# mean and std of MLD
var_name = 'MLD_N2'
var_data = (whole_ds_1poly100m[var_name]*inshore_mask)
var_data = xr.where(var_data==0, np.nan, var_data)  # convert 0 to nans
var_data = var_data.mean(dim=['eta_rho', 'xi_rho'])
MLD_mean =  var_data.groupby('time.month').mean(dim='time')
MLD_std =  var_data.groupby('time.month').std(dim='time')

Tres_diff_100m = (whole_ds_1poly100m_noMLD-whole_ds_1poly100m).efold_time_int.isel(eta_rho=109, xi_rho=71)
mld = var_data

x = mld.values
y = Tres_diff_100m.values

In [None]:
from scipy.stats import spearmanr, kendalltau

rho, pval = spearmanr(mld, Tres_diff_100m)
tau, pval_tau = kendalltau(mld, Tres_diff_100m)

print("Spearman rho:", rho, "p-value:", pval)
print("Kendall tau:", tau, "p-value:", pval_tau)

In [None]:
plt.rcParams.update({'font.size': 11})
fig, axes = plt.subplots(1,3, figsize=(11, 2.7))

axes[0].plot(mean_diff_100m['month'], reorder(mean_diff_100m), marker='o', label='$\Delta_{Whole-bottom}$', color='C0')
axes[0].fill_between(mean_diff_100m['month'], reorder(mean_diff_100m) - reorder(std_diff_100m), reorder(mean_diff_100m) + reorder(std_diff_100m), alpha=0.2, color='C0')

axes[1].plot(MLD_mean['month'], reorder(MLD_mean), marker='o', label='MLD', color='gray')
axes[1].fill_between(MLD_mean['month'], reorder(MLD_mean) - reorder(MLD_std), reorder(MLD_mean) + reorder(MLD_std), alpha=0.2, color='gray')


axes[0].set_xlabel('Month')
axes[1].set_xlabel('Month')
axes[0].set_ylabel(r"$\tau_{whole}$ - $\tau_{sub}$ (days)", fontsize=12)
axes[1].set_ylabel('Mixed layer depth (m)')

axes[0].set_ylim(2,16)
axes[1].set_ylim(0,16)
axes[0].set_xticks(np.arange(1,13), string_numbers)
axes[1].set_xticks(np.arange(1,13),  string_numbers);
axes[0].set_yticks(np.arange(2,17,3));
axes[1].set_yticks(np.arange(0,36,5));


# Create scatter plot
axes[2].scatter(x, y, label='Data Points', s=5, c='lightgray')

# Optional: add a LOWESS smoother
sns.regplot(x=mld, y=Tres_diff_100m, scatter=False, lowess=True, color='k')

# Add labels and legend
axes[2].set_xlabel('Mixed layer depth (m)')
axes[2].set_ylabel(r"$\tau_{whole}$ - $\tau_{sub}$ (days)", fontsize=12)

Spearman_rho = f'Spearman $\\rho$ = {rho:.2f}'
p_val = f'p value < 1e-5' 

axes[2].text(22, 2.5, Spearman_rho, fontsize=10)
axes[2].text(25, 1, p_val, fontsize=10)

axes[2].set_xlim(-3, 58)
axes[2].set_ylim(0, 18)

plt.subplots_adjust(wspace=0.3, hspace=0.0)

axes[0].text(-1, 16.5, 'a' , fontsize=15)
axes[1].text(-2.5, 36, 'b', fontsize=15)
axes[2].text(-10, 18.8, 'c', fontsize=15)

#fig.savefig(f'./figures/figure_S5.png', bbox_inches='tight', dpi=200)

# Figure S6

In [None]:
wind_series = ns_mdl10m['svstr']

In [None]:
# Ensure datetime index and proper naming
wind_series.index = pd.to_datetime(wind_series.index)
wind_series.name = 'svstr'

# Create a DataFrame
df = wind_series.to_frame()

# Add month columns
df['month'] = df.index.month
df['month_name'] = df.index.strftime('%b')

# Group by month (across all years)
grouped = df.groupby('month')

# Calculate monthly stats
median = grouped['svstr'].median()
iqr = grouped['svstr'].quantile(0.75) - grouped['svstr'].quantile(0.25)
std = grouped['svstr'].std()

# Calculate % upwelling- and downwelling-favorable days
upwelling_pct = grouped['svstr'].apply(lambda x: (x > 0).sum() / len(x) * 100)    # positive = upwelling
downwelling_pct = grouped['svstr'].apply(lambda x: (x < 0).sum() / len(x) * 100)  # negative = downwelling

# Combine into summary DataFrame
summary = pd.DataFrame({
    'Median': median,
    'IQR': iqr,
    'StdDev': std,
    '% Upwelling-Favorable': upwelling_pct,
    '% Downwelling-Favorable': downwelling_pct
})

# Add numeric and name month columns
summary['Month'] = summary.index
summary['Month Name'] = summary['Month'].apply(lambda x: pd.Timestamp(f'2024-{x:02d}-01').strftime('%b'))

# Reorder to start in September (Southern Hemisphere spring)
month_order = list(range(9, 13)) + list(range(1, 9))  # [9, 10, 11, 12, 1, ..., 8]
summary = summary.set_index('Month')
summary = summary.loc[month_order].reset_index()

# Update Month Name again (optional)
summary['Month Name'] = summary['Month'].apply(lambda x: pd.Timestamp(f'2024-{x:02d}-01').strftime('%b'))

# Display result
print(summary[['Month Name', 'Median', 'IQR', 'StdDev', '% Upwelling-Favorable', '% Downwelling-Favorable']])


In [None]:
plt.rcParams.update({'font.size': 13})


# Reorder months to start from September
month_order = list(range(9, 13)) + list(range(1, 9))
summary = summary.set_index('Month')
summary = summary.loc[month_order].reset_index()

# Use month numbers directly for x-axis labels
x_labels = summary['Month'].astype(str).tolist()  # ['9', '10', '11', ..., '8']

# Plot
fig, ax1 = plt.subplots(figsize=(10, 5))

# Plot median ± IQR (primary axis)
ax1.errorbar(x_labels, summary['Median'], 
             yerr=summary['IQR'] / 2, fmt='o-', label='Median ± IQR', color='C0')

ax1.set_ylabel('Alongshore wind stress (N m⁻²)')
ax1.set_xlabel('Month')
ax1.axhline(0, color='gray', linestyle='--')
ax1.set_ylim(-0.04, 0.04)
#ax1.grid(True)

# Secondary axis for % upwelling and downwelling
ax2 = ax1.twinx()
ax2.plot(x_labels, summary['% Upwelling-Favorable'], 
         's--', color='C3', label='% Upwelling-favorable')
ax2.plot(x_labels, summary['% Downwelling-Favorable'], 
         'd--', color='C4', label='% Downwelling-favorable')

ax2.set_ylabel('Percentage of days (%)')


# Combine legends
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right', fontsize=12)

plt.tight_layout()
plt.show()

#fig.savefig(f'./figures/figure_S6.png', bbox_inches='tight', dpi=200)

# Figure 2

In [None]:
# Model grid from Fearon et al 2023
da = xr.open_dataset('./data/grid.nc')

tlong = da.lon_rho.values
tlat = da.lat_rho.values
mask = da.mask_rho.isel(time=0).values
inshore_mask = xr.where(da.isel(time=0).h <= 100, 1, 0)  # inshore of 100m
mask_sel = mask*inshore_mask

ocean_indices = np.where(mask_sel==1)
tlong_sel = tlong[ocean_indices]
tlat_sel = tlat[ocean_indices]

In [None]:
ds_back_30days = xr.open_dataset('./data/back_everyday_result_70m.nc')

In [None]:
def reorder(mean):
    return np.concatenate((mean[-4:], mean[:-4]))

xticks = np.concatenate((np.arange(9,13), np.arange(1,9)))
string_numbers = [str(xt) for xt in xticks]

In [None]:
# compute trends of depth for each day back in time

def plot_depth_mean(index, ax, ax2):
    
    fpath = '/d1/mengyang/Benguela/Particle_tracking/tempdir/backward_watermass/'
    ds_back = xr.open_dataset(fpath + 'back_everyday_result_70m.nc').sel(num_days_back=index)

    new_time = np.repeat(ds_back.time, 100).values
    new_z = ds_back.z.stack(combined=('time', 'traj')).values
    new_lon = ds_back.lon.stack(combined=('time', 'traj')).values
    new_lat = ds_back.lat.stack(combined=('time', 'traj')).values
    new_h = ds_back.h.stack(combined=('time', 'traj')).values

    new_ds = xr.Dataset(
        {
            'z': (['time'], new_z),
            'lon': (['time'], new_lon),
            'lat': (['timxtickse'], new_lat),
            'h': (['time'], new_h),
        },
        coords={'time': new_time,},
    )

    new_ds['z'] = -new_ds.z


    def plot_var(new_ds, ax, var_name, label):
        var_data = new_ds[var_name]
        df = var_data.to_dataframe()
        df['Month'] = df.index.month

        # reorder months
        df_copy = df.copy()
        for i in range(12):
            df_copy.loc[df_copy['Month'] == xticks[i], 'Month'] = str(i+1)
        df_copy['Month'] = df_copy['Month'].astype(int)

        # Create a box-and-whisker plot using seaborn
        ax = sns.boxplot(ax=ax, x='Month', y=var_name, data=df_copy, width=0.5, color='C1', flierprops = dict(marker='d', markersize=2))

        # plot medians
        monthly_median = var_data.groupby('time.month').median('time')
        ax.plot(np.arange(0,12), reorder(monthly_median), marker='o', color='black')
        ax.set_xticks(np.arange(12), string_numbers);
        ax.set_ylabel(label);
        #ax.set_yscale('log')

    plot_var(new_ds, ax, "z", 'Depth (m)')

    # ax.set_xticks([]);
    # ax1.set_xticks([]);
    ax.set_xlabel('');

    #ax.text(-3, 350, 'a', fontsize=20)

    ############# second row
    monthly_mean_new_ds = new_ds.groupby('time.month').mean(dim='time')
    monthly_std_new_ds = new_ds.groupby('time.month').std(dim='time')

    order = [9,10,11,12,1,2,3,4,5,6,7,8]
    order = [1,2,3,4,5,6,7,8,9,10,11,12]

    mean = reorder(monthly_mean_new_ds.z)
    std = reorder(monthly_std_new_ds.z)
    ax2.plot(order, mean, marker='o', label='Surface', color='C0')
    ax2.fill_between(order, mean - std, mean + std, alpha=0.2, color='C0')


    ax.set_xlabel('Month')
    ax2.set_xlabel('Month')
    ax2.set_ylabel('Depth (m)')

    ax2.set_ylim(-10, 200)
    ax.set_ylim(-10,325)
    ax2.set_xticks(np.arange(1,13), string_numbers);

    #ax2.text(-2, 160, 'b', fontsize=20)


In [None]:
def movie_back_water(index, close=False, save=False):
    
    plt.rcParams.update({'font.size': 13})

    order = [9,10,11,12,1,2,3,4,5,6,7,8]
    fig = plt.figure(figsize=(14, 8))
    for t in range(12):
        ax = fig.add_subplot(2, 6, t+1, projection=ccrs.PlateCarree())
        ax.set_extent([17, 18.5, -33.5, -31], crs=ccrs.PlateCarree())
        ax.add_feature(cfeature.LAND)
        if t in [0,6]:
            ax.set_yticks(np.arange(-33.5, -30.5, 0.5), crs=ccrs.PlateCarree())
            lat_formatter = LatitudeFormatter()
            ax.yaxis.set_major_formatter(lat_formatter)  
        if t >= 6:
            ax.set_xticks(np.arange(17, 18.5, 0.5), crs=ccrs.PlateCarree())
            lon_formatter = LongitudeFormatter(zero_direction_label=False)
            ax.xaxis.set_major_formatter(lon_formatter)
        
        # contour
        contours = ax.contour(da.lon_rho, da.lat_rho, da.isel(time=0).h, 
                              levels=5, colors='black', linewidths=2, alpha=0.6, transform=ccrs.PlateCarree())
        plt.clabel(contours, inline=True, fontsize=12)

        #plt.contourf(tlong, tlat, da.zeta.isel(time=160), levels=15)
        #plt.colorbar(label='Sea surface height (m)')

        month = order[t]
        tim = f'{month:02d}'
        ds_back_ = ds_back_30days.sel(num_days_back= index, time=ds_back_30days['time.month'] == month)
        plt.scatter(ds_back_.lon, ds_back_.lat, s=3, c=-ds_back_.z, vmin=0, vmax=300, cmap='viridis_r')
        #plt.colorbar(shrink=1, label='Depth (m)')
        plt.title(tim)

        # coordinate of 70 m mooring
        moor70m = [18.183, -32.329]
        moor20m = [18.318, -32.292]
        plt.scatter(moor70m[0], moor70m[1], s=30, marker='*',c='r')
        
        if t==1:
            plt.text(19.2, -30.6, f'{index:02d} days back-in-time', fontsize=18)

    def add_colorbar(x0, y0, vmin, vmax, label, cmap_label="viridis_r"):
        '''
        x0, y0: start location for the colorbar
        vmin, vmax: range of the colorbar
        label: label of the colorbar'
        '''
        cax = fig.add_axes([x0, y0, 0.015, 0.5])  # [x0, y0, width, height]
        cmap = plt.colormaps[cmap_label]
        normalize = plt.Normalize(vmin=vmin, vmax=vmax)  # Normalize the color values
        sm = cm.ScalarMappable(cmap=cmap, norm=normalize)
        cbar = fig.colorbar(sm, cax=cax, shrink=0.9, label=label, orientation='vertical')
        cbar.ax.tick_params(labelsize=14)

    add_colorbar(0.92, 0.25, 0, 300, 'Depth (m)')
    plt.subplots_adjust(wspace=0.06, hspace=0.04)
    
    # ax1 = fig.add_axes([1.02, 0.55, 0.2, 0.3]) # [left, bottom, width, height]
    # ax2 = fig.add_axes([1.02, 0.2, 0.2, 0.3])
    
    ax1 = fig.add_axes([0.13, -0.35, 0.34, 0.4]) # [left, bottom, width, height]
    ax2 = fig.add_axes([0.56, -0.35, 0.34, 0.4])
    
    # First map (upper-left of the 12)
    fig.axes[0].text(
        -0.15, 1.15, "a", transform=fig.axes[0].transAxes,
        fontsize=20, fontweight="bold", va="top", ha="left"
    )

    # Bottom row
    ax1.text(
        -0.15, 1.05, "b", transform=ax1.transAxes,
        fontsize=20, fontweight="bold",va="top", ha="left"
    )
    ax2.text(
        -0.15, 1.05, "c", transform=ax2.transAxes,
        fontsize=20,  fontweight="bold",va="top", ha="left"
    )


    plot_depth_mean(index, ax1, ax2)
    
    if close == True:
        plt.close()
    
    if save == True:
        fig.savefig(f'./figures/figs2movie/backward_watermass_70m/backward_watermass_70m_{index:02}d_clim_together.png', bbox_inches='tight', dpi=200)

In [None]:
%%time
movie_back_water(30, close=False, save=False)

In [None]:
%%time
for t in range(1, len(ds_back_30days.num_days_back.values)+1):
    print(t)
    movie_back_water(t, close=True, save=True)

# Figure S8

In [None]:
ds_stra.N2_int.values

In [None]:
tau = ns_mdl10m.svstr.values
N = ds_stra.N2_int.values
rho0 = 1025

omega = 7.292115e-5  # (Groten, 2004)                  [ radians/s ]
lat = -32
f = np.abs(2*omega*np.sin(lat))


d = 4.83 + 9.13*np.sqrt(tau/(rho0*N*f))

In [None]:
ds_stra["source_depth"] = ("time", d)
ds_stra["wind"] = ("time", ns_mdl10m.svstr.values)

ds_stra

In [None]:
ds_stra_median = ds_stra.groupby('time.month').median(dim='time')
ds_stra_mean = ds_stra.groupby('time.month').mean(dim='time')
ds_stra_std = ds_stra.groupby('time.month').std(dim='time')

In [None]:
plt.rcParams.update({'font.size': 12})
var_name = 'source_depth'
var_data = ds_stra[var_name]
df = var_data.to_dataframe()
df['Month'] = df.index.month

# reorder months
df_copy = df.copy()
for i in range(12):
    df_copy.loc[df_copy['Month'] == xticks[i], 'Month'] = str(i+1)
df_copy['Month'] = df_copy['Month'].astype(int)

# Create a box-and-whisker plot using seaborn
fig = plt.figure(figsize=(5, 4))
ax = sns.boxplot(x='Month', y=var_name, data=df_copy, width=0.5, color='gray')

# plot medians
monthly_median = var_data.groupby('time.month').median('time')
plt.plot(np.arange(0,12), reorder(monthly_median), marker='o', color='black')

N_monthly_median = monthly_median

# Customize the plot labels and title
plt.xlabel('Month')
plt.ylabel('Upwelling source depth (m)')
plt.xticks(np.arange(12), string_numbers);
#plt.yticks(np.arange(0,18,2));
#plt.title(f'Box-and-Whisker Plot for {var_name} by Month')
#plt.ylim(1,12)
#fig.savefig('./figures/figure_S8.png',bbox_inches='tight',dpi=200)

In [None]:
ds_stra_median.wind.plot(color='b', marker='o')
ds_stra_mean.wind.plot(color='r', marker='o')

In [None]:
ds_stra_median.N2_int.plot(color='b', marker='^')
ds_stra_mean.N2_int.plot(color='r', marker='o')

In [None]:
d_mean = 4.83 + 9.13*np.sqrt(ds_stra_mean.wind/(rho0 * ds_stra_mean.N2_int * f))
d_median = 4.83 + 9.13*np.sqrt(ds_stra_median.wind/(rho0 * ds_stra_median.N2_int * f))

In [None]:
plt.plot(np.arange(0,12), reorder(d_median), marker='o', color='b')
plt.plot(np.arange(0,12), reorder(d_mean), marker='o', color='r')

plt.xticks(np.arange(12), string_numbers);


# Figure S7

In [None]:
df_sheet = pd.read_excel('./data/Grant_2022_data.xlsx', sheet_name='cleaned')

# fill in the nans in Time array
original_time = df_sheet['Time'].values
for i in range(len(original_time)):
    if str(original_time[i]) == 'NaT':
        original_time[i] = original_time[i-1]

df_sheet['Time'] = original_time

df_sheet.columns = ['time', 'Depth', 'Resp', 'NCP', 'GCP'] # rename

for i in range(len(df_sheet)):
    if df_sheet['Depth'][i] == 14:
        df_sheet['Depth'][i] = 15

df_sheet.set_index('time', inplace=True)

# Convert the DataFrame to an xarray Dataset
ds = xr.Dataset.from_dataframe(df_sheet)

In [None]:
# subsurface: depth > 8 or depth > 15
ds_sub = ds.where(ds.Depth>5, drop=True)
mean = ds_sub.groupby('time.month').mean(dim='time')
std = ds_sub.groupby('time.month').std(dim='time')

uni_date = np.unique(ds_sub.time)

Resp_int = np.zeros(uni_date.shape)
NCP_int = np.zeros(uni_date.shape)
GCP_int = np.zeros(uni_date.shape)

In [None]:
# average over different depth levels
for i in range(len(uni_date)):
    
    ds_ = ds_sub.sel(time=uni_date[i]) # select a profile
    
    if ds_.Depth.size <= 2:
        Resp_int[i] = ds_.Resp.mean()
        NCP_int[i] = ds_.NCP.mean()
        GCP_int[i] = ds_.GCP.mean()
    else:
        resp = ds_.Resp.values
        NCP = ds_.NCP.values
        GCP = ds_.GCP.values
        depth = ds_.Depth.values
        
        Resp_int[i] = ((depth[1] - depth[0]) * (resp[0] + resp[1])/2 + (depth[2] - depth[1]) * (resp[1] + resp[2])/2) / (depth[2] - depth[1])
        NCP_int[i] = ((depth[1] - depth[0]) * (NCP[0] + NCP[1])/2 + (depth[2] - depth[1]) * (NCP[1] + NCP[2])/2) / (depth[2] - depth[1])
        GCP_int[i] = ((depth[1] - depth[0]) * (GCP[0] + GCP[1])/2 + (depth[2] - depth[1]) * (GCP[1] + GCP[2])/2) / (depth[2] - depth[1])

ds_int = xr.Dataset(
    {
        'Resp': (['time'], Resp_int*44.7),
        'NCP': (['time'], NCP_int*44.7),
        'GCP': (['time'], GCP_int*44.7),
    },
    coords={'time': uni_date,},
)

In [None]:
def reorder(mean):
    arr = mean
    arr_ = np.insert(arr, 5, np.nan)
    
    return np.concatenate((arr_[-4:], arr_[:-4]))

xticks = np.concatenate((np.arange(9,13), np.arange(1,9)))
string_numbers = [str(xt) for xt in xticks]

In [None]:
mean = ds_int.groupby('time.month').mean(dim='time')
std = ds_int.groupby('time.month').std(dim='time')

plt.rcParams.update({'font.size': 13})
fig = plt.figure(figsize=(5,4))
# Assuming `dataset` is your xarray dataset and `variable_name` is the name of the variable you want to plot
variable = ds_int.Resp

# Extract months from the time dimension
months = pd.to_datetime(variable['time'].values).month

# Adjust the months to wrap around starting from September (month 9)
adjusted_months = []
for i in range(len(months)):
    if months[i] > 8:
        adjusted_months.append(months[i] - 8)
    else:
        adjusted_months.append(months[i] - 8 + 12)
        

# Plot scatter
plt.scatter(adjusted_months, variable.values, s=20,c='k')
plt.xlabel('Month')
plt.ylabel('Respiration rate (µmol O$_2$ L$^{-1}$ day$^{-1}$)')

# Customize x-axis ticks and labels
plt.xticks(ticks=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
           labels=[str(i) for i in [9,10,11,12,1,2,3,4,5,6,7,8]]);

def get_mean_std(months):
    '''
    months: [9,10,11]
    '''
    mean_R = mean.Resp.sel(month=months).values.mean()
    std_R = mean.Resp.sel(month=months).values.std()
    
    return mean_R, std_R
    
# plt.errorbar(0 * 3 + 2, mean_R, yerr=std_R, fmt='o', color='red', capsize=5, label=f' Mean ± Std')
season_labels = ['s','ss','sdd','sddd']
season_labels = [2,5,8,11]
mean_values = [get_mean_std([9,10,11])[0], get_mean_std([12,1,2])[0], get_mean_std([3,4,5])[0], get_mean_std([7,8])[0]]
std_values = [get_mean_std([9,10,11])[1], get_mean_std([12,1,2])[1], get_mean_std([3,4,5])[1], get_mean_std([7,8])[1]]
plt.bar(season_labels, mean_values, yerr=std_values, capsize=5, color='skyblue', ecolor='blue', \
        alpha=0.7, width=2, linewidth=3,)

fig.savefig('./figures/figure_S7.png', bbox_inches='tight', dpi=100)