
rasterio → read/write GeoTIFF images

numpy → array operations

Path → handle paths cleanly

tqdm → progress bars


In [None]:
# Libraries for handling raster data, arrays, and filesystem operations
import rasterio
import numpy as np
from pathlib import Path
import os
from tqdm import tqdm


Keeps all paths and patch parameters in one place.

mkdir(parents=True, exist_ok=True) ensures directories exist.

In [None]:
# Base directory for your project (update to your machine/server path)
BASE_PATH = Path("/kraken1nobackup/2025-CS489-509/drodriguez/ML")

# Directories for input tiles and predicted masks
TILE_DIR = BASE_PATH / "data/updated_naip_tiles_2018"
MASK_DIR = BASE_PATH / "results/updated_chesapeake_predictions_2018/masks"

# Directories for storing extracted patches
PATCH_DIR = BASE_PATH / "patches"
PATCH_IMG_DIR = PATCH_DIR / "images"
PATCH_MASK_DIR = PATCH_DIR / "masks"

# Create directories if they do not exist
PATCH_IMG_DIR.mkdir(parents=True, exist_ok=True)
PATCH_MASK_DIR.mkdir(parents=True, exist_ok=True)

# Patch configuration
PATCH_SIZE = 512  # size of each patch
STRIDE = PATCH_SIZE  # stride; change to < PATCH_SIZE for overlapping patches


This function handles all patch extraction and saving.

Input: full tile image + mask.

Output: multiple patches saved to disk.

In [None]:
def extract_patches(image, mask, base_name):
    """
    Extracts square patches from an RGB image and its mask.

    Parameters:
    - image: np.array of shape (C, H, W)
    - mask: np.array of shape (H, W)
    - base_name: string to use for patch filenames

    Returns:
    - count: total number of patches extracted
    """
    h, w = image.shape[1:]  # image shape: (C, H, W)
    count = 0
    for i in range(0, h - PATCH_SIZE + 1, STRIDE):
        for j in range(0, w - PATCH_SIZE + 1, STRIDE):
            img_patch = image[:, i:i+PATCH_SIZE, j:j+PATCH_SIZE]
            mask_patch = mask[i:i+PATCH_SIZE, j:j+PATCH_SIZE]

            # Define output paths
            img_out_path = PATCH_IMG_DIR / f"{base_name}_patch_{count}.tif"
            mask_out_path = PATCH_MASK_DIR / f"{base_name}_patch_{count}_mask.tif"

            # Save image patch
            with rasterio.open(
                img_out_path, "w",
                driver="GTiff",
                height=PATCH_SIZE,
                width=PATCH_SIZE,
                count=3,
                dtype=img_patch.dtype
            ) as dst:
                dst.write(img_patch)

            # Save mask patch
            with rasterio.open(
                mask_out_path, "w",
                driver="GTiff",
                height=PATCH_SIZE,
                width=PATCH_SIZE,
                count=1,
                dtype=mask_patch.dtype
            ) as dst:
                dst.write(mask_patch, 1)

            count += 1
    return count


Loops over all tiles.

Checks if a corresponding mask exists.

Reads image and mask.

Calls the patch extraction function.

Prints progress and number of patches created.

In [None]:
# Get all tile paths
tile_paths = sorted(TILE_DIR.glob("*.tif"))

# Loop through tiles
for tile_path in tqdm(tile_paths, desc="Extracting patches"):
    base_name = tile_path.stem

    # Match to mask file
    mask_path = MASK_DIR / f"{base_name}.tif"
    if not mask_path.exists():
        print(f"⚠️ Mask not found for {base_name}, skipping...")
        continue

    # Read RGB image
    with rasterio.open(tile_path) as img_src:
        img = img_src.read([1, 2, 3])  # Channels: R, G, B

    # Read mask
    with rasterio.open(mask_path) as mask_src:
        mask = mask_src.read(1)

    # Extract and save patches
    patch_count = extract_patches(img, mask, base_name)
    print(f"✅ {base_name}: {patch_count} patches")
