In [7]:
import geopandas as gpd
import numpy as np
from shapely.geometry import box
import os

def create_grid(geojson_path, output_dir, grid_size=2560, crs='EPSG:32760'):
    """
    Creates a square grid that covers a GeoJSON polygon.

    Args:
        geojson_path: Path to input GeoJSON file (e.g., auckland_mainland.geojson).
        output_dir: Directory to save the grid files.
        grid_size: Size of grid cells in meters. 
                     (e.g., 2560m for 256x256 pixel patches at 10m scale)
        crs: Coordinate reference system to use (default: UTM 60S for Auckland).
    """
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Read the GeoJSON
    print(f"Reading GeoJSON from {geojson_path}...")
    gdf = gpd.read_file(geojson_path)
    
    # Reproject to UTM (meters) for an accurate grid
    print(f"Reprojecting to {crs}...")
    gdf_utm = gdf.to_crs(crs)
    
    # Get the bounding box
    bounds = gdf_utm.total_bounds  # (minx, miny, maxx, maxy)
    minx, miny, maxx, maxy = bounds
    
    print(f"Creating {grid_size}m x {grid_size}m grid...")
    
    # Generate grid cell coordinates
    x_coords = np.arange(minx, maxx, grid_size)
    y_coords = np.arange(miny, maxy, grid_size)
    
    total_patches = len(x_coords) * len(y_coords)
    print(f"Grid dimensions: {len(x_coords)} cols x {len(y_coords)} rows = {total_patches} total patches (before filtering)")
    
    # Create grid polygons
    grid_cells = []
    for x in x_coords:
        for y in y_coords:
            cell = box(x, y, x + grid_size, y + grid_size)
            grid_cells.append(cell)
            
    # Create GeoDataFrame
    grid_gdf = gpd.GeoDataFrame(geometry=grid_cells, crs=crs)
    
    # Filter out patches that don't intersect with the original geometry
    print(f"Filtering patches that intersect with original shape...")
    
    # Perform a spatial join to keep only intersecting patches
    grid_filtered = gpd.sjoin(grid_gdf, gdf_utm, how="inner")
    grid_filtered = grid_filtered.drop_duplicates(subset=['geometry']) # Remove duplicates
    
    # Reset patch IDs to be sequential after filtering
    grid_filtered['patch_id'] = range(len(grid_filtered))
    grid_filtered = grid_filtered.reset_index(drop=True)
    
    print(f"  -> Total patches after filtering: {len(grid_filtered)}")
    
    # --- Save to file ---
    # Save the main UTM grid file
    output_path_utm = os.path.join(output_dir, f'auckland_grid_{grid_size}m.geojson')
    print(f"Saving UTM grid to {output_path_utm}")
    grid_filtered.to_file(output_path_utm, driver='GeoJSON')
    
    # Also create a version in WGS84 (EPSG:4326) for GEE bounds
    output_path_wgs84 = os.path.join(output_dir, f'auckland_grid_{grid_size}m_wgs84.geojson')
    print(f"Saving WGS84 grid to {output_path_wgs84}")
    grid_wgs84 = grid_filtered.to_crs('EPSG:4326')
    grid_wgs84.to_file(output_path_wgs84, driver='GeoJSON')
    
    print("\n✓ Grid creation complete!")


# --- Main execution ---
if __name__ == "__main__":
    
    # 1. Your main Auckland shapefile
    geojson_file = "aklshp/akl_mainland_only.geojson"
    
    # 2. Directory to save the grid files
    grid_output_dir = "aklshp"
    
    # 3. Define your patch size
    # For 256x256 pixel patches at 10m/pixel:
    PATCH_SIZE_METERS = 256 * 10 # = 2560 meters
    
    # 4. Run the grid creation
    create_grid(
        geojson_path=geojson_file,
        output_dir=grid_output_dir,
        grid_size=PATCH_SIZE_METERS,
        crs='EPSG:32760'  # UTM Zone 60S for Auckland
    )

Reading GeoJSON from aklshp/akl_mainland_only.geojson...
Reprojecting to EPSG:32760...
Creating 2560m x 2560m grid...
Grid dimensions: 40 cols x 50 rows = 2000 total patches (before filtering)
Filtering patches that intersect with original shape...
  -> Total patches after filtering: 870
Saving UTM grid to aklshp/auckland_grid_2560m.geojson
Saving WGS84 grid to aklshp/auckland_grid_2560m_wgs84.geojson

