In [None]:
import numpy as np
import rasterio
from scipy.interpolate import griddata
import matplotlib.pyplot as plt

def fill_gaps_with_interpolation(image_array, method='linear'):
    """
    Interpolate missing data in an image using surrounding pixels.
    """
    # Create coordinate arrays
    rows, cols = image_array.shape
    X, Y = np.meshgrid(np.arange(cols), np.arange(rows))

    # Mask of valid values (assuming missing values are 0 or NaN)
    mask_valid = (image_array != 0) & ~np.isnan(image_array)

    # Known (valid) pixel positions and values
    known_points = np.column_stack((X[mask_valid], Y[mask_valid]))
    known_values = image_array[mask_valid]

    # All points (for interpolation)
    all_points = np.column_stack((X.ravel(), Y.ravel()))

    # Interpolate
    filled_values = griddata(known_points, known_values, all_points, method=method)

    # Reshape back to image
    filled_image = filled_values.reshape(image_array.shape)

    return filled_image

# Example usage
with rasterio.open('datasets2/images/2.tif') as src:
    band = src.read(1).astype(np.float32)  # Read one band

# Interpolate missing values
filled_band = fill_gaps_with_interpolation(band, method='linear')

# Display original vs filled
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(band, cmap='gray')
plt.title("Original with Gaps")

plt.subplot(1, 2, 2)
plt.imshow(filled_band, cmap='gray')
plt.title("Filled with Spatial Interpolation")

plt.tight_layout()
plt.show()
