# OSNAP Parcels Experiments


## Technical preamble

In [None]:
# import matplotlib.colors as colors
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import xarray as xr
from datetime import datetime, timedelta
import seaborn as sns
# from matplotlib.colors import ListedColormap
import cmocean as co
import pandas as pd
import matplotlib.dates as mdates
import cartopy.crs as ccrs
import cartopy
import seawater as sw


xr.set_options(keep_attrs=True)


In [None]:
## Parameters
# Project path
project_path = Path.cwd() / '..' / '..' 
project_path = project_path.resolve()

# Parcels track data file
path_data_tracks = Path('data/processed/tracks/osnap/') 
filename_tracks_1 = 'tracks_osnap_backward_201301_N19886_D730_Rnd123.nc'
filename_tracks_2 = 'tracks_osnap_backward_201307_N19886_D730_Rnd123.nc'

# model mask file
data_path = Path("data/external/iAtlantic/")
experiment_name = "VIKING20X.L46-KKG36107B"
mesh_mask_file = project_path / data_path / "mask" / experiment_name / "1_mesh_mask.nc"

#section lonlat file
sectionPath = Path('data/external/sections/')
sectionFilename = 'osnap_pos_wp.txt'
sectionname = 'osnap'
gsrsectionFilename = 'gsr_pos_wp.txt'

degree2km = 1.852*60.0


# some transport values specific to osnap runs
# randomly seeded 39995 particles, 19886 were in ocean points (the rest were land)

osnap_section_length = 3594572.87839    # m
osnap_section_depth = 4000 # m over which particles launched

osnap_section_ocean_area = osnap_section_length * osnap_section_depth * 19886.0 / 39995.0

particle_section_area = osnap_section_length * osnap_section_depth / 39995.0


## plotting functions

In [None]:
# range of stations from west to east, stations 0-12. Python indexing.

def conditionalplotTracksBetweenLonsCartopy(ds,condition,lonRange,depthRange,cmap=co.cm.tempo_r):
    
    dsmask = ds.where(ds.isel(obs=0).lon > lonRange[0]).where(ds.isel(obs=0).lon < lonRange[1])
    dsmask = dsmask.where(ds.isel(obs=0).z > depthRange[0]).where(ds.isel(obs=0).z < depthRange[1])
    
    dscond = dsmask.where(condition).dropna('traj', how='all')

    central_lon, central_lat = -30, 50
    sns.set(style="whitegrid")
    fig, ax = plt.subplots(figsize = (16,10),subplot_kw={'projection': ccrs.Orthographic(central_lon, central_lat)})
    extent = [-60, 0, 30, 70]
    ax.set_extent(extent)
    ax.gridlines()
    ax.coastlines(resolution='50m')


    pcm = ax.scatter(
        dscond.lon.data.flatten(),
        dscond.lat.data.flatten(),
        3,
#         (dsmask.time.data.flatten()-ds.time.isel(obs=0,traj=0).data).astype('timedelta64[D]')
        mdates.date2num(dscond.time.data.flatten())
        ,cmap= cmap
        ,zorder=2
        ,transform=ccrs.PlateCarree()
    #   ,alpha=0.3
    )
    cb = fig.colorbar(pcm,ax=ax,label = "date")
    loc = mdates.MonthLocator()
    cb.ax.yaxis.set_major_locator(loc)
    cb.ax.yaxis.set_major_formatter(mdates.ConciseDateFormatter(loc))


    depth.isel(y=slice(1400, 2499), x=slice(0, 2404)).plot.contour(ax=ax,transform=ccrs.PlateCarree(),
        x="nav_lon", y="nav_lat", colors = 'grey', levels = [200,800,1500,2000,2500,3500],zorder=1
    );
#     depth.isel(y=slice(1400, 2499), x=slice(0, 2404)).plot.contour(ax=ax,transform=ccrs.PlateCarree(),
#         x="nav_lon", y="nav_lat", colors = 'k', levels = [1],zorder=3
#     );

    ax.scatter(
        ds.lon.isel(obs=0).data.flatten(),
        ds.lat.isel(obs=0).data.flatten(),2,zorder=4,
        transform=ccrs.PlateCarree()
    )
    ax.scatter(
        dsmask.lon.isel(obs=0).data.flatten(),
        dsmask.lat.isel(obs=0).data.flatten(),2,zorder=5,
        transform=ccrs.PlateCarree()
    )

    return


