In [1]:
from dask.distributed import Client,LocalCluster
from dask_jobqueue import PBSCluster
import os

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.markers import MarkerStyle
import cmocean
import numpy as np
import xarray as xr
import rioxarray as rxr
import dask as da
import geopandas as gpd
import cartopy as cart
import cartopy.crs as ccrs
import pandas as pd
import matplotlib.ticker as ticker
from matplotlib.gridspec import GridSpec
import emsarray as emr
import emsarray
import calendar
from shapely.geometry import Point, Polygon, box, shape
from alphashape import alphashape
import shapely
from tqdm.notebook import tqdm_notebook
import time
import regionmask
import re as re
from windrose import WindroseAxes
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

from datetime import datetime, timedelta
import glob
import PyCO2SYS as pyco2
from IPython.display import display, HTML
import great_circle_calculator.great_circle_calculator as gcc
import gsw as sw

import math as math

In [None]:
plt.rcParams['figure.figsize'] = [10, 6]   # Default figure size
plt.rcParams['figure.dpi'] = 100           # Default DPI
plt.rcParams['savefig.dpi'] = 300          # DPI for saving figures
plt.rcParams['font.size'] = 12             # Default font size
plt.rcParams['axes.titlesize'] = 14        # Title font size
plt.rcParams['axes.labelsize'] = 12        # Axis label font size
plt.rcParams['xtick.labelsize'] = 11       # X-tick label font size
plt.rcParams['ytick.labelsize'] = 11       # Y-tick label font size
plt.rcParams['legend.fontsize'] = 10       # Legend font size
plt.rcParams['lines.linewidth'] = 1.5      # Line width
plt.rcParams['lines.markersize'] = 8       # Marker size

In [None]:
def safe_div(numerator,denominator):
    if denominator != 0:
        return numerator/denominator
    else:
        return 0

def read_shoc_nc(path,ems=True):
  
    files = glob.glob(path)

    files.sort()

    if ems == True:

        ds = emsarray.accessors.xarray.open_mfdataset(files,concat_dim='time',combine='nested',
                           data_vars='minimal',compat='override',
                           coords='minimal',parallel='true')

    else:

        ds = xr.open_mfdataset(files,concat_dim='time',combine='nested',
                               data_vars='minimal',compat='override',
                               coords='minimal',parallel='true')

    return ds

def extent(data):

    if 'longitude' in data.coords:
        
        extent = (np.nanmin(data['longitude'].values)-0.005,np.nanmax(data['longitude'].values)+0.005,
                  np.nanmin(data['latitude'].values)-0.005,np.nanmax(data['latitude'].values)+0.005)
        
    else:

        extent = (np.nanmin(data['x_centre'].values)-0.005,np.nanmax(data['x_centre'].values)+0.005,
                  np.nanmin(data['y_centre'].values)-0.005,np.nanmax(data['y_centre'].values)+0.005)
        
    return extent

def running_mean(data, window_size):
    return np.convolve(data, np.ones(window_size)/window_size, mode='valid')

In [2]:
import numpy as np
from datetime import datetime
import os,sys
import time
from datetime import timedelta
from datetime import date

def shoc_days(years_since,year,month,day,hours=None,minutes=None,seconds=None):

    hours = hours or 0
    minutes = minutes or 0
    seconds = seconds or 0

    # Calculate the difference in days between the end date and the begin date
    dates = datetime(year,month,day,hours,minutes,seconds)

    day_one = datetime(years_since,1,1)
    
    time_difference = (dates - day_one)

    days_with_fraction = time_difference.days + (time_difference.seconds / 86400)
    
    return days_with_fraction

def shoc_date(start_yr,days_in): # Input days as float of days since first day, i.e., 4718.236

    days_since_string = str(start_yr)+"-1-1"
    day_one = datetime.strptime(days_since_string,"%Y-%m-%d")
    end_date = day_one + timedelta(days=days_in)
    
    return end_date

In [None]:
def read_in_ts(ts_file_path): 
    
    ts_file = str(ts_file_path)
    
    import re
    
    comment = '#'
    columns = []
    
    with open(ts_file,'r') as td:
        for line in td:
            line = line.rstrip()
            if line[0] == comment:
                if '.name' in line: 
                    inf = re.split(' +',line)
                    columns.append(inf[2])
    
            else:
                _dfs = [
                    pd.DataFrame([line.split(' ')],columns=columns, dtype=float),
                    pd.read_table(td,sep=' ', header=None, names=columns)]
                df = pd.concat(_dfs,ignore_index=True)
    return df

In [6]:
def start_cluser(cores,hours=None,mins=None):

    hours = hours or '00'
    if not isinstance(hours, str):
        if float(hours) < 10:
            hours = '0' + str(hours)

    mins = mins or '00'
    if not isinstance(mins, str):
        if float(mins) < 10:
            mins = '0' + str(mins)
        
    walltime = str(hours) + ':' + str(mins) + ':00'
    cores = cores
    memory = str(4 * cores) + 'GB'
    
    cluster = PBSCluster(walltime=str(walltime), cores=cores, memory=str(memory), processes=cores,
                         job_extra_directives=['-q normalsr',
                                               '-P ih54',
                                               '-l ncpus='+str(cores),
                                               '-l mem='+str(memory),
                                               '-l storage=scratch/et4+gdata/et4+gdata/ew0+gdata/ih54'],
                         local_directory='$TMPDIR',
                         job_directives_skip=["select"])
                         # python=os.environ["DASK_PYTHON"])

    return cluster

