# OSNAP Parcels Experiments


## Technical preamble

In [None]:
# import matplotlib.colors as colors
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import xarray as xr
from datetime import datetime, timedelta
# import seaborn as sns
# from matplotlib.colors import ListedColormap
import cmocean as co
import pandas as pd
import matplotlib.dates as mdates
import cartopy.crs as ccrs
import cartopy



In [None]:
## Parameters
# Project path
project_path = Path.cwd() / '..' / '..' 
project_path = project_path.resolve()

# Parcels track data file
path_data_tracks = Path('data/processed/tracks/osnap/') 
filename_tracks = 'tracks_osnap_forwards_201701_N19886_D365_Rnd123.nc'

# model mask file
data_path = Path("data/external/iAtlantic/")
experiment_name = "VIKING20X.L46-KKG36107B"
mesh_mask_file = project_path / data_path / "mask" / experiment_name / "1_mesh_mask.nc"

#section lonlat file
sectionPath = Path('data/external/sections/')
sectionFilename = 'osnap_pos_wp.txt'
sectionname = 'osnap'
gsrsectionFilename = 'gsr_pos_wp.txt'

degree2km = 1.852*60.0

# choose color map, suggest amp for backwards, tempo_r for forwards

cmap = co.cm.tempo_r
# cmap = co.cm.amp

## plotting functions

In [None]:
# range of stations from west to east, stations 0-12. Python indexing.

def plotTracksBetweenStations(ds,stationRange,xlims,ylims,cmap=co.cm.tempo_r):
    
    dsmask = ds.where(ds.isel(obs=0).lon > lonlat.lon[stationRange[0]]).where(ds.isel(obs=0).lon < lonlat.lon[stationRange[1]])

    plt.figure(figsize = (16,8))

    # plt.scatter(
    #     ds.where(ds.alive == 1).lon.data.flatten(),
    #     ds.where(ds.alive == 1).lat.data.flatten(),
    #     3,
    #     ds.where(ds.alive == 1).temp.data.flatten(),
    #     alpha=0.3, vmin = 2, vmax = 7
    # )
    plt.scatter(
        dsmask.lon.data.flatten(),
        dsmask.lat.data.flatten(),
        3,
#         (dsmask.time.data.flatten()-ds.time.isel(obs=0,traj=0).data).astype('timedelta64[D]')
        mdates.date2num(dsmask.time.data.flatten())
        ,cmap= cmap
        ,zorder=2
    #   ,alpha=0.3
    )
    cb = plt.colorbar(label = "date");
    loc = mdates.MonthLocator()
    cb.ax.yaxis.set_major_locator(loc)
    cb.ax.yaxis.set_major_formatter(mdates.ConciseDateFormatter(loc))

    # plt.plot(titanic_lon, titanic_lat, 'rx', ms = 8, mew = 2);

    depth.isel(y=slice(1400, 2499), x=slice(0, 2404)).plot.contour(
        x="nav_lon", y="nav_lat", colors = 'grey', levels = [200,800,1500,2000,2500,3500],zorder=1
    );
    depth.isel(y=slice(1400, 2499), x=slice(0, 2404)).plot.contour(
        x="nav_lon", y="nav_lat", colors = 'k', levels = [1],zorder=3
    );

    plt.scatter(
        ds.lon.isel(obs=0).data.flatten(),
        ds.lat.isel(obs=0).data.flatten(),2,zorder=4
    )
    plt.scatter(
        dsmask.lon.isel(obs=0).data.flatten(),
        dsmask.lat.isel(obs=0).data.flatten(),2,zorder=5
    )

    plt.plot(gsrlonlat.lon.data[:-1],gsrlonlat.lat.data[:-1],zorder=5)
    
    # plt.savefig('upstream.png')
    plt.ylim(ylims[0],ylims[1])
    plt.xlim(xlims[0],xlims[1])
    # plt.ylim(62,65)
    # plt.xlim(-10,-2.5)


In [None]:
# range of stations from west to east, stations 0-12. Python indexing.

