# Skyborn GridFill Tutorial: Advanced Data Interpolation

This tutorial demonstrates how to use the **Skyborn gridfill** package for advanced data interpolation and gap-filling in atmospheric datasets. We'll work with real wind field data to show how gridfill can intelligently fill missing data regions using various interpolation methods.

## What is GridFill?

GridFill is a sophisticated interpolation package that:
- Fills missing data regions in gridded datasets
- Preserves physical relationships and gradients
- Supports multiple interpolation algorithms
- Works seamlessly with xarray data structures
- Handles complex boundary conditions

The **gridfill** package provides multiple interpolation methods for different use cases.

## 1. Setup and Data Loading

First, let's import the necessary packages and load our wind data:

In [None]:
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.util import add_cyclic_point
import cmaps
import warnings
warnings.filterwarnings('ignore')

# Import gridfill interfaces
from skyborn.gridfill import fill  # Main gridfill function
from skyborn.gridfill.xarray import xr_fill  # xarray interface

# Import skyborn plotting utilities
from skyborn.plot import add_equal_axes, curved_quiver

# Set up matplotlib for better plots
plt.rcParams['figure.figsize'] = (14, 10)
plt.rcParams['font.size'] = 11
plt.rcParams['axes.titlesize'] = 13
plt.rcParams['axes.labelsize'] = 11
config = {
    "font.family": 'DejaVu Sans',
    "font.size": 15,
    'font.weight': 'bold',
    "mathtext.fontset": 'stix',
    "font.serif": ['cmb10'],
    "axes.unicode_minus": False,
    "axes.labelweight": "bold",      
    "axes.labelsize": 15,
}
plt.rcParams.update(config)
print("Libraries imported successfully!")
print("gridfill and skyborn.plot utilities loaded!")

In [None]:
# Setup for saving documentation images
from pathlib import Path

# Define the documentation images directory  
notebook_dir = Path.cwd()
if 'notebooks' in str(notebook_dir):
    docs_images_dir = notebook_dir.parent / 'images'
else:
    docs_images_dir = Path('docs/source/images')

docs_images_dir.mkdir(parents=True, exist_ok=True)

def save_gallery_figure(fig, filename, dpi=300):
    """Save figure to documentation gallery with high quality"""
    filepath = docs_images_dir / filename
    fig.savefig(filepath, dpi=dpi, bbox_inches='tight', facecolor='white')
    return filepath

### Load Wind Data for GridFill Demonstration

We'll use the same wind data from the windspharm tutorial and create artificial missing data regions to demonstrate gridfill capabilities:

In [None]:
# Load wind data from NetCDF files
data_path = '../../../src/skyborn/windspharm/examples/example_data/'

# Load datasets
ds_u = xr.open_dataset(data_path + 'uwnd_mean.nc')
ds_v = xr.open_dataset(data_path + 'vwnd_mean.nc')

# Extract wind components as DataArrays (select January)
u_wind = ds_u.uwnd.isel(time=0)  # January
v_wind = ds_v.vwnd.isel(time=0)

print("Wind data loaded successfully!")
print(f"U wind shape: {u_wind.shape}")
print(f"V wind shape: {v_wind.shape}")
print(f"Coordinate dimensions: {u_wind.dims}")
print(f"\nLatitude range: {u_wind.latitude.min().values:.1f}° to {u_wind.latitude.max().values:.1f}°")
print(f"Longitude range: {u_wind.longitude.min().values:.1f}° to {u_wind.longitude.max().values:.1f}°")

# Calculate wind speed for visualization
wind_speed = np.sqrt(u_wind**2 + v_wind**2)
print(f"Wind speed range: {wind_speed.min().values:.2f} to {wind_speed.max().values:.2f} m/s")

## 2. Create Missing Data Regions

To demonstrate gridfill capabilities, we'll artificially create missing data regions that simulate real-world scenarios like satellite data gaps or sensor failures:

In [None]:
# Create copies of original data
u_original = u_wind.copy()
v_original = v_wind.copy()
speed_original = wind_speed.copy()

# Create missing data regions to simulate realistic scenarios
u_missing = u_wind.copy()
v_missing = v_wind.copy()
speed_missing = wind_speed.copy()

# Get coordinate arrays
lats = u_wind.latitude.values
lons = u_wind.longitude.values
lat_mesh, lon_mesh = np.meshgrid(lats, lons, indexing='ij')

# Create several missing data regions
missing_mask = np.zeros_like(lat_mesh, dtype=bool)