### Position test function

In [None]:
def apply_left_of_line(ds, lon_1, lon_2, lat_1, lat_2):
    '''Apply an area crossing criterion.
    
    Larvae in ds selected while they are in a selected area.
    '''
    # particles are selected if they pass through given area.
    position =  ((lon_2 -lon_1) * (ds.lat - lat_1) - 
                     (ds.lon - lon_1) * (lat_2 - lat_1))
                        
    return position >= 0.0, position < 0

## Load data

In [None]:
ds_1 = xr.open_dataset(project_path / path_data_tracks / filename_tracks_1)
ds_2 = xr.open_dataset(project_path / path_data_tracks / filename_tracks_2)

display(ds_1)
# ds.isel(obs=0).z.max()

## Velocity conversions from degrees lat/lon per second to m/s

In [None]:
ds_1=ds_1.assign({'uvel_ms':ds_1.uvel * degree2km * 1000.0 * np.cos(np.radians(ds_1.lat))})
ds_1=ds_1.assign({'vvel_ms':ds_1.vvel * degree2km * 1000.0})
ds_2=ds_2.assign({'uvel_ms':ds_2.uvel * degree2km * 1000.0 * np.cos(np.radians(ds_2.lat))})
ds_2=ds_2.assign({'vvel_ms':ds_2.vvel * degree2km * 1000.0})


In [None]:
mesh_mask = xr.open_dataset(mesh_mask_file)
mesh_mask = mesh_mask.squeeze()
mesh_mask = mesh_mask.set_coords(["nav_lon", "nav_lat", "nav_lev"])

bathy = mesh_mask.mbathy.rename("number of water filled points")

depth = (mesh_mask.e3t_0 * mesh_mask.tmask).sum("z")
# display(mesh_mask)

### section position data

In [None]:
lonlat = xr.Dataset(pd.read_csv(project_path / sectionPath / sectionFilename,delim_whitespace=True))
lonlat['lon'] *= -1.0

lonlat

In [None]:
lonlat.lon.attrs['long_name']='Longitude'
lonlat.lat.attrs['long_name']='Latitude'
lonlat.lon.attrs['standard_name']='longitude'
lonlat.lat.attrs['standard_name']='latitude'
lonlat.lon.attrs['units']='degrees_east'
lonlat.lat.attrs['units']='degrees_north'

lonlat2mean= lonlat.rolling({'dim_0':2}).mean()

lonlatdiff = (lonlat.diff('dim_0'))

lonlatdiff = lonlatdiff.assign({'y':lonlatdiff['lat']*degree2km})
lonlatdiff = lonlatdiff.assign({'x':lonlatdiff['lon']*degree2km*np.cos(np.radians(lonlat2mean.lat.data[1:]))})
lonlatdiff=lonlatdiff.assign({'length':np.sqrt(lonlatdiff['x']**2+lonlatdiff['y']**2)})
lonlatdiff=lonlatdiff.assign({'costheta':lonlatdiff['x']/lonlatdiff['length']})
lonlatdiff=lonlatdiff.assign({'sintheta':lonlatdiff['y']/lonlatdiff['length']})

total_length = lonlatdiff.length.sum().data 
total_osnap_length = lonlatdiff.length[0:12].sum().data;  # exclude section across UK - just there for testing north/south


lonlatdiff

## Find initial velocities normal to the section

In [None]:
ds_1_init = ds_1.isel(obs=0)
ds_1_init = ds_1_init.assign({'section_index':xr.DataArray(np.searchsorted(lonlat.lon,ds_1_init.lon)-1,dims='traj')})
ds_2_init = ds_2.isel(obs=0)
ds_2_init = ds_2_init.assign({'section_index':xr.DataArray(np.searchsorted(lonlat.lon,ds_2_init.lon)-1,dims='traj')})


