# Interpolation of a time series

This example shows how to interpolate a time series using the library.

In this example, we consider the time series of MSLA maps distributed by AVISO/CMEMS.

## Initialize Dataset

Here we load the dataset from the zarr store. Note that this very large dataset initializes nearly instantly, and we can see the full list of variables and coordinates.

In [None]:
import intake
cat = intake.Catalog("https://raw.githubusercontent.com/pangeo-data/pangeo-datastore"
                     "/master/intake-catalogs/ocean.yaml")
ds = cat["sea_surface_height"].to_dask()
ds

## Handle the time series

We implement a class to handle a time series and on demand loading the data required to interpolate data over a specific time period.

In [None]:
import datetime
import numpy as np
import pandas as pd
import pyinterp.backends.xarray


class TimeSeries:
    """Handling of MSLA AVISO maps"""

    def __init__(self, ds):
        self.ds = ds
        self.series, self.dt = self._load_ts()
        
    @staticmethod
    def _is_sorted(array):
        indices = np.argsort(array)
        return np.all(indices == np.arange(len(indices)))

    def _load_ts(self):
        """Loading the time series into memory."""
        time = self.ds.time
        assert self._is_sorted(time)

        series = pd.Series(time)
        frequency = set(np.diff(series.values.astype("datetime64[s]")).astype("int64"))
        if len(frequency) != 1:
            raise RuntimeError(
                "Time series does not have a constant step between two "
                f"grids: {frequency} seconds")
        return series, datetime.timedelta(seconds=float(frequency.pop()))
    
    def load_dataset(self, varname, start, end):
        """Loading the time series into memory for the defined period.

        Args:
            varname (str): Name of the variable to be loaded into memory.
            start (datetime.datetime): Date of the first map to be loaded.
            end (datetime.datetime): Date of the last map to be loaded.

        Return:
            pyinterp.backends.xarray.Grid3D: The interpolator handling the
            interpolation of the grid series.
        """
        if start < self.series.min() or end > self.series.max():
            raise IndexError(
                f"period [{start}, {end}] out of range [{self.series.min()}, "
                f"{self.series.max()}]")
        first = start - self.dt
        last = end + self.dt

        selected = self.series[(self.series >= first) & (self.series < last)]
        print(f"fetch data from {selected.min()} to {selected.max()}")
        
        data_array = ds[varname].isel(time=selected.index)
        return pyinterp.backends.xarray.Grid3D(data_array)

## Load dataset

Finally, the functions necessary to load the test datset into memory are added. This file contains several columns defining the float identifier, the date of the measurement, the longitude and the latitude of the measurement.

In [None]:
def cnes_jd_to_datetime(seconds):
    """Convert a date expressed in seconds since 1950 into a calendar
    date."""
    return datetime.datetime.utcfromtimestamp(
        ((seconds / 86400.0) - 7305.0) * 86400.0)


def load_positions():
    """Loading and formatting the dataset."""
    df = pd.read_csv("../tests/dataset/positions.csv",
                     header=None,
                     sep=r";",
                     usecols=[0, 1, 2, 3],
                     names=["id", "time", "lon", "lat"],
                     dtype=dict(id=np.uint32,
                                time=np.float64,
                                lon=np.float64,
                                lat=np.float64))
    df.mask(df == 1.8446744073709552e+19, np.nan, inplace=True)
    df["time"] = df["time"].apply(cnes_jd_to_datetime)
    df.set_index('time', inplace=True)
    df["sla"] = np.nan
    return df.sort_index()

df = load_positions()

## Implementation of interpolation

We create the object that will handle the download of data for the periods required for the interpolation.

In [None]:
time_series = TimeSeries(ds)

The function below, allows to cluster the processing period into sub-periods in order to load the grids in blocks.

In [None]:
def periods(df, time_series, frequency='W'):
    """Return the list of periods covering the time series loaded in
    memory."""
    period_start = df.groupby(
        df.index.to_period(frequency))["sla"].count().index

    for start, end in zip(period_start, period_start[1:]):
        start = start.to_timestamp()
        if start < time_series.series[0]:
            start = time_series.series[0]
        end = end.to_timestamp()
        yield start, end
    yield end, df.index[-1] + time_series.dt

Finally, the interpolation function is written for one of the sub-periods selected by the function `periods`.

In [None]:
def interpolate(df, time_series, start, end):
    """Interpolate the time series over the defined period."""
    interpolator = time_series.load_dataset("sla", start, end)
    mask = (df.index >= start) & (df.index < end)
    selected = df.loc[mask, ["lon", "lat"]]
    df.loc[mask, ["sla"]] = interpolator.trivariate(dict(
        longitude=selected["lon"].values,
        latitude=selected["lat"].values,
        time=selected.index.values),
        interpolator="inverse_distance_weighting",
        num_threads=0)

In [None]:
for start, end in periods(df, time_series, frequency='M'):
    interpolate(df, time_series, start, end)

Visualization of the SLA for a float.

In [None]:
float_id = 62423050
selected_float = df[df.id == float_id]
first = selected_float.index.min()
last = selected_float.index.max()
size = (selected_float.index - first) / (last - first)

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
%matplotlib inline

fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(111, projection=ccrs.PlateCarree(central_longitude=180))
sc = ax.scatter(
    selected_float.lon,
    selected_float.lat,
    s=size*100,
    c=selected_float.sla,
    transform=ccrs.PlateCarree(),
    cmap='jet')
ax.coastlines()
ax.set_title("Time series of SLA "
             "(larger points are closer to the last date)")
ax.add_feature(cfeature.LAND)
ax.add_feature(cfeature.COASTLINE)
ax.set_extent([80, 100, 13.5, 25], crs=ccrs.PlateCarree())
fig.colorbar(sc)