✓ Grid creation complete!


In [None]:
import ee
import geopandas as gpd
import numpy as np
import h5py
from pathlib import Path
import json
import datetime

try:
    # Your project name
    ee.Initialize(project="geog-761-experiment-1") 
except Exception as e:
    print(f"Error initializing Earth Engine: {e}")
    print("Please run 'earthengine authenticate' in your terminal.")
    exit()

print("✓ Earth Engine initialized successfully.")


def get_most_recent_sentinel2_auckland_ee(geometry, days_back=100, cloud_cover_max=30.0):
    """Gets the most recent, least cloudy Sentinel-2 image for the area."""
    
    now_py = datetime.datetime.now()
    start_date_py = now_py - datetime.timedelta(days=days_back)
    
    now_ee = ee.Date(now_py)
    start_date_ee = ee.Date(start_date_py)
    
    s2_collection = (
        ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
        .filterBounds(geometry)
        .filterDate(start_date_ee, now_ee)
        .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', cloud_cover_max))
        .sort('CLOUDY_PIXEL_PERCENTAGE')
    )
    
    if s2_collection.size().getInfo() == 0:
        print("Warning: No Sentinel-2 images found with specified criteria.")
        return None
        
    s2_image = ee.Image(s2_collection.mosaic())
    
    s2_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 
                'B7', 'B8', 'B8A', 'B9', 'B11', 'B12']
    
    return s2_image.select(s2_bands).clip(geometry)


def get_most_recent_sentinel1_auckland_ee(geometry, days_back=30):
    """Gets the most recent Sentinel-1 image for the area."""
    
    now_py = datetime.datetime.now()
    start_date_py = now_py - datetime.timedelta(days=days_back)
    
    now_ee = ee.Date(now_py)
    start_date_ee = ee.Date(start_date_py)

    s1_collection = (
        ee.ImageCollection('COPERNICUS/S1_GRD')
        .filterBounds(geometry)
        .filterDate(start_date_ee, now_ee)
        .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV'))
        .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))
        .filter(ee.Filter.eq('instrumentMode', 'IW'))
        .sort('system:time_start', False)
    )
    
    if s1_collection.size().getInfo() == 0:
        print("Warning: No Sentinel-1 images found with specified criteria.")
        return None, None
        
    s1_image = ee.Image(s1_collection.first())
    s1_metadata = s1_image.getInfo()
    
    s1_bands = ['VV', 'VH']
    
    # <--- FIX: Removed duplicated return statement that was here
    return s1_image.select(s1_bands).clip(geometry), s1_metadata


def extract_patch_data(image, geometry, crs, scale=10):
    """
    CRITICAL FUNCTION: Extracts patch data and forces all bands to a single scale.
    """
    try:
        image_rescaled = image.reproject(crs=crs, scale=scale)
        sample = image_rescaled.sampleRectangle(region=geometry, defaultValue=0)
        
        band_names = image.bandNames().getInfo()
        
        band_arrays = []
        for band in band_names:
            band_data = np.array(sample.get(band).getInfo())
            
            if band_data.size == 0:
                raise Exception(f"Band {band} returned no data.")
            if band_data.shape[0] < 2 or band_data.shape[1] < 2:
                raise Exception(f"Band {band} extracted only {band_data.shape} pixels.")
                
            band_arrays.append(band_data)
        
        data_array = np.stack(band_arrays, axis=-1)
        return data_array
            
    except Exception as e:
        print(f"  -> Error in extract_patch_data: {e}")
        return None


