# missing data interpolation

statistics is the answer to everything

### potential shenanigans

"Several techniques have been used to fill the gaps in either the UWLS or OI derived total vector maps.

These are implemented using covariance derived from normal mode analysis (Lipphardt et al. 2000), open-boundary modal analysis (OMA) (Kaplan and Lekien 2007), and empirical orthogonal function (EOF) analysis (Beckers and Rixen 2003; Alvera-Azcárate et al. 2005); and using idealized or smoothed observed covariance (Davis 1985)."

- normal mode analysis
- open-boundary modal analysis (OMA)
- empirical orthogonal function analysis (EOF)
- use idealized/smoothed observed covariance

---

### other ideas

DINEOF (could only find an implementation in R)

to be honest I don't understand any of these methods but they look cool

### currently implemented:

rip data straight from the lower resolution data for areas where data is considered missing in the high resolution data

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import math
from pathlib import Path
import time
import matplotlib.pyplot as plt
import xarray as xr
import numpy as np
from parcels import FieldSet
from datetime import timedelta, datetime

import utils
from parcels_utils import xr_dataset_to_fieldset, HFRGrid

### target and interp_references

`target` is the data you are interpolating.

`interp_references` is a list of reference data to interpolate from. A few specifications:
- should be ordered from most accurate data to least accurate (highest to lowest resolution)
- time domain should be identical or bigger than the one of the target
- lat and lon domain should be bigger than the target's to prevent any out-of-bounds complications

`mask_nc` must have the exact same lat and lon dimensions of the target

In [None]:
files_root = Path("/Volumes/T7/Documents/Programs/scripps-cordc/parcels_westcoast/current_netcdfs")

target = HFRGrid(files_root / "west_coast_1km_hourly/tj_plume_2020-08.nc")

interp_references = [
    HFRGrid(files_root / "west_coast_2km_hourly/tj_plume_2020-08.nc"),
    HFRGrid(files_root / "west_coast_6km_hourly/tj_plume_2020-08.nc"),
]

mask_nc = HFRGrid(files_root / "west_coast_1km_hourly/tj_sample.nc", init_fs=False)

### check validity of interpolation references

In [None]:
def validate_grids(target, references, mask):
    targ_times, targ_lats, targ_lons = target.get_coords()
    targ_min = (targ_lats[0], targ_lons[0])
    targ_max = (targ_lats[-1], targ_lons[-1])
    # check references
    for ref in references:
        ref_times, ref_lats, ref_lons = ref.get_coords()
        lat_inbounds = (ref_lats[0] <= targ_min[0]) and (ref_lats[-1] >= targ_max[0])
        lon_inbounds = (ref_lons[0] <= targ_min[1]) and (ref_lons[-1] >= targ_max[1])
        time_inbounds = (ref_times[0] <= targ_times[0]) and (ref_times[-1] >= targ_times[-1])
        if not (lat_inbounds and lon_inbounds and time_inbounds):
            raise ValueError("Incorrect reference dimensions")
    # check mask
    _, mask_lats, mask_lons = mask.get_coords()
    mask_same_res = (len(targ_lats) == len(mask_lats)) and (len(targ_lons) == len(mask_lons))
    if not mask_same_res:
        raise ValueError("Mask is not the same lat/lon shape as target")
    # mask_nc should just be sliced before being used
    # change these asserts to >= later when that's done
    lat_inbounds = (mask_lats[0] == targ_min[0]) and (mask_lats[-1] == targ_max[0])
    lon_inbounds = (mask_lons[0] == targ_min[1]) and (mask_lons[-1] == targ_max[1])
    if not (lat_inbounds and lon_inbounds):
        raise ValueError("Incorrect mask dimensions")

In [None]:
validate_grids(target, interp_references, mask_nc)

### interpolation type

more information can be found in the `tutorial_interpolation` notebook

EDIT: just use `linear`

## nan values and parcels

note that when this xarray Dataset is passed into parcels, all the nan values change to 0 and the mask generation won't work anymore

so the Dataset is copied for use with the FieldSet instead

In [None]:
no_data = utils.generate_mask_none(mask_nc.xrds["u"].values)
no_data = np.tile(no_data, (target.xrds["time"].size, 1, 1))
invalid = utils.generate_mask_invalid(target.xrds["u"].values)
invalid_pos = np.where(invalid)
num_invalid = invalid.sum()
print(f"total invalid values on target data: {num_invalid}")