# Region 1: Large rectangular gap (simulating satellite swath gap)
lat_mask1 = (lat_mesh >= 20) & (lat_mesh <= 40)
lon_mask1 = (lon_mesh >= 120) & (lon_mesh <= 160)
missing_mask |= (lat_mask1 & lon_mask1)

# Region 2: Circular region (simulating storm center data loss)
center_lat, center_lon = 10, 200
radius = 15  # degrees
dist = np.sqrt((lat_mesh - center_lat)**2 + (lon_mesh - center_lon)**2)
missing_mask |= (dist <= radius)

# Region 3: Irregular coastal region (simulating land mask issues)
lat_mask3 = (lat_mesh >= -10) & (lat_mesh <= 10)
lon_mask3 = (lon_mesh >= 280) & (lon_mesh <= 320)
# Add some irregularity
irregular_mask = np.sin(lon_mesh * np.pi / 10) * np.cos(lat_mesh * np.pi / 5) > 0.3
missing_mask |= (lat_mask3 & lon_mask3 & irregular_mask)

# Apply missing data mask
u_missing = u_missing.where(~missing_mask, np.nan)
v_missing = v_missing.where(~missing_mask, np.nan)
speed_missing = speed_missing.where(~missing_mask, np.nan)

# Count missing points
total_points = lat_mesh.size
missing_points = missing_mask.sum()
missing_percent = (missing_points / total_points) * 100

print(f"Missing data statistics:")
print(f"Total grid points: {total_points}")
print(f"Missing points: {missing_points}")
print(f"Missing percentage: {missing_percent:.1f}%")
print(f"Created {np.sum([lat_mask1.any(), (dist <= radius).any(), (lat_mask3 & lon_mask3 & irregular_mask).any()])} missing regions")

### Visualize Missing Data Regions

Let's create an elegant visualization showing the original data and the missing data regions:

In [None]:
# Enhanced plotting function with beautiful styling
def plot_field_elegant(ax, data, title, levels, extend='both', add_features=True):
    """Plot atmospheric field with elegant Robinson projection and styling"""
    if add_features:
        # Add sophisticated map features
        ax.add_feature(cfeature.COASTLINE, alpha=0.8, linewidth=1.2)
        ax.add_feature(cfeature.OCEAN, color='#f0f8ff', alpha=0.3)
        ax.add_feature(cfeature.LAND, color='#f5f5dc', alpha=0.4)
    
    # Add cyclic point for smooth visualization
    data_cyclic, lon_cyclic = add_cyclic_point(data.values, coord=data.longitude.values)
    
    # Create beautiful contour plot
    im = ax.contourf(lon_cyclic, data.latitude.values, data_cyclic,
                     levels=levels, cmap=cmaps.BlueWhiteOrangeRed, 
                     transform=ccrs.PlateCarree(), extend=extend)
    
    # Set elegant title
    ax.set_title(title, fontsize=15, fontweight='bold', pad=15)
    ax.set_global()
    
    return im

# Create comparison plot
fig = plt.figure(figsize=(20, 12))

# Original data
ax1 = plt.subplot(2, 2, 1, projection=ccrs.Robinson(central_longitude=180))
levels = np.linspace(0, speed_original.max().values, 20)
im1 = plot_field_elegant(ax1, speed_original, 'Original Wind Speed (January 850 hPa)', levels, extend='max')

# Data with missing regions
ax2 = plt.subplot(2, 2, 2, projection=ccrs.Robinson(central_longitude=180))
im2 = plot_field_elegant(ax2, speed_missing, 'Wind Speed with Missing Data', levels, extend='max')

# Missing data mask
ax3 = plt.subplot(2, 2, 3, projection=ccrs.Robinson(central_longitude=180))
# Create mask visualization
mask_data = xr.DataArray(missing_mask.astype(float), coords=speed_original.coords, dims=speed_original.dims)
mask_data = mask_data.where(missing_mask, np.nan)
ax3.add_feature(cfeature.COASTLINE, alpha=0.8, linewidth=1.2)
ax3.add_feature(cfeature.OCEAN, color='#f0f8ff', alpha=0.3)
ax3.add_feature(cfeature.LAND, color='#f5f5dc', alpha=0.4)
mask_cyclic, _ = add_cyclic_point(mask_data.values, coord=mask_data.longitude.values)
im3 = ax3.contourf(lon_cyclic, mask_data.latitude.values, mask_cyclic,
                   levels=[0.5, 1.5], colors=['red'], alpha=0.7,
                   transform=ccrs.PlateCarree())
ax3.set_title('Missing Data Regions (Red Areas)', fontsize=15, fontweight='bold', pad=15)
ax3.set_global()