def conditionalplotTracksBetweenStations(ds,condition,stationRange,xlims,ylims,cmap=co.cm.tempo_r):
    
    dsmask = ds.where(ds.isel(obs=0).lon > lonlat.lon[stationRange[0]]).where(ds.isel(obs=0).lon < lonlat.lon[stationRange[1]])
    dscond = dsmask.where(condition)

    plt.figure(figsize = (16,8))

    # plt.scatter(
    #     ds.where(ds.alive == 1).lon.data.flatten(),
    #     ds.where(ds.alive == 1).lat.data.flatten(),
    #     3,
    #     ds.where(ds.alive == 1).temp.data.flatten(),
    #     alpha=0.3, vmin = 2, vmax = 7
    # )
    plt.scatter(
        dscond.lon.data.flatten(),
        dscond.lat.data.flatten(),
        3,
#         (dsmask.time.data.flatten()-ds.time.isel(obs=0,traj=0).data).astype('timedelta64[D]')
        mdates.date2num(dscond.time.data.flatten())
        ,cmap= cmap
        ,zorder=2
    #   ,alpha=0.3
    )
    cb = plt.colorbar(label = "date");
    loc = mdates.MonthLocator()
    cb.ax.yaxis.set_major_locator(loc)
    cb.ax.yaxis.set_major_formatter(mdates.ConciseDateFormatter(loc))

    # plt.plot(titanic_lon, titanic_lat, 'rx', ms = 8, mew = 2);

    depth.isel(y=slice(1400, 2499), x=slice(0, 2404)).plot.contour(
        x="nav_lon", y="nav_lat", colors = 'grey', levels = [200,800,1500,2000,2500,3500],zorder=1
    );
    depth.isel(y=slice(1400, 2499), x=slice(0, 2404)).plot.contour(
        x="nav_lon", y="nav_lat", colors = 'k', levels = [1],zorder=3
    );

#     plt.scatter(
#         ds.lon.isel(obs=0).data.flatten(),
#         ds.lat.isel(obs=0).data.flatten(),2,zorder=4
#     )
    plt.plot(lonlat.lon.data[:-1],lonlat.lat.data[:-1],zorder=4)
    plt.scatter(
        dsmask.lon.isel(obs=0).data.flatten(),
        dsmask.lat.isel(obs=0).data.flatten(),2,zorder=5
    )

    plt.plot(gsrlonlat.lon.data[:-1],gsrlonlat.lat.data[:-1],zorder=5)
    

    # plt.savefig('upstream.png')
    plt.ylim(ylims[0],ylims[1])
    plt.xlim(xlims[0],xlims[1])
    # plt.ylim(62,65)
    # plt.xlim(-10,-2.5)


In [None]:
# range of stations from west to east, stations 0-12. Python indexing.

def conditionalplotTracksBetweenStationsCartopy(ds,condition,stationRange,xlims,ylims,cmap=co.cm.tempo_r):
    
    dsmask = ds.where(ds.isel(obs=0).lon > lonlat.lon[stationRange[0]]).where(ds.isel(obs=0).lon < lonlat.lon[stationRange[1]])
    dscond = dsmask.where(condition)

    central_lon, central_lat = -30, 55
    fig, ax = plt.subplots(figsize = (16,8),subplot_kw={'projection': ccrs.Orthographic(central_lon, central_lat)})
    extent = [-60, 0, 40, 70]
    ax.set_extent(extent)
    ax.gridlines()
    ax.coastlines(resolution='50m')


    # plt.scatter(
    #     ds.where(ds.alive == 1).lon.data.flatten(),
    #     ds.where(ds.alive == 1).lat.data.flatten(),
    #     3,
    #     ds.where(ds.alive == 1).temp.data.flatten(),
    #     alpha=0.3, vmin = 2, vmax = 7
    # )
    pcm = ax.scatter(
        dscond.lon.data.flatten(),
        dscond.lat.data.flatten(),
        3,
#         (dsmask.time.data.flatten()-ds.time.isel(obs=0,traj=0).data).astype('timedelta64[D]')
        mdates.date2num(dscond.time.data.flatten())
        ,cmap= cmap
        ,zorder=2
        ,transform=ccrs.PlateCarree()
    #   ,alpha=0.3
    )
    cb = fig.colorbar(pcm,ax=ax,label = "date");
    loc = mdates.MonthLocator()
    cb.ax.yaxis.set_major_locator(loc)
    cb.ax.yaxis.set_major_formatter(mdates.ConciseDateFormatter(loc))


    depth.isel(y=slice(1400, 2499), x=slice(0, 2404)).plot.contour(ax=ax,transform=ccrs.PlateCarree(),
        x="nav_lon", y="nav_lat", colors = 'grey', levels = [200,800,1500,2000,2500,3500],zorder=1
    );
