# missing data interpolation

statistics is the answer to everything

### potential shenanigans

"Several techniques have been used to fill the gaps in either the UWLS or OI derived total vector maps.

These are implemented using covariance derived from normal mode analysis (Lipphardt et al. 2000), open-boundary modal analysis (OMA) (Kaplan and Lekien 2007), and empirical orthogonal function (EOF) analysis (Beckers and Rixen 2003; Alvera-Azcárate et al. 2005); and using idealized or smoothed observed covariance (Davis 1985)."

- normal mode analysis
- open-boundary modal analysis (OMA)
- empirical orthogonal function analysis (EOF)
- use idealized/smoothed observed covariance

---

### other ideas

DINEOF (could only find an implementation in R)

to be honest I don't understand any of these methods but they look cool

### currently implemented:

rip data straight from the lower resolution data for areas where data is considered missing in the high resolution data

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import math
import time
import matplotlib.pyplot as plt
import xarray as xr
import numpy as np
from parcels import FieldSet
from datetime import timedelta, datetime

import utils
from parcels_utils import get_file_info, xr_dataset_to_fieldset

### target and interp_references

`target` is the data you are interpolating.

`interp_references` is a list of reference data to interpolate from. A few specifications:
- should be ordered from most accurate data to least accurate (highest to lowest resolution)
- time domain should be identical or bigger than the one of the target
- lat and lon domain should be bigger than the target's to prevent any out-of-bounds complications

`mask_nc` must have the exact same lat and lon dimensions of the target

In [None]:
target = get_file_info(utils.CURRENT_NETCDF_DIR / "west_coast_1km_hourly/tj_plume.nc", utils.DATA_1KM, name="target")

interp_references = [
    get_file_info(utils.CURRENT_NETCDF_DIR / "west_coast_2km_hourly/tj_plume.nc", utils.DATA_2KM, name="ref2km"),
    get_file_info(utils.CURRENT_NETCDF_DIR / "west_coast_6km_hourly/tj_plume.nc", utils.DATA_6KM, name="ref6km")
]

mask_nc = get_file_info(utils.CURRENT_NETCDF_DIR / "west_coast_1km_hourly/tj_sample.nc", utils.DATA_1KM, name="sample")

### check validity of interpolation references

In [None]:
targ_min = (target["lat"][0], target["lon"][0])
targ_max = (target["lat"][-1], target["lon"][-1])
for ir in interp_references:
    print(f"checking {ir['name']}")
    assert ir["lat"][0] <= targ_min[0]
    assert ir["lon"][0] <= targ_min[1]
    assert ir["lat"][-1] >= targ_max[0]
    assert ir["lon"][-1] >= targ_max[1]
    assert ir["timerng"][0] <= target["timerng"][0]
    assert ir["timerng"][1] >= target["timerng"][1]
assert target["res"] == mask_nc["res"]
assert mask_nc["lat"][0] == targ_min[0]
assert mask_nc["lon"][0] == targ_min[1]
assert mask_nc["lat"][-1] == targ_max[0]
assert mask_nc["lon"][-1] == targ_max[1]

### interpolation type

more information can be found in the `tutorial_interpolation` notebook

EDIT: just use `linear`

In [None]:
reference_interp_method = "linear"

for r in interp_references:
    r["fs_flat"].U.interp_method = reference_interp_method
    r["fs_flat"].V.interp_method = reference_interp_method

## nan values and parcels

note that when this xarray Dataset is passed into parcels, all the nan values change to 0 and the mask generation won't work anymore

so the Dataset is copied for use with the FieldSet instead

In [None]:
no_data = utils.generate_mask_none(mask_nc["xrds"]["u"].values)[0]
no_data = np.tile(no_data, (target["xrds"]["time"].size, 1, 1))
invalid = utils.generate_mask_invalid(target["xrds"]["u"].values)
invalid_pos = np.where(invalid)
num_invalid = invalid.sum()
print(f"total invalid values on target data: {num_invalid}")

### use of Parcels Field for interpolation

indexing Field values goes [time, depth, lat, lon]

Field does linear interpolation automatically when indexing values between it's coordinate values

### note

https://stackoverflow.com/questions/12923586/nearest-neighbor-search-python

in theory, the latitude and longitude values are equally spaced. however, the difference between the coordinate values always flucuates a very small amount between two distinct values, so it's not perfectly equally spaced.

from testing, this causes enough error to completely change a simulation, so a kdtree must be used.

### another note