### use of Parcels Field for interpolation

indexing Field values goes [time, depth, lat, lon]

Field does linear interpolation automatically when indexing values between it's coordinate values

In [None]:
def get_interped(i, ref, invalid_where):
    """
    Args:
        i (int): index on invalid_where
        ref (HFRGrid): reference Dataset
        invalid_where (array-like): (3, n) dimensional array representing all invalid positions
    
    Returns:
        (u, v): (nan, nan) if no data was found, interpolated values otherwise
    """
    time_diff = np.diff(ref.fieldset_flat.U.grid.time)[0]
    t = invalid_where[0][i]
    lat = target.lats[invalid_where[1][i]]
    lon = target.lons[invalid_where[2][i]]
    current_u, current_v = ref.get_fs_current(t * time_diff, lat, lon)
    current_abs = abs(current_u) + abs(current_v)
    # if both the u and v components are 0, there's probably no data there
    if np.isnan(ref.get_closest_current(t, lat, lon)[0]) or current_abs == 0:
        return np.nan, np.nan
    return current_u, current_v

### linear interpolation using lower resolution data

In [None]:
target_interped_u = target.xrds["u"].values.copy()
target_interped_v = target.xrds["v"].values.copy()
invalid_interped = invalid.copy()
for f in interp_references:
    invalid_pos_new = np.where(invalid_interped)
    num_invalid_new = invalid_interped.sum()
    arr_u = np.zeros(num_invalid_new)
    arr_v = np.zeros(num_invalid_new)
    print(f"Attempting to interpolate {num_invalid_new} points...")
    for i in range(num_invalid_new):
        c_u, c_v = get_interped(i, f, invalid_pos_new)
        arr_u[i] = c_u
        arr_v[i] = c_v
    target_interped_u[invalid_pos_new] = arr_u
    target_interped_v[invalid_pos_new] = arr_v
    invalid_interped = utils.generate_mask_invalid(target_interped_u)
    print(f"total invalid values after interpolation with {f}: {invalid_interped.sum()}")
    print(f"    values filled: {num_invalid_new - invalid_interped.sum()}")
print(f"total invalid values on interpolated: {invalid_interped.sum()}")

## even more filling with PLS and smoothing with DCT shenanigans

uses the matlab engine and smoothn.m

https://www.mathworks.com/help/matlab/matlab-engine-for-python.html

https://www.mathworks.com/matlabcentral/fileexchange/25634-smoothn

In [None]:
import matlab.engine

eng = matlab.engine.start_matlab()

target_smoothed_u = target_interped_u.copy()
target_smoothed_v = target_interped_v.copy()
u_list = target_smoothed_u.tolist()
v_list = target_smoothed_v.tolist()

print(f"Filling {len(u_list)} fields...")
for i in range(len(u_list)):
    u_mat = matlab.double(u_list[i])
    v_mat = matlab.double(v_list[i])
    uv_smooth = eng.smoothn([u_mat, v_mat], "robust")
    u_array = np.empty(uv_smooth[0].size)
    v_array = np.empty(uv_smooth[1].size)
    u_array[:] = uv_smooth[0]
    v_array[:] = uv_smooth[1]
    target_smoothed_u[i] = u_array
    target_smoothed_v[i] = v_array

# mask the filled data
target_smoothed_u[no_data] = np.nan
target_smoothed_v[no_data] = np.nan

eng.quit()

### formatting and saving

In [None]:
# re-add coordinates, dimensions, and metadata to interpolated data
darr_u = utils.conv_to_dataarray(target_smoothed_u, target.xrds["u"])
darr_v = utils.conv_to_dataarray(target_smoothed_v, target.xrds["v"])
target_interped_xrds = target.xrds.drop_vars(["u", "v"]).assign(u=darr_u, v=darr_v)

In [None]:
save_path = str(target.path).split(".nc")[0] + "_interped.nc"
target_interped_xrds.to_netcdf(save_path)
print(f"saved to {save_path}")

### display field to see if interpolation worked

In [None]:
fs_interp = xr_dataset_to_fieldset(target_interped_xrds)
target.fieldset.U.show()  # uninterpolated
fs_interp.U.show()  # interpolated, gapfilled, smoothed