In [1]:
import numpy as np
import GPy
import xarray as xr

# Function to estimate the lag between two signals
def estimate_lag(signal1, signal2):
    """
    Estimates the lag between two signals by calculating the cross-correlation.
    
    Parameters:
    - signal1: First signal (e.g., reference signal from station 1)
    - signal2: Second signal (e.g., from another station)
    
    Returns:
    - Estimated lag as an integer number of time steps
    """
    valid_indices = ~np.isnan(signal1) & ~np.isnan(signal2)
    if valid_indices.sum() < 2:  # Not enough data to compute lag
        return 0
    
    corr = np.correlate(signal1[valid_indices], signal2[valid_indices], mode='full')
    lag = corr.argmax() - (len(signal2[valid_indices]) - 1)
    return lag


# Function to perform GPR for gap filling using surrounding stations
def gap_fill_station(dataset, target_station_id, var_name='temperature'):
    """
    Performs gap filling for a target station using the data from surrounding stations.
    Uses station's lat/lon as additional features in the Gaussian Process Regression.
    
    Parameters:
    - dataset: xarray.Dataset containing the station data
    - target_station_id: The station ID of the target station to fill
    - var_name: The name of the variable to be filled (e.g., 'temperature')
    
    Returns:
    - Filled values for the target station (with gaps filled by GPR)
    """
    # Get lat/lon information for all stations
    lat = dataset['lat'].values
    lon = dataset['lon'].values

    surrounding_stations = [id_ for id_ in dataset['station_id'].values if id_ != target_station_id]
    target_data = dataset[var_name].loc[{'station_id': target_station_id}].values
    
    # Prepare the results
    filled_data = target_data.copy()

    # Create coordinates for GPR (lat, lon, time)
    coords = np.column_stack([np.full(len(dataset['time']), lat[target_station_id-1]),
                              np.full(len(dataset['time']), lon[target_station_id-1]),
                              dataset['time'].values])

    # Prepare for Gaussian Process Regression
    valid_coords_list = []
    valid_temp_list = []

    # Use surrounding stations for gap filling
    for station_id in surrounding_stations:
        surrounding_data = dataset[var_name].loc[{'station_id': station_id}].values
        lag = estimate_lag(dataset[var_name].loc[{'station_id': 1}].values, surrounding_data)
        
        valid_indices = ~np.isnan(surrounding_data)
        if valid_indices.sum() > 0:
            valid_coords = np.column_stack([np.full(valid_indices.sum(), lat[station_id-1]),
                                            np.full(valid_indices.sum(), lon[station_id-1]),
                                            dataset['time'].values[valid_indices]])
            valid_temp = surrounding_data[valid_indices]
            
            valid_coords_list.append(valid_coords)
            valid_temp_list.append(valid_temp)

    if len(valid_coords_list) == 0:
        return filled_data  # No valid data to perform gap filling

    valid_coords_all = np.vstack(valid_coords_list)
    valid_temp_all = np.concatenate(valid_temp_list)[:, None]

    # Perform Gaussian Process Regression
    kernel = GPy.kern.RBF(input_dim=3, lengthscale=1.0, variance=1.0) + GPy.kern.Bias(input_dim=3)
    gpr_model = GPy.models.GPRegression(valid_coords_all, valid_temp_all, kernel)
    gpr_model.optimize()

    # Predict for the entire target station (including missing data)
    gpr_pred, _ = gpr_model.predict(coords)

    # Fill in the missing data
    missing_indices = np.isnan(target_data)
    filled_data[missing_indices] = gpr_pred[missing_indices].flatten()

    return filled_data


# Example: Create synthetic buoy (station) data with lat/lon features
n_time = 100
n_stations = 5
time = np.linspace(0, 10, n_time)
station_ids = np.arange(1, n_stations + 1)
latitudes = np.random.uniform(30, 50, n_stations)  # Random latitudes between 30° and 50°
longitudes = np.random.uniform(-120, -80, n_stations)  # Random longitudes between -120° and -80°

# Initialize an empty dataset with lat/lon
dataset = xr.Dataset(coords={"station_id": station_ids, "time": time})
dataset['lat'] = ('station_id', latitudes)
dataset['lon'] = ('station_id', longitudes)

# Generate data for the first station (no lag)
data_station1 = (
    20 + 5 * np.exp(-0.2 * time) + 
    1.5 * np.sin(2 * np.pi * time / 5) + 
    np.random.normal(0, 0.5, n_time)
)
dataset['temperature'] = (('station_id', 'time'), np.empty((n_stations, n_time)) * np.nan)
dataset['temperature'].loc[{'station_id': 1}] = data_station1

# Generate data for each subsequent station with mixed lags and some missing data
for i in range(2, n_stations + 1):
    # Define lag: positive for some, negative for others
    lag = (i // 2 * 0.1) if i % 2 == 0 else -((i + 1) // 2) * 0.1

    # Create data for the station, applying lag
    data_with_lag = (
        20 + 5 * np.exp(-0.2 * time) + 
        i * 1.5 * np.sin(2 * np.pi * (time - lag) / 5) + 
        np.random.normal(0, i * 0.5, n_time)
    )

    # Introduce missing data (NaNs)
    data_with_missing = data_with_lag.copy()
    data_with_missing[np.random.choice(n_time, size=10, replace=False)] = np.nan

    # Store in xarray dataset
    dataset['temperature'].loc[{'station_id': i}] = data_with_missing

# Gap fill for a specific station using the surrounding stations (including lat/lon)
target_station_id = 2
filled_data = gap_fill_station(dataset, target_station_id, var_name='temperature')

# View the results
print(f"Original data for station {target_station_id}:")
print(dataset['temperature'].loc[{'station_id': target_station_id}].values)

print(f"Gap-filled data for station {target_station_id}:")
print(filled_data)


Original data for station 2:
[        nan 25.69232121 24.93001293 27.42584209 24.62787981 25.0627165
 25.88565924 25.97081476 26.5166489  26.43451355 27.36131416 28.74388377
 26.97052614 26.43429853 27.22667206 26.31489939 27.89841905 25.51735009
 25.33620906 25.56864268         nan 23.96786821 24.3138378  25.34109662
 26.191475   23.84777865 21.78125223         nan 23.02269022 21.60378558
 21.61173986         nan 20.12711629 20.68839563 20.701761   19.07722139
         nan 20.43609435 17.60084334 19.89442522 19.27070782 19.09918067
 20.10287729 20.35724592         nan 19.20103984 20.62096454 21.42854321
         nan 22.86132938 22.64895147 21.36500318 20.62698366 23.84748392
         nan 23.97393404 25.97372542 23.78754727 24.08493461 24.15012245
 24.49691293 25.88574356 22.66849423 24.34218023 25.09503859 25.73618422
 24.18365012         nan 23.49947198 24.19137692 23.57925739 23.68245038
 22.10406175 21.08848438 21.34777769 20.65848609 20.62227524 18.5215238
 19.20342759 20.00437049