I just looked at the particle trajectory that used the indexing without kd trees and it looked MAJORLY fucked up, so I probably wrote something wrong with the method. TODO will come back to this later.

In [None]:
import scipy.spatial

# set up kdtrees for all references
for f in interp_references:
    f["latkdtree"] = scipy.spatial.cKDTree(np.array([f["lat"]]).T)
    f["lonkdtree"] = scipy.spatial.cKDTree(np.array([f["lon"]]).T)

In [None]:
def get_nearest_current(ref, t, lat, lon):
    index = (t, ref["latkdtree"].query([lat])[1], ref["lonkdtree"].query([lon])[1])
    return ref["xrds"]["u"].values[index], ref["xrds"]["v"].values[index]


def get_current_interp(ref, t, lat, lon):
    return ref["fs_flat"].U[t, 0, lat, lon], ref["fs_flat"].V[t, 0, lat, lon]


def get_interped(i, ref, invalid_where):
    """
    Args:
        i (int): index on invalid_where
        ref (dict): reference Dataset
        invalid_where (array-like): (3, n) dimensional array representing all invalid positions
    
    Returns:
        (u, v): (nan, nan) if no data was found, interpolated values otherwise
    """
    time_diff = np.diff(ref["fs_flat"].U.grid.time)[0]
    t = invalid_where[0][i]
    lat = target["lat"][invalid_where[1][i]]
    lon = target["lon"][invalid_where[2][i]]
    current_u, current_v = get_current_interp(ref, t * time_diff, lat, lon)
    current_abs = abs(current_u) + abs(current_v)
    # if both the u and v components are 0, there's probably no data there
    if np.isnan(get_nearest_current(ref, t, lat, lon)[0]) or current_abs == 0:
        return np.nan, np.nan
    return current_u, current_v

### linear interpolation using lower resolution data

In [None]:
target_interped_u = target["xrds"]["u"].values.copy()
target_interped_v = target["xrds"]["v"].values.copy()
invalid_interped = invalid.copy()
for f in interp_references:
    invalid_pos_new = np.where(invalid_interped)
    num_invalid_new = invalid_interped.sum()
    arr_u = np.zeros(num_invalid_new)
    arr_v = np.zeros(num_invalid_new)
    for i in range(num_invalid_new):
        c_u, c_v = get_interped(i, f, invalid_pos_new)
        arr_u[i] = c_u
        arr_v[i] = c_v
    target_interped_u[invalid_pos_new] = arr_u
    target_interped_v[invalid_pos_new] = arr_v
    invalid_interped = utils.generate_mask_invalid(target_interped_u)
    print(f"total invalid values after interpolation with {f['name']}: {invalid_interped.sum()}")
    print(f"    values filled: {num_invalid_new - invalid_interped.sum()}")

In [None]:
print(f"total invalid values on interpolated: {invalid_interped.sum()}")

## even more filling with PLS and smoothing with DCT shenanigans

uses the matlab engine and smoothn.m

https://www.mathworks.com/help/matlab/matlab-engine-for-python.html

https://www.mathworks.com/matlabcentral/fileexchange/25634-smoothn

In [None]:
import matlab.engine

eng = matlab.engine.start_matlab()

u_list = target_interped_u.tolist()
v_list = target_interped_v.tolist()

for i in range(len(u_list)):
    uv_mat = matlab.double([u_list[i], v_list[i]])
    uv_smooth = eng.smoothn(uv_mat, "robust")
    uv_array = np.empty(uv_smooth.size)
    uv_array[:] = uv_smooth
    target_interped_u[i] = uv_array[0]
    target_interped_v[i] = uv_array[1]
    
target_interped_u[np.where(no_data)] = np.nan
target_interped_v[np.where(no_data)] = np.nan

### formatting and saving

In [None]:
# re-add coordinates, dimensions, and metadata to interpolated data
darr_u = utils.conv_to_dataarray(target_interped_u, target["xrds"]["u"])
darr_v = utils.conv_to_dataarray(target_interped_v, target["xrds"]["v"])
target_interped_xrds = target["xrds"].drop_vars(["u", "v"]).assign(u=darr_u, v=darr_v)

In [None]:
save_path = str(target["path"]).split(".nc")[0] + "_interped.nc"
target_interped_xrds.to_netcdf(save_path)
print(f"saved to {save_path}")

### display field to see if interpolation worked

In [None]:
fs_interp = xr_dataset_to_fieldset(target_interped_xrds)
target["fs"].U.show()  # uninterpolated
fs_interp.U.show()  # interpolated, gapfilled, smoothed