In [None]:
ds_1_init = ds_1_init.assign({'u_normal':ds_1_init.vvel_ms * lonlatdiff.costheta[ds_1_init.section_index].data - ds_1_init.uvel_ms * lonlatdiff.sintheta[ds_1_init.section_index].data})
ds_1_init = ds_1_init.assign({'u_along':ds_1_init.vvel_ms * lonlatdiff.sintheta[ds_1_init.section_index].data + ds_1_init.uvel_ms * lonlatdiff.costheta[ds_1_init.section_index].data})
ds_2_init = ds_2_init.assign({'u_normal':ds_2_init.vvel_ms * lonlatdiff.costheta[ds_2_init.section_index].data - ds_2_init.uvel_ms * lonlatdiff.sintheta[ds_2_init.section_index].data})
ds_2_init = ds_2_init.assign({'u_along':ds_2_init.vvel_ms * lonlatdiff.sintheta[ds_2_init.section_index].data + ds_2_init.uvel_ms * lonlatdiff.costheta[ds_2_init.section_index].data})


## Find along-section distances of initial points

In [None]:
length_west = xr.concat((xr.DataArray([0],dims=("dim_0"),coords={"dim_0": [0]}),lonlatdiff.length.cumsum()),dim='dim_0')

ds_1_init = ds_1_init.assign({'x':xr.DataArray(length_west[ds_1_init.section_index] + lonlatdiff.length[ds_1_init.section_index]*
                              (ds_1_init.lon - lonlat.lon[ds_1_init.section_index])/lonlatdiff.lon[ds_1_init.section_index],dims='traj')})
ds_2_init = ds_2_init.assign({'x':xr.DataArray(length_west[ds_2_init.section_index] + lonlatdiff.length[ds_2_init.section_index]*
                              (ds_2_init.lon - lonlat.lon[ds_2_init.section_index])/lonlatdiff.lon[ds_2_init.section_index],dims='traj')})

## Plot section

In [None]:
sns.set(style="whitegrid")
central_lon, central_lat = -30, 55
fig, ax = plt.subplots(subplot_kw={'projection': ccrs.Orthographic(central_lon, central_lat)})
extent = [-60, 0, 40, 70]
ax.set_extent(extent)
ax.gridlines()
ax.coastlines(resolution='50m')

lonlat.plot.scatter(ax=ax,transform=ccrs.PlateCarree(),x='lon',y='lat')
lonlat2mean.plot.scatter(ax=ax,transform=ccrs.PlateCarree(),x='lon',y='lat')

## Have a quick look

### Release positions of the particles

In [None]:
sns.set(style="darkgrid")
fig,ax = plt.subplots(3,2,figsize = (18,12))

pcm = ax[0,0].scatter(ds_1_init.x.data.flatten(),
                    ds_1_init.z.data.flatten(),
                    2,
                    ds_1_init.salt.data.flatten(),
                    vmin=34.5,vmax=35.5,
                    cmap=co.cm.haline)
ax[0,0].set_ylim(0,4000)
ax[0,0].invert_yaxis()
fig.colorbar(pcm,ax=ax[0,0],label = "Salinity [PSU]",extend='both')

pcm = ax[1,0].scatter(ds_2_init.x.data.flatten(),
                    ds_2_init.z.data.flatten(),
                    2,
                    ds_2_init.salt.data.flatten(),
                    vmin = 34.5, vmax = 35.5,
                    cmap=co.cm.haline)
ax[1,0].set_ylim(0,4000)
ax[1,0].invert_yaxis()
fig.colorbar(pcm,ax=ax[1,0],label = "Salinity [PSU]",extend='both')

pcm = ax[2,0].scatter(ds_1_init.x.data.flatten(),
                    ds_1_init.z.data.flatten(),
                    2,
                    (ds_2_init-ds_1_init).salt.data.flatten(),
                    vmin = -0.3, vmax = 0.3,
                    cmap=co.cm.delta)
ax[2,0].set_ylim(0,4000)
ax[2,0].invert_yaxis()
fig.colorbar(pcm,ax=ax[2,0],label = "Salinity difference [PSU]",extend='both')

pcm = ax[0,1].scatter(ds_1_init.x.data.flatten(),
                    ds_1_init.z.data.flatten(),
                    2,
                    ds_1_init.temp.data.flatten(),
                    vmin=-1,vmax=11,
                    cmap=co.cm.thermal)