#     depth.isel(y=slice(1400, 2499), x=slice(0, 2404)).plot.contour(ax=ax,transform=ccrs.PlateCarree(),
#         x="nav_lon", y="nav_lat", colors = 'k', levels = [1],zorder=3
#     );

    ax.scatter(
        ds.lon.isel(obs=0).data.flatten(),
        ds.lat.isel(obs=0).data.flatten(),2,zorder=4,
        transform=ccrs.PlateCarree()
    )
    ax.scatter(
        dsmask.lon.isel(obs=0).data.flatten(),
        dsmask.lat.isel(obs=0).data.flatten(),2,zorder=5,
        transform=ccrs.PlateCarree()
    )

    ax.plot(gsrlonlat.lon.data[:-1],gsrlonlat.lat.data[:-1],zorder=5,transform=ccrs.PlateCarree(),)
    

#     # plt.savefig('upstream.png')
#     plt.ylim(ylims[0],ylims[1])
#     plt.xlim(xlims[0],xlims[1])
#     # plt.ylim(62,65)
    # plt.xlim(-10,-2.5)


### Position test function

In [None]:
def apply_left_of_line(ds, lon_1, lon_2, lat_1, lat_2):
    '''Apply an area crossing criterion.
    
    Larvae in ds selected while they are in a selected area.
    '''
    # particles are selected if they pass through given area.
    position =  ((lon_2 -lon_1) * (ds.lat - lat_1) - 
                     (ds.lon - lon_1) * (lat_2 - lat_1))
                        
    return position >= 0.0, position < 0

## Load data

In [None]:
ds = xr.open_dataset(project_path / path_data_tracks / filename_tracks)

display(ds)

In [None]:
ds.isel(obs=-1).time


In [None]:
mesh_mask = xr.open_dataset(mesh_mask_file)
mesh_mask = mesh_mask.squeeze()
mesh_mask = mesh_mask.set_coords(["nav_lon", "nav_lat", "nav_lev"])

bathy = mesh_mask.mbathy.rename("number of water filled points")

depth = (mesh_mask.e3t_0 * mesh_mask.tmask).sum("z")
# display(mesh_mask)

### osnap section position data

In [None]:
lonlat = xr.Dataset(pd.read_csv(project_path / sectionPath / sectionFilename,delim_whitespace=True))
lonlat['lon'] *= -1.0

lonlat

In [None]:
lonlat.lon.attrs['long_name']='Longitude'
lonlat.lat.attrs['long_name']='Latitude'
lonlat.lon.attrs['standard_name']='longitude'
lonlat.lat.attrs['standard_name']='latitude'
lonlat.lon.attrs['units']='degrees_east'
lonlat.lat.attrs['units']='degrees_north'

lonlat2mean= lonlat.rolling({'dim_0':2}).mean()

lonlatdiff = abs(lonlat.diff('dim_0'))
lonlatdiff['lat'] = lonlatdiff['lat']*degree2km
lonlatdiff['lon'] = lonlatdiff['lon']*degree2km*np.cos(np.radians(lonlat2mean.lat.data[1:]))
lonlatdiff=lonlatdiff.assign({'length':np.sqrt(lonlatdiff['lon']**2+lonlatdiff['lat']**2)})

total_length = lonlatdiff.length.sum().data

In [None]:
central_lon, central_lat = -30, 55
fig, ax = plt.subplots(subplot_kw={'projection': ccrs.Orthographic(central_lon, central_lat)})
extent = [-60, 0, 40, 70]
ax.set_extent(extent)
ax.gridlines()
ax.coastlines(resolution='50m')

lonlat.plot.scatter(ax=ax,transform=ccrs.PlateCarree(),x='lon',y='lat')
lonlat2mean.plot.scatter(ax=ax,transform=ccrs.PlateCarree(),x='lon',y='lat')

### gsr section position data

In [None]:
gsrlonlat = xr.Dataset(pd.read_csv(project_path / sectionPath / gsrsectionFilename,delim_whitespace=True))

gsrlonlat.lat


In [None]:
gsrlonlat.lon.attrs['long_name']='Longitude'
gsrlonlat.lat.attrs['long_name']='Latitude'
gsrlonlat.lon.attrs['standard_name']='longitude'
gsrlonlat.lat.attrs['standard_name']='latitude'
gsrlonlat.lon.attrs['units']='degrees_east'
gsrlonlat.lat.attrs['units']='degrees_north'

gsrlonlat2mean= gsrlonlat.rolling({'dim_0':2}).mean()