# Data availability percentage
ax4 = plt.subplot(2, 2, 4, projection=ccrs.Robinson(central_longitude=180))
availability = (~missing_mask).astype(float) * 100
avail_data = xr.DataArray(availability, coords=speed_original.coords, dims=speed_original.dims)
avail_cyclic, _ = add_cyclic_point(avail_data.values, coord=avail_data.longitude.values)
im4 = ax4.contourf(lon_cyclic, avail_data.latitude.values, avail_cyclic,
                   levels=np.linspace(0, 100, 11), cmap=cmaps.amwg256,
                   transform=ccrs.PlateCarree(), extend='neither')
ax4.add_feature(cfeature.COASTLINE, alpha=0.8, linewidth=1.2)
ax4.add_feature(cfeature.OCEAN, color='#f0f8ff', alpha=0.3)
ax4.add_feature(cfeature.LAND, color='#f5f5dc', alpha=0.4)
ax4.set_title('Data Availability (%)', fontsize=15, fontweight='bold', pad=15)
ax4.set_global()

# Add colorbars using add_equal_axes
cax1 = add_equal_axes(ax1, 'bottom', 0.06, 0.02)
cbar1 = plt.colorbar(im1, cax=cax1, orientation='horizontal')
cbar1.set_label('Wind Speed (m/s)', fontsize=11, fontweight='bold')
cbar1.ax.tick_params(labelsize=9)

cax2 = add_equal_axes(ax2, 'bottom', 0.06, 0.02)
cbar2 = plt.colorbar(im2, cax=cax2, orientation='horizontal')
cbar2.set_label('Wind Speed (m/s)', fontsize=11, fontweight='bold')
cbar2.ax.tick_params(labelsize=9)

cax4 = add_equal_axes(ax4, 'bottom', 0.06, 0.02)
cbar4 = plt.colorbar(im4, cax=cax4, orientation='horizontal')
cbar4.set_label('Availability (%)', fontsize=11, fontweight='bold')
cbar4.ax.tick_params(labelsize=9)

plt.tight_layout()
save_gallery_figure(fig, 'gridfill_missing_data_overview.png')
plt.show()

print("Missing data regions created and visualized!")
print("\nMissing data regions represent common real-world scenarios:")
print("• Large rectangular gap: Satellite swath gaps or instrument failures")
print("• Circular region: Storm center data loss or sensor blind spots")
print("• Irregular coastal region: Land-sea boundary issues or topographic effects")

## 3. GridFill Interpolation Methods

Now let's apply different gridfill interpolation methods to fill the missing data regions. We'll compare multiple approaches:

In [None]:
# Apply different gridfill methods
print("Applying GridFill interpolation methods...")

# Method 1: Basic poisson interpolation (default)
print("\n1. Poisson interpolation...")
u_filled_poisson = xr_fill(u_missing, method='poisson')
v_filled_poisson = xr_fill(v_missing, method='poisson')
speed_filled_poisson = np.sqrt(u_filled_poisson**2 + v_filled_poisson**2)

# Method 2: Spring relaxation method
print("2. Spring relaxation method...")
u_filled_spring = xr_fill(u_missing, method='spring', eps=1e-4, relax=0.6)
v_filled_spring = xr_fill(v_missing, method='spring', eps=1e-4, relax=0.6)
speed_filled_spring = np.sqrt(u_filled_spring**2 + v_filled_spring**2)

# Method 3: Iterative improvement
print("3. Iterative method...")
u_filled_iterative = xr_fill(u_missing, method='iterative', max_iter=1000)
v_filled_iterative = xr_fill(v_missing, method='iterative', max_iter=1000)
speed_filled_iterative = np.sqrt(u_filled_iterative**2 + v_filled_iterative**2)

print("\nGridFill interpolation completed for all methods!")

# Calculate interpolation statistics
methods = ['Poisson', 'Spring', 'Iterative']
filled_speeds = [speed_filled_poisson, speed_filled_spring, speed_filled_iterative]

print("\nInterpolation quality statistics:")
for method, filled in zip(methods, filled_speeds):
    # Calculate RMS error in missing regions
    original_missing = speed_original.where(missing_mask)
    filled_missing = filled.where(missing_mask)
    
    rms_error = np.sqrt(((original_missing - filled_missing)**2).mean()).values
    max_error = np.abs(original_missing - filled_missing).max().values
    
    print(f"  {method}:")
    print(f"    RMS error: {rms_error:.3f} m/s")
    print(f"    Max error: {max_error:.3f} m/s")
    print(f"    Data range: {filled.min().values:.2f} to {filled.max().values:.2f} m/s")