ax[0,1].set_ylim(0,4000)
ax[0,1].invert_yaxis()
fig.colorbar(pcm,ax=ax[0,1],label = "potential temperature [$\degree$C]",extend='both')

pcm = ax[1,1].scatter(ds_2_init.x.data.flatten(),
                    ds_2_init.z.data.flatten(),
                    2,
                    ds_2_init.temp.data.flatten(),
                    vmin=-1,vmax=11,
                    cmap=co.cm.thermal)
ax[1,1].set_ylim(0,4000)
ax[1,1].invert_yaxis()
fig.colorbar(pcm,ax=ax[1,1],label = "potential temperature [$\degree$C]",extend='both')

pcm = ax[2,1].scatter(ds_1_init.x.data.flatten(),
                    ds_1_init.z.data.flatten(),
                    2,
                    (ds_2_init-ds_1_init).temp.data.flatten(),
                    vmin = -3, vmax = 3,
                    cmap=co.cm.diff)
ax[2,1].set_ylim(0,4000)
ax[2,1].invert_yaxis()
fig.colorbar(pcm,ax=ax[2,1],label = "temperature difference [$\degree$C]",extend='both')





In [None]:
lonRange=[-30,-20]
depthRange=[0,500]

### plot tracks conditional

#### from Labrador sea

In [None]:
ds_1_lab_sea_in, ds_1_lab_sea_notin = apply_left_of_line(ds_1,-75,-40,40,65)
ds_2_lab_sea_in, ds_2_lab_sea_notin = apply_left_of_line(ds_2,-75,-40,40,65)

In [None]:
ds_1_60w_in, ds_1_60w_notin = apply_left_of_line(ds_1,-60,-60,40,65)
ds_2_60w_in, ds_2_60w_notin = apply_left_of_line(ds_2,-60,-60,40,65)

In [None]:

conditionalplotTracksBetweenLonsCartopy(ds_1,ds_1_lab_sea_in.max("obs"),lonRange=lonRange,depthRange=depthRange,cmap=co.cm.tempo_r)

In [None]:
conditionalplotTracksBetweenLonsCartopy(ds_2,ds_2_lab_sea_in.max("obs"),lonRange=lonRange,depthRange=depthRange,cmap=co.cm.tempo_r)

In [None]:
conditionalplotTracksBetweenLonsCartopy(ds_1,ds_1_lab_sea_in.max("obs")==False,lonRange=lonRange,depthRange=depthRange,cmap=co.cm.tempo_r)

In [None]:
conditionalplotTracksBetweenLonsCartopy(ds_2,ds_2_lab_sea_in.max("obs")==False,lonRange=lonRange,depthRange=depthRange,cmap=co.cm.tempo_r)

In [None]:
conditionalplotTracksBetweenLonsCartopy(ds_1,(ds_1_lab_sea_in.max("obs")==False)*ds_1_60w_in.max("obs"),lonRange=lonRange,depthRange=depthRange,cmap=co.cm.tempo_r)

In [None]:
conditionalplotTracksBetweenLonsCartopy(ds_2,(ds_2_lab_sea_in.max("obs")==False)*ds_2_60w_in.max("obs"),lonRange=lonRange,depthRange=depthRange,cmap=co.cm.tempo_r)

In [None]:
conditionalplotTracksBetweenLonsCartopy(ds_1,(ds_1_lab_sea_in.max("obs")==False)*(ds_1_60w_in.max("obs")==False),lonRange=lonRange,depthRange=depthRange,cmap=co.cm.tempo_r)

In [None]:
conditionalplotTracksBetweenLonsCartopy(ds_2,(ds_2_lab_sea_in.max("obs")==False)*(ds_2_60w_in.max("obs")==False),lonRange=lonRange,depthRange=depthRange,cmap=co.cm.tempo_r)

In [None]:
# there are neater ways to do this with xarray if the datasets are large
def get_TS(ds):
    S = ds.where(ds.salt != 0.0).salt.data.flatten()
    T = ds.where(ds.salt != 0.0).temp.data.flatten()
    time = ds.where(ds.salt != 0.0).time.data.flatten()
    z = ds.where(ds.salt != 0.0).z.data.flatten()

    T_pad = (np.nanmax(T) - np.nanmin(T))/20.0
    S_pad = (np.nanmax(S) - np.nanmin(S))/20.0
    T_lim = [np.nanmin(T) - T_pad, np.nanmax(T) + T_pad]
    S_lim = [np.nanmin(S) - S_pad, np.nanmax(S) + S_pad]
    
    return S,T,time,z,S_lim,T_lim