In [None]:
def plume_size(xr_ds,input,threshold,var):

    cell_area = (input.h1acell * input.h2acell).values

    plume = (xr_ds[var].isel(k=18) > threshold).values

    binary_plume = plume.astype(int)

    square_km = np.zeros((len(xr_ds.time)))

    for i in tqdm_notebook(range(len(xr_ds.time)),desc='progress...'):

        data = cell_area * binary_plume[i,:,:]
    
        square_km[i] = np.nansum(data/1e6) 
        
    return square_km

In [1]:
def marine_res_mask(gpd_df, template_ds, boundary=None, lon_name='x', lat_name='y'):
    
    mask = regionmask.mask_3D_geopandas(
        gpd_df,
        template_ds[lon_name],
        template_ds[lat_name]
    )
    
    if lon_name != 'lon':
        mask = mask.rename({lon_name: 'lon'})
    if lat_name != 'lat':
        mask = mask.rename({lat_name: 'lat'})
        
    if isinstance(boundary, list):
        mask = mask.sel(
            lon=slice(boundary[0], boundary[1]),
            lat=slice(boundary[2], boundary[3])
        )
        
    return mask
        

In [2]:
def shoc_ssh_mask(dataset,original_in_file,var,depth_slice=False,depth=None):

    depth_mask = np.empty((len(dataset.time),len(dataset.j),len(dataset.i)))
    
    new_var = np.zeros((len(dataset.time),len(dataset.k),len(dataset.j),len(dataset.i)))
    
    if depth_slice:
        
        depth_slice_mask = np.zeros((len(dataset.time),len(dataset.j),len(dataset.i)))

        depth_slice_var = np.zeros((len(dataset.time),len(dataset.k),len(dataset.j),len(dataset.i)))
    
    depths = original_in_file.z_grid.values
    eta_vals = dataset.eta.values
    var_vals = dataset[var].values
    
    dims=['time','k','j','i']
    
    for t in tqdm_notebook(range(len(dataset.time)),desc="progress"):
    
        eta = eta_vals[t,:,:]
        variable = var_vals[t,:,:,:]
    
        for j in range(len(dataset.j)):
            
            for i in range(len(dataset.i)):
        
                depth_mask[t,j,i] = np.abs(depths-eta[j,i]).argmin() 

                if depth_slice:

                    depth_slice_mask[t,j,i] = np.abs(depths-(eta[j,i]-depth)).argmin()
        
        for k in range(len(dataset.k)):
            
            for j in range(len(dataset.j)):
                
                for i in range(len(dataset.i)):

                    if depth_slice:
                        if (k >= depth_slice_mask[t,j,i]) & (k <= depth_mask[t,j,i]):

                            depth_slice_var[t,k,j,i] = variable[k,j,i]

                    else:
                        if k <= depth_mask[t,j,i]:
                            
                            new_var[t,k,j,i] = variable[k,j,i]


    new_dataset = xr.Dataset(
    {
        var: (dims, new_var),
    }
    )

    if depth_slice:
            slice_dataset = xr.Dataset(
    {
        var: (dims, depth_slice_var),
    }
    )
    
    return slice_dataset if depth_slice else new_dataset

def shoc_volume(input_array):

    z_diff = np.diff(input_array.z_grid.values)

    volume = np.zeros((len(input_array.z_grid)-1,len(input_array.j_centre),len(input_array.i_centre)))

    for k in range(len(input_array.z_grid)-2):

        volume[k] = input_array.h1acell.values * input_array.h2acell.values * z_diff[k]

    volume[len(input_array.z_grid)-2] = input_array.h1acell.values * input_array.h2acell.values * z_diff[len(input_array.z_grid)-3]

    return volume

def shoc_mass_2(dataset,var,vol):

    mass_var = dataset[var] * vol

    value_ts = mass_var.sum(dim=['i','j','k'])

    return value_ts.values

In [1]:
def longest_period_above_below_threshold(time_series, threshold, above=True):
    longest_period = 0
    longest_period_start = 0
    current_period = 0
    current_period_start = 0

    comparator = '>' if above else '<'

    for i, value in enumerate(time_series):
        if eval(f'value {comparator} threshold'):
            if current_period == 0:
                current_period_start = i
            current_period += 1
        else:
            if current_period > longest_period:
                longest_period = current_period
                longest_period_start = current_period_start
            current_period = 0

    if current_period > longest_period:
        longest_period = current_period
        longest_period_start = current_period_start

    return longest_period_start,longest_period


In [None]:
def plot_with_custom_grid(fig, axes, rows, cols, projection=ccrs.PlateCarree()):
    """
    Apply gridlines to a figure with subplots such that only the leftmost plots
    have labels on the left axis, and the bottommost plots have labels on the bottom axis.

    Parameters:
    fig : matplotlib.figure.Figure
        The figure object containing the subplots.
    axes : numpy.ndarray
        Array of axes objects returned from plt.subplots.
    rows : int
        Number of rows in the subplot grid.
    cols : int
        Number of columns in the subplot grid.
    projection : cartopy.crs.Projection, optional
        The projection for the gridlines. Defaults to PlateCarree.
    """
    
    axes = axes.flatten()  # Ensure axes are flattened in case they are a 2D array

    for i, ax in enumerate(axes):
        # Add gridlines with labels
        gl = ax.gridlines(crs=projection, draw_labels=True, linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
        
        # Remove top and right labels for all plots
        gl.top_labels = False
        gl.right_labels = False
        
        # Only show left labels for the leftmost plots
        if (i % cols) != 0:
            gl.left_labels = False

        # Only show bottom labels for the bottommost plots
        if i < (rows - 1) * cols:
            gl.bottom_labels = False