# Meteorcruise vs. FOCI NEMO Test data

_**ATTENTION:** This will load > 5G from the internet, store it on Disk, and consume O(50GB) of memory when running. SO make sure you're fine with this kind of resource use before hitting "Run all cells"._

## Description

We use an [Intake driver](https://github.com/ESM-VFC/intake_pangaeapy) for [`pangaeapy`](https://github.com/pangaea-data-publisher/pangaeapy) to load hydrographic observational data from Meteor cruises M85/1, M85/2, M90, M105, M120 and
- plot positions on a map
- create a [temperature-salinity diagram](https://en.wikipedia.org/wiki/Temperature%E2%80%93salinity_diagram) of all cruises.

We load a NEMO test dataset that covers the same time on the calendar and
- plot surface temperature on a map together with the observed temperature
- select data for the same locations and time stamps as in the observational data set and repeate the temperature-salinity diagrams
 
Along the way, there's a few obstacles:
- Selecting NEMO data on a curvilinear horizontal grid is not directly implemented in xarray, so we use [`xorca_lonlat2ij`](git.geomar.de/python/xorca_lonlat2ij) to find closest indices on the sphere.
- We need to un-elegantly mask the data using the fact that over land, the values never change from an exact `0`, because the mask info is in a different file (the mesh-mask) than the actual data.

_**Note** that we cannot expect a lot of similarity between the in-sity observational data and a free running climate model._

## Parameters

In [None]:
# parameters

esm_vfc_data_dir = "../esm-vfc-data/"
nemo_catalog_url = "https://raw.githubusercontent.com/ESM-VFC/esm-vfc-catalogs/master/catalogs/NEMO_ORCA05_FOCI_Test_Full.yaml"
meteor_catalog_url = "https://raw.githubusercontent.com/ESM-VFC/esm-vfc-catalogs/master/catalogs/METEOR_cruises.yaml"
host = !hostname
host = host[0]
dask_cluster_args = dict(n_workers=4, threads_per_worker=2, memory_limit=12e9, host=host)

## Tech preamble

In [None]:
import numpy as np
import pandas as pd
import xarray as xr

In [None]:
# set up intake catalog
import intake
from esmvfc_cattools import download_zenodo_files_for_entry
import os

os.environ["ESM_VFC_DATA_DIR"] = esm_vfc_data_dir

In [None]:
# set up plotting
import hvplot.pandas
import hvplot.xarray
import geoviews.feature as gf
from holoviews.operation import decimate
from cartopy import crs

In [None]:
import xorca_lonlat2ij as xll2ij

In [None]:
# set up Dask cluster
from dask.distributed import Client
client = Client(**dask_cluster_args)
client

## Get obs data, extract near-surface measurements, plot positions

In [None]:
meteor_catalog = intake.open_catalog(meteor_catalog_url)
list(meteor_catalog)

In [None]:
# OPTIMIZE: Here we could go for dask dataframes partitioned across the cruises.
obs_df = pd.concat(
    (
        meteor_catalog["M85_1_bottles"].read(),
        meteor_catalog["M85_2_bottles"].read(),
        meteor_catalog["M90_bottles"].read(),
        meteor_catalog["M106_bottles"].read(),
        meteor_catalog["M120_bottles"].read()
    ),
    ignore_index=True
)

In [None]:
# Construct and "Event" column that contains the profile number.
# We need to account for the fact that some cruises already have merged
# station and profile events, however.
obs_df["Event"] = obs_df["Event"].where(
    obs_df["Profile"].isnull(),
    obs_df["Event"] + "-" + obs_df["Profile"].fillna("-99").astype(int).astype(str)
)
obs_df = obs_df.drop(columns=["Profile"])

In [None]:
obs_df

In [None]:
# restrict to measurements at minimal depth per Event (= station)
near_surface_obs = obs_df.loc[
    obs_df.groupby("Event")["Depth water"].idxmin()
]
near_surface_obs = near_surface_obs.set_index("Event")
near_surface_obs

In [None]:
(
    near_surface_obs.hvplot(
        "Longitude", "Latitude", geo=True, kind="points", hover=False)
    * gf.coastline
)

_**FIXME:** Hover tool shows wrong values ("Latitude: 7945355th"???)._

In [None]:
len(obs_df), len(near_surface_obs)

In [None]:
(
    decimate(obs_df.hvplot.scatter("Temp", "Sal", alpha=0.2, label="all data", hover=False), max_samples=2_000)
    * decimate(
        near_surface_obs.hvplot.scatter("Temp", "Sal", alpha=0.8, label="surface data", hover=False),
        max_samples=int(2_000 * len(near_surface_obs)**0.5 / len(obs_df)**0.5)  # same thinning as other set of points
    )
)

## Load catalog and fetch data

In [None]:
model_data_cat = intake.open_catalog(nemo_catalog_url)
download_zenodo_files_for_entry(
    model_data_cat["NEMO_ORCA05_FOCI_Test_grid_T"]
)
download_zenodo_files_for_entry(
    model_data_cat["NEMO_ORCA05_FOCI_Test_mesh_mask"]
)

## Restrict to North Atlantic, calc mean SST, plot with obs positions

In [None]:
# hydrographic data
model_dataset = model_data_cat["NEMO_ORCA05_FOCI_Test_grid_T"](
    chunks={"time_counter": 1, "deptht": 23}
).to_dask()
model_dataset = model_dataset.set_coords(["nav_lat", "nav_lon"])
model_dataset["nav_lat"] = model_dataset["nav_lat"].isel(time_counter=0).squeeze()
model_dataset["nav_lon"] = model_dataset["nav_lon"].isel(time_counter=0).squeeze()
model_dataset = xr.decode_cf(model_dataset)

# Need the grid definitions
model_meshmask = model_data_cat["NEMO_ORCA05_FOCI_Test_mesh_mask"](
    chunks={"z": 23}
).to_dask()
model_meshmask = model_meshmask.squeeze()
model_meshmask = xr.decode_cf(model_meshmask)

In [None]:
display(model_dataset)
display(f"{model_dataset.nbytes / 1e9} GB")

In [None]:
display(model_meshmask)
display(f"{model_meshmask.nbytes / 1e9} GB")

In [None]:
# need compute / cast to numpy array here in order for datashade to work
# (see https://datashader.org/user_guide/Performance.html)
model_mean_sst = model_dataset.sosstsst.mean("time_counter").compute()
model_mean_sst = model_mean_sst.where(model_mean_sst != 0)

In [None]:
(
    model_mean_sst.hvplot.quadmesh(
        "nav_lon", "nav_lat",
        geo=True, datashade=True, hover=False)
    * near_surface_obs.hvplot(
        "Longitude", "Latitude",
        geo=True, kind="points", color="red", hover=False)
    * gf.land * gf.coastline
)

## Extract model data along ship track (surface positions)

In [None]:
xll2ij.get_ij?

In [None]:
positions = list(zip(
    near_surface_obs["Latitude"],
    near_surface_obs["Longitude"],
))

depths = near_surface_obs["Depth water"].to_xarray()
depths

times = near_surface_obs["Date/Time"].to_xarray()

lat_ind, lon_ind = xll2ij.get_ij(
    model_meshmask, positions, 't', xgcm=False, xarray_out=True)
lat_ind = lat_ind.rename({"location": "Event"})
lon_ind = lon_ind.rename({"location": "Event"})

In [None]:
model_dataset["deptht"] = model_dataset["deptht"].compute()
model_dataset["time_counter"] = model_dataset["time_counter"].compute()
model_dataset["y"] = np.arange(model_dataset.dims["y"])
model_dataset["x"] = np.arange(model_dataset.dims["x"])

In [None]:
# select
ship_track_data = model_dataset.isel(y=lat_ind, x=lon_ind)
ship_track_data = ship_track_data.sel(deptht=depths, method="nearest")
ship_track_data = ship_track_data.sel(time_counter=times, method="nearest")

# mask
ship_track_data = ship_track_data.where(ship_track_data.votemper != 0)

display(ship_track_data)

In [None]:
%%time

ship_track_data = ship_track_data.compute()

In [None]:
(
    ship_track_data.to_dataframe().hvplot.scatter("votemper", "vosaline", label="surface data, model", hover=False)
    * near_surface_obs.hvplot.scatter("Temp", "Sal", alpha=0.8, label="surface data, obs", hover=False)
)

## Extract model data along ship track (all depths)

In [None]:
positions = list(zip(
    obs_df["Latitude"],
    obs_df["Longitude"],
))

depths = obs_df["Depth water"].to_xarray()
depths

times = obs_df["Date/Time"].to_xarray()

lat_ind, lon_ind = xll2ij.get_ij(
    model_meshmask, positions, 't', xgcm=False, xarray_out=True)
lat_ind = lat_ind.rename({"location": "index"})
lon_ind = lon_ind.rename({"location": "index"})

In [None]:
len(lon_ind)

In [None]:
%%time

# OPTIMIZE: Here, we have ~ 800_000 positions that we select.
#           If we don't fully load the dataset before computing,
#           we're left with an enormous Dask graph.
model_dataset = model_dataset.compute()

# select
ship_track_data = model_dataset.sel(
    y=lat_ind, x=lon_ind, deptht=depths, time_counter=times, method="nearest"
)

# mask
ship_track_data = ship_track_data.where(ship_track_data.votemper != 0)

display(ship_track_data)

In [None]:
%%time

all_depths_plot = (
    decimate(
        ship_track_data.to_dataframe().hvplot.scatter(
            "votemper", "vosaline", alpha=0.4, label="surface data, model", hover=False
        )
    )
    * decimate(
        obs_df.hvplot.scatter(
            "Tpot", "Sal", alpha=0.4, label="surface data, obs", hover=False
        )
    )
)

In [None]:
all_depths_plot