In [None]:
def TSplot_colourbytime(S_1,T_1,time_1,S_2,T_2,time_2,S_lim,T_lim):
    Tgrid = np.zeros((100,100)) + np.linspace(T_lim[0],T_lim[1],100)
    Sgrid = np.zeros((100,100)) + np.linspace(S_lim[0],S_lim[1],100)
    sig0grid = sw.pden(Sgrid,Tgrid.T,0,0) - 1000.0

    sns.set(style="darkgrid")
    fig,ax = plt.subplots(1,2,figsize=(14,9),sharey=True,gridspec_kw={'wspace': 0.1})

    cs = ax[0].contour(Sgrid,Tgrid.T,sig0grid,
              colors = 'grey', linestyles = 'dashed')
    ax[0].clabel(cs, inline=1, fmt='%1.1f', fontsize=10)

    cs = ax[1].contour(Sgrid,Tgrid.T,sig0grid,
              colors = 'grey', linestyles = 'dashed')
    ax[1].clabel(cs, inline=1, fmt='%1.1f', fontsize=10)

    cm = ax[0].scatter(S_1, T_1, 
                    s = 2, 
                    c = mdates.date2num(time_1),
                    cmap = co.cm.tempo_r,
                    zorder=10,
                    alpha=0.3)

    cbar = fig.colorbar(cm,ax=ax[0],orientation='horizontal',pad=0.1)
    # set alpha to 1 on colorbar
    cbar.set_alpha(1.0)
    cbar.draw_all() # don't understand why this is required but alph doesn't change without it
    # format the dates on colorbar
    loc = mdates.AutoDateLocator()
    cbar.ax.xaxis.set_major_locator(loc)
    cbar.ax.xaxis.set_major_formatter(mdates.ConciseDateFormatter(loc))
    ax[0].set_ylim(T_lim[0],T_lim[1])
    ax[0].set_xlim(S_lim[0],S_lim[1])
    # label axes
    ax[0].set_ylabel('potential temperature [$\degree$C]')
    ax[0].set_xlabel('practical salinity [PSU]')
    
    cm = ax[1].scatter(S_2, T_2, 
                    s = 2, 
                    c = mdates.date2num(time_2),
                    cmap = co.cm.tempo_r,
                    zorder=10,
                    alpha=0.3)

    cbar = fig.colorbar(cm,ax=ax[1],orientation='horizontal',pad=0.1)
    # set alpha to 1 on colorbar
    cbar.set_alpha(1.0)
    cbar.draw_all() # don't understand why this is required but alph doesn't change without it
    # format the dates on colorbar
    loc = mdates.AutoDateLocator()
    cbar.ax.xaxis.set_major_locator(loc)
    cbar.ax.xaxis.set_major_formatter(mdates.ConciseDateFormatter(loc))
    # label axes
    ax[1].set_ylim(T_lim[0],T_lim[1])
    ax[1].set_xlim(S_lim[0],S_lim[1])
    ax[1].set_ylabel('')
    ax[1].set_xlabel('practical salinity [PSU]')