gsrlonlatdiff = abs(gsrlonlat.diff('dim_0'))
gsrlonlatdiff['lat'] = gsrlonlatdiff['lat']*degree2km
gsrlonlatdiff['lon'] = gsrlonlatdiff['lon']*degree2km*np.cos(np.radians(gsrlonlat2mean.lat.data[1:]))
gsrlonlatdiff=gsrlonlatdiff.assign({'length':np.sqrt(gsrlonlatdiff['lon']**2+gsrlonlatdiff['lat']**2)})

gsrtotal_length = gsrlonlatdiff.length.sum().data

In [None]:
central_lon, central_lat = -30, 55
fig, ax = plt.subplots(subplot_kw={'projection': ccrs.Orthographic(central_lon, central_lat)})
extent = [-60, 0, 30, 70]
ax.set_extent(extent)
ax.gridlines()
ax.coastlines(resolution='50m')

gsrlonlat.plot.scatter(ax=ax,transform=ccrs.PlateCarree(),x='lon',y='lat')
gsrlonlat2mean.plot.scatter(ax=ax,transform=ccrs.PlateCarree(),x='lon',y='lat')

## Have a quick look

### Release positions of the particles

In [None]:
plt.figure(figsize = (9,4))
plt.scatter(ds.isel(obs=0).lon.data.flatten(),ds.isel(obs=0).z.data.flatten(),s = 2)
plt.ylim(4000,0)

### Test particle positions

#### south/north of osnap

In [None]:
# do north and south separately because of missing values

south = xr.Dataset()
north = xr.Dataset()
for i in range(len(lonlat.lon)-1):
    south['subsect'+str(i)],north['subsect'+str(i)] = apply_left_of_line(ds,lonlat.lon[i+1],lonlat.lon[i],lonlat.lat[i+1],lonlat.lat[i])

south_a = south.subsect0 + south.subsect1 + south.subsect2 
south_b = south.subsect3 * south.subsect4 * south.subsect5
south_c = south.subsect6 + south.subsect7 + south.subsect8 
south_d = south.subsect8 * south.subsect9 * south.subsect10 * south.subsect11 
south_e = south.subsect12
south_all = south_a * south_c * south_e * (south_b + south_d)

north_a = north.subsect0 * north.subsect1 * north.subsect2 
north_b = north.subsect3 + north.subsect4 + north.subsect5
north_c = north.subsect6 * north.subsect7 * north.subsect8 
north_d = north.subsect8 + north.subsect9 + north.subsect10 + north.subsect11
north_e = north.subsect12
north_all = north_a + north_c + north_e + (north_b * north_d)


#### south/north of gsr

In [None]:
# do north and south separately because of missing values

gsrsouth = xr.Dataset()
gsrnorth = xr.Dataset()
for i in range(len(gsrlonlat.lon)-1):
    gsrsouth['subsect'+str(i)],gsrnorth['subsect'+str(i)] = apply_left_of_line(ds,gsrlonlat.lon[i+1],gsrlonlat.lon[i],gsrlonlat.lat[i+1],gsrlonlat.lat[i])

gsrsouth_a = gsrsouth.subsect0 + gsrsouth.subsect1 + gsrsouth.subsect2 
gsrsouth_b = gsrsouth.subsect3 + gsrsouth.subsect4
gsrsouth_c = gsrsouth.subsect5 + gsrsouth.subsect6 + gsrsouth.subsect7
gsrsouth_all = gsrsouth_a * gsrsouth_b * gsrsouth_c

gsrnorth_a = gsrnorth.subsect0 * gsrnorth.subsect1 * gsrnorth.subsect2 
gsrnorth_b = gsrnorth.subsect3 * gsrnorth.subsect4
gsrnorth_c = gsrnorth.subsect5 * gsrnorth.subsect6 * gsrnorth.subsect7
gsrnorth_all = gsrnorth_a + gsrnorth_b + gsrnorth_c

display(gsrsouth_all)


In [None]:
# # plot tracks which go through an area
# conditionalplotTracksBetweenStations(ds,gsrnorth_all.max("obs"),stationRange=[0,12],xlims=[-75,15],ylims=[36,70],cmap=cmap)

# Cartopy plot tracks which go through an area
conditionalplotTracksBetweenStationsCartopy(ds,gsrnorth_all.max("obs"),stationRange=[11,12],xlims=[-75,15],ylims=[36,70],cmap=cmap)

# plot points in an area
# conditionalplotTracksBetweenStations(ds,north_all,stationRange=[0,12],xlims=[-75,15],ylims=[36,70],cmap=cmap)