def extract_sentinel_patches(
    grid_utm_geojson_path,
    grid_wgs84_geojson_path,
    output_dir='auckland_patches',
    scale=10
):
    """
    Extracts ALL Sentinel-1 and Sentinel-2 bands for each grid cell.
    """
    
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    print(f"Output directory: {output_path}")
    
    print("\nReading grids...")
    grid_utm = gpd.read_file(grid_utm_geojson_path)
    grid_wgs84 = gpd.read_file(grid_wgs84_geojson_path)
    
    grid_crs = grid_utm.crs.to_string()
    print(f"Loaded {len(grid_utm)} patches in CRS: {grid_crs}")
    
    grid_bounds = grid_wgs84.total_bounds
    grid_geometry = ee.Geometry.Rectangle([
        grid_bounds[0], grid_bounds[1], grid_bounds[2], grid_bounds[3]
    ])
    
    print("\nFetching Sentinel-2 image...")
    s2_image = get_most_recent_sentinel2_auckland_ee(grid_geometry)
    if s2_image is None: 
        print("ERROR: Could not get Sentinel-2 image. Exiting.")
        return
    s2_bands = s2_image.bandNames().getInfo()
    print(f"   S2 bands: {s2_bands}")
    
    print("\nFetching Sentinel-1 image...")
    s1_image, _ = get_most_recent_sentinel1_auckland_ee(grid_geometry)
    if s1_image is None: 
        print("ERROR: Could not get Sentinel-1 image. Exiting.")
        return
    s1_bands = s1_image.bandNames().getInfo()
    print(f"   S1 bands: {s1_bands}")
    print(f"\n✓ Sentinel images ready")
    
    print(f"\nExtracting patches for {len(grid_utm)} grid cells...")
    
    # Loop over the UTM grid for processing
    for idx, row in grid_utm.iterrows():
        patch_id = int(row['patch_id'])
        print(f"\n--- Processing patch {idx + 1}/{len(grid_utm)} (ID: {patch_id}) ---")

        # <--- 🐞 THE FIX IS HERE ---
        # Get bounds for the patch (in UTM)
        bounds = row.geometry.bounds  # (minx, miny, maxx, maxy)
        
        # Create an explicit GEE Rectangle instead of converting the complex polygon
        patch_geom = ee.Geometry.Rectangle(
            coords=[bounds[0], bounds[1], bounds[2], bounds[3]],
            proj=grid_crs,     # Tell GEE what projection these coordinates are in
            geodesic=False   # The coordinates are planar (UTM), not on a globe
        )
        # --- END OF FIX ---
        
        try:
            h5_path = output_path / f'patch_{patch_id:04d}.h5'
            
            with h5py.File(h5_path, 'w') as hf:
                print("  Extracting Sentinel-2...")
                s2_data = extract_patch_data(s2_image, patch_geom, grid_crs, scale)
                if s2_data is not None:
                    s2_group = hf.create_group('sentinel2')
                    print(f"  -> S2 data shape: {s2_data.shape}")
                    for i, band_name in enumerate(s2_bands):
                        s2_group.create_dataset(band_name, data=s2_data[:, :, i])
                
                print("  Extracting Sentinel-1...")
                s1_data = extract_patch_data(s1_image, patch_geom, grid_crs, scale)
                if s1_data is not None:
                    s1_group = hf.create_group('sentinel1')
                    print(f"  -> S1 data shape: {s1_data.shape}")
                    for i, band_name in enumerate(s1_bands):
                        s1_group.create_dataset(band_name, data=s1_data[:, :, i])
                
                hf.attrs['patch_id'] = patch_id
                hf.attrs['crs'] = grid_crs
                hf.attrs['scale_meters'] = scale
            
        except Exception as e:
            print(f"\nError processing patch {patch_id}: {e}")
            continue
    
    print(f"\n✓ Extraction complete!")


# --- Main execution ---
if __name__ == "__main__":
    
    PATCH_SIZE_METERS = 2560 
    
    grid_utm_path = f'aklshp/auckland_grid_{PATCH_SIZE_METERS}m.geojson'
    grid_wgs84_path = f'aklshp/auckland_grid_{PATCH_SIZE_METERS}m_wgs84.geojson'
    
    patch_output_dir = 'auckland_patches'
    
    extract_sentinel_patches(
        grid_utm_geojson_path=grid_utm_path,
        grid_wgs84_geojson_path=grid_wgs84_path,
        output_dir=patch_output_dir,
        scale=10
    )

✓ Earth Engine initialized successfully.
Output directory: auckland_patches

Reading grids...
Loaded 870 patches in CRS: EPSG:32760

Fetching Sentinel-2 image...
   S2 bands: ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9', 'B11', 'B12']

Fetching Sentinel-1 image...
   S1 bands: ['VV', 'VH']

✓ Sentinel images ready

Extracting patches for 870 grid cells...

--- Processing patch 1/870 (ID: 0) ---
  Extracting Sentinel-2...
  -> S2 data shape: (257, 257, 12)
  Extracting Sentinel-1...
  -> S1 data shape: (257, 257, 2)

--- Processing patch 2/870 (ID: 1) ---
  Extracting Sentinel-2...