In [None]:
def TSplot(S_1,T_1,z_1,S_2,T_3,z_3,zlabel,S_lim,T_lim):
    Tgrid = np.zeros((100,100)) + np.linspace(T_lim[0],T_lim[1],100)
    Sgrid = np.zeros((100,100)) + np.linspace(S_lim[0],S_lim[1],100)
    sig0grid = sw.pden(Sgrid,Tgrid.T,0,0) - 1000.0

    sns.set(style="darkgrid")
    fig,ax = plt.subplots(1,2,figsize=(14,9),sharey=True,gridspec_kw={'wspace': 0.1})

    cs = ax[0].contour(Sgrid,Tgrid.T,sig0grid,
              colors = 'grey', linestyles = 'dashed')
    ax[0].clabel(cs, inline=1, fmt='%1.1f', fontsize=10)
    cs = ax[1].contour(Sgrid,Tgrid.T,sig0grid,
              colors = 'grey', linestyles = 'dashed')
    ax[1].clabel(cs, inline=1, fmt='%1.1f', fontsize=10)

    cm = ax[0].scatter(S_1, T_1, 
                    s = 2, 
                    c = z_1,
                    cmap = co.cm.deep,
                    zorder=10,
                    alpha=0.3)

    cbar = fig.colorbar(cm,label = zlabel,ax=ax[0],orientation='horizontal',pad=0.1)
    # set alpha to 1 on colorbar
    cbar.set_alpha(1.0)
    cbar.draw_all() # don't understand why this is required but alph doesn't change without it
    # label axes
    ax[0].set_ylabel('potential temperature [$\degree$C]')
    ax[0].set_xlabel('practical salinity [PSU]')
    ax[0].set_ylim(T_lim[0],T_lim[1])
    ax[0].set_xlim(S_lim[0],S_lim[1])
    
    cm = ax[1].scatter(S_2, T_2, 
                    s = 2, 
                    c = z_2,
                    cmap = co.cm.deep,
                    zorder=10,
                    alpha=0.3)

    cbar = fig.colorbar(cm,label = zlabel,ax=ax[1],orientation='horizontal',pad=0.1)
    # set alpha to 1 on colorbar
    cbar.set_alpha(1.0)
    cbar.draw_all() # don't understand why this is required but alph doesn't change without it
    # label axes
    ax[1].set_ylabel('')
    ax[1].set_xlabel('practical salinity [PSU]')
    ax[1].set_ylim(T_lim[0],T_lim[1])
    ax[1].set_xlim(S_lim[0],S_lim[1])


In [None]:
dsmask_1 = ds_1.where(ds_1.isel(obs=0).lon > lonRange[0]).where(ds_1.isel(obs=0).lon < lonRange[1])
dsmask_1 = dsmask_1.where(ds_1.isel(obs=0).z > depthRange[0]).where(ds_1.isel(obs=0).z < depthRange[1]).dropna('traj', how='all').isel(obs=slice(1,10,1))
S_1,T_1,time_1,z_1,S_lim,T_lim = get_TS(dsmask_1)

dsmask_2 = ds_2.where(ds_2.isel(obs=0).lon > lonRange[0]).where(ds_2.isel(obs=0).lon < lonRange[1])
dsmask_2 = dsmask_2.where(ds_2.isel(obs=0).z > depthRange[0]).where(ds_2.isel(obs=0).z < depthRange[1]).dropna('traj', how='all').isel(obs=slice(1,10,1))
S_2,T_2,time_2,z_2,S_lim,T_lim = get_TS(dsmask_2)

S_lim_fix = [32,37]
T_lim_fix = [-1,24]

In [None]:
TSplot_colourbytime(S_1,T_1,time_1,S_2,T_2,time_2,S_lim_fix,T_lim_fix)

In [None]:
TSplot(S_1,T_1,z_1,S_2,T_2,z_2,'depth [m]',S_lim_fix,T_lim_fix)

## Transports

In [None]:
sns.set(style="darkgrid")
fig,ax = plt.subplots(2,figsize = (14,10))
cm=ax[0].scatter(ds_1_init.lon,ds_1_init.z,2,ds_1_init.u_normal,cmap=co.cm.balance,vmin=-0.3,vmax=0.3)
ax[0].invert_yaxis()
fig.colorbar(cm,ax=ax[0],label = "u_normal",extend='both')

cm=ax[1].scatter(ds_2_init.lon,ds_2_init.z,2,ds_2_init.u_normal,cmap=co.cm.balance,vmin=-0.3,vmax=0.3)
ax[1].invert_yaxis()
fig.colorbar(cm,ax=ax[1],label = "u_normal",extend='both')

In [None]:
sns.set(style="darkgrid")
fig,ax = plt.subplots(2,figsize = (14,10))
cm=ax[0].scatter(ds_1_init.lon,ds_1_init.z,50,ds_1_init.u_normal,cmap=co.cm.balance,vmin=-0.3,vmax=0.3)
fig.colorbar(cm,ax=ax[0],label = "u_normal",extend='both')

