# Remap Regular Lat/Lon Grid to HEALPix

This notebook demonstrates how to remap a regular latitude/longitude grid dataset to HEALPix grid using easygems.remap.

Adapted from the unstructured grid examples to work with regular grids like the MERG dataset.

In [None]:
import numpy as np
import xarray as xr
import healpix as hp
import easygems.remap as egr
import easygems.healpix as egh
import matplotlib.pyplot as plt
from pathlib import Path

## Define the remapping functions

In [None]:
def gen_weights_latlon(lon, lat, order):
    """
    Generate remapping weights from regular lat/lon grid to HEALPix grid.
    
    Parameters:
    -----------
    lon : array-like
        1D array of longitude values in degrees
    lat : array-like  
        1D array of latitude values in degrees
    order : int
        HEALPix order (zoom level)
        
    Returns:
    --------
    weights : Dataset
        Remapping weights for use with easygems.remap.apply_weights
    """
    nside = hp.order2nside(order)
    npix = hp.nside2npix(nside)
    
    # Get HEALPix pixel coordinates
    hp_lon, hp_lat = hp.pix2ang(
        nside=nside, ipix=np.arange(npix), lonlat=True, nest=True
    )
    
    # Create 2D meshgrid from 1D lat/lon arrays
    lon_2d, lat_2d = np.meshgrid(lon, lat)
    
    # Flatten to 1D for weight computation
    source_lon = lon_2d.flatten()
    source_lat = lat_2d.flatten()
    
    # For global grids, handle periodicity
    if np.max(lon) - np.min(lon) >= 359:  # Global grid
        print("Handling longitude periodicity for global grid")
        
        # Extend grid periodically in longitude
        lon_extended = np.hstack([source_lon - 360, source_lon, source_lon + 360])
        lat_extended = np.tile(source_lat, 3)
        
        # Compute weights using extended grid
        weights = egr.compute_weights_delaunay(
            points=(lon_extended, lat_extended), 
            xi=(hp_lon, hp_lat)
        )
        
        # Remap source indices back to original grid size
        original_size = len(source_lon)
        weights = weights.assign(src_idx=weights.src_idx % original_size)
        
    else:
        # Regional grid - no periodicity handling
        print("Regional grid detected, no periodicity handling")
        weights = egr.compute_weights_delaunay(
            points=(source_lon, source_lat), 
            xi=(hp_lon, hp_lat)
        )
    
    return weights

In [None]:
def remap_latlon_to_healpix(ds, order):
    """
    Remap a dataset from regular lat/lon grid to HEALPix grid.
    
    Parameters:
    -----------
    ds : xr.Dataset
        Input dataset with lat/lon coordinates
    order : int
        HEALPix order (zoom level)
        
    Returns:
    --------
    ds_remap : xr.Dataset
        Dataset remapped to HEALPix grid with 'cell' dimension
    """
    
    # Generate remapping weights
    print(f"Generating weights for HEALPix order {order}")
    weights = gen_weights_latlon(ds.lon.values, ds.lat.values, order)
    
    # Get number of HEALPix pixels
    npix = len(weights.tgt_idx)
    print(f"Remapping to {npix} HEALPix pixels")
    
    # Apply remapping using xr.apply_ufunc
    ds_remap = xr.apply_ufunc(
        egr.apply_weights,
        ds,
        kwargs=weights,
        keep_attrs=True,
        input_core_dims=[["lat", "lon"]],  # Input dimensions to remap
        output_core_dims=[["cell"]],       # Output HEALPix dimension
        on_missing_core_dim="copy",        # Copy other dimensions as-is
        output_dtypes=["f4"],
        vectorize=True,
        dask="parallelized",
        dask_gufunc_kwargs={
            "output_sizes": {"cell": npix},
        },
    )
    
    return ds_remap

## Load and examine your dataset

In [None]:
# Update this path to your actual file
input_file = "merg_2020080620_4km-pixel.nc"

# Open with chunking for memory efficiency
ds = xr.open_dataset(input_file, chunks={'time': 1, 'lat': 900, 'lon': 1800})
print("Dataset loaded:")
display(ds)

print(f"\nGrid info:")
print(f"Longitude range: {ds.lon.min().values:.2f} to {ds.lon.max().values:.2f}")
print(f"Latitude range: {ds.lat.min().values:.2f} to {ds.lat.max().values:.2f}")
print(f"Grid resolution: ~{np.diff(ds.lon.values).mean():.3f}° lon, ~{np.diff(ds.lat.values).mean():.3f}° lat")

## Perform the remapping to HEALPix zoom 9

In [None]:
# Set HEALPix parameters
order = 9
nside = hp.order2nside(order)
npix = hp.nside2npix(nside)

print(f"HEALPix order {order}: nside={nside}, npix={npix}")
print(f"Approximate HEALPix resolution: {np.sqrt(4*np.pi/npix)*180/np.pi:.3f} degrees")

# Perform the remapping
print("\nStarting remapping...")
ds_remap = remap_latlon_to_healpix(ds, order)

print("\nRemapped dataset:")
display(ds_remap)

## Add metadata and save

In [None]:
# Add HEALPix metadata
ds_remap.attrs.update({
    'healpix_order': order,
    'healpix_nside': nside,
    'healpix_npix': npix,
    'healpix_nest': True,
    'original_grid': 'regular_lat_lon',
    'original_resolution': '0.01_degree',
    'remapping_method': 'delaunay_triangulation'
})

# Save the remapped dataset
output_file = f"merg_2020080620_4km-pixel_healpix_z{order}.nc"

print(f"Saving to {output_file}...")
encoding = {
    'Tb': {'zlib': True, 'complevel': 4},
    'precipitationCal': {'zlib': True, 'complevel': 4}
}

ds_remap.to_netcdf(output_file, encoding=encoding)
print("File saved successfully!")

## Visualize the results

In [None]:
# Plot original data
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Original grid
ds.Tb.isel(time=0).plot(ax=ax1, cmap='viridis')
ax1.set_title('Original Lat/Lon Grid')
ax1.set_aspect('equal')

# HEALPix grid using easygems
egh.healpix_show(ds_remap.Tb.isel(time=0), ax=ax2, cmap='viridis')
ax2.set_title(f'HEALPix Grid (order {order})')

plt.tight_layout()
plt.show()

In [None]:
# Compare precipitation data
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Original precipitation
ds.precipitationCal.isel(time=0).plot(ax=ax1, cmap='Blues')
ax1.set_title('Original Precipitation (mm/hr)')
ax1.set_aspect('equal')

# HEALPix precipitation
egh.healpix_show(ds_remap.precipitationCal.isel(time=0), ax=ax2, cmap='Blues')
ax2.set_title(f'HEALPix Precipitation (order {order})')

plt.tight_layout()
plt.show()

## Check data conservation

In [None]:
# Compare statistics to check if remapping preserved the data well
print("Data conservation check:")
print("\nTemperature (Tb):")
print(f"Original - Mean: {ds.Tb.isel(time=0).mean().values:.2f}, Std: {ds.Tb.isel(time=0).std().values:.2f}")
print(f"Remapped - Mean: {ds_remap.Tb.isel(time=0).mean().values:.2f}, Std: {ds_remap.Tb.isel(time=0).std().values:.2f}")

print("\nPrecipitation:")
print(f"Original - Mean: {ds.precipitationCal.isel(time=0).mean().values:.4f}, Std: {ds.precipitationCal.isel(time=0).std().values:.4f}")
print(f"Remapped - Mean: {ds_remap.precipitationCal.isel(time=0).mean().values:.4f}, Std: {ds_remap.precipitationCal.isel(time=0).std().values:.4f}")