cm=ax[1].scatter(ds_2_init.lon,ds_2_init.z,50,ds_2_init.u_normal,cmap=co.cm.balance,vmin=-0.3,vmax=0.3)
fig.colorbar(cm,ax=ax[1],label = "u_normal",extend='both')

ax[0].set_xlim(lonRange)
ax[0].set_ylim(depthRange)
ax[0].invert_yaxis()

ax[1].set_xlim(lonRange)
ax[1].set_ylim(depthRange)
ax[1].invert_yaxis()

In [None]:
def transports(ds_1_init,ds_2_init):
    ds_1_init_mask = ds_1_init.where(ds_1_init.lon > lonRange[0]).where(ds_1_init.lon < lonRange[1])
    ds_1_init_mask = ds_1_init_mask.where(ds_1_init.z > depthRange[0]).where(ds_1_init.z < depthRange[1])
    ds_2_init_mask = ds_2_init.where(ds_2_init.lon > lonRange[0]).where(ds_2_init.lon < lonRange[1])
    ds_2_init_mask = ds_2_init_mask.where(ds_2_init.z > depthRange[0]).where(ds_2_init.z < depthRange[1])

    transport_1 = ds_1_init_mask.u_normal.sum(dim='traj')*particle_section_area
    transport_2 = ds_2_init_mask.u_normal.sum(dim='traj')*particle_section_area


    temperature_transport_1 = (ds_1_init_mask.u_normal * ds_1_init_mask.temp).sum(dim='traj')*particle_section_area
    temperature_transport_2 = (ds_2_init_mask.u_normal * ds_2_init_mask.temp).sum(dim='traj')*particle_section_area
    salt_transport_1 = (ds_1_init_mask.u_normal * ds_1_init_mask.salt).sum(dim='traj')*particle_section_area
    salt_transport_2 = (ds_2_init_mask.u_normal * ds_2_init_mask.salt).sum(dim='traj')*particle_section_area

    print(transport_1.data,transport_2.data)
    print(temperature_transport_1.data,temperature_transport_2.data)
    print(salt_transport_1.data,salt_transport_2.data)
    print(temperature_transport_1.data/transport_1.data,temperature_transport_2.data/transport_2.data)
    print(salt_transport_1.data/transport_1.data,salt_transport_2.data/transport_2.data)
    
    return

In [None]:
transports(ds_1_init,ds_2_init)

In [None]:
# conda list

In [None]:
ds_1_lab_sea_in.max("obs")

In [None]:
def section_scatter_plot(ds_1_init,ds_2_init):
    sns.set(style="darkgrid")
    fig,ax = plt.subplots(2,figsize = (14,10))
    cm=ax[0].scatter(ds_1_init.lon,ds_1_init.z,50,ds_1_init.u_normal,cmap=co.cm.balance,vmin=-0.3,vmax=0.3)
    fig.colorbar(cm,ax=ax[0],label = "u_normal",extend='both')

    cm=ax[1].scatter(ds_2_init.lon,ds_2_init.z,50,ds_2_init.u_normal,cmap=co.cm.balance,vmin=-0.3,vmax=0.3)
    fig.colorbar(cm,ax=ax[1],label = "u_normal",extend='both')

    ax[0].set_xlim(lonRange)
    ax[0].set_ylim(depthRange)
    ax[0].invert_yaxis()

    ax[1].set_xlim(lonRange)
    ax[1].set_ylim(depthRange)
    ax[1].invert_yaxis()

In [None]:
section_scatter_plot(ds_1_init.where(ds_1_lab_sea_in.max("obs")),ds_2_init.where(ds_2_lab_sea_in.max("obs")))

In [None]:
transports(ds_1_init.where(ds_1_lab_sea_in.max("obs")),ds_2_init.where(ds_2_lab_sea_in.max("obs")))

In [None]:
section_scatter_plot(ds_1_init.where(ds_1_lab_sea_in.max("obs")==False).where(ds_1_60w_in.max("obs")),ds_2_init.where(ds_2_lab_sea_in.max("obs")==False).where(ds_2_60w_in.max("obs")))

In [None]:
transports(ds_1_init.where(ds_1_lab_sea_in.max("obs")==False).where(ds_1_60w_in.max("obs")),ds_2_init.where(ds_2_lab_sea_in.max("obs")==False).where(ds_2_60w_in.max("obs")))