

rasterio → read/write GeoTIFF images

numpy → array operations

Path → manage paths easily

tqdm → progress bar



In [None]:
import rasterio
import numpy as np
from pathlib import Path
from tqdm import tqdm


Organizes all paths and patch settings in one place.

Creates output directories if they do not exist.

In [None]:
# Base project directory (update as needed)
BASE_PATH = Path("/content/cafo_project")

# Input tiles and mask directories
TILE_DIR = BASE_PATH / "data/updated_naip_tiles_2018"
MASK_DIR = BASE_PATH / "results/updated_chesapeake_predictions_2018/masks"

# Output patch directories
PATCH_DIR = BASE_PATH / "patches"
PATCH_IMG_DIR = PATCH_DIR / "images"
PATCH_MASK_DIR = PATCH_DIR / "masks"
PATCH_IMG_DIR.mkdir(parents=True, exist_ok=True)
PATCH_MASK_DIR.mkdir(parents=True, exist_ok=True)

# Patch parameters
PATCH_SIZE = 512
STRIDE = PATCH_SIZE  # no overlap; set < PATCH_SIZE for overlapping patches


Handles all patch extraction and saving.

Skips patches with no CAFO presence.

Returns numbers of saved and skipped patches.

In [None]:
def extract_patches(image, mask, base_name):
    """
    Extract square patches from an RGB image and mask.
    Skips patches that are all background (mask=0).

    Returns:
    - count of patches saved
    - count of patches skipped
    """
    h, w = image.shape[1:]  # image shape: (C, H, W)
    count = 0
    skipped = 0

    for i in range(0, h - PATCH_SIZE + 1, STRIDE):
        for j in range(0, w - PATCH_SIZE + 1, STRIDE):
            img_patch = image[:, i:i+PATCH_SIZE, j:j+PATCH_SIZE]
            mask_patch = mask[i:i+PATCH_SIZE, j:j+PATCH_SIZE]

            if np.all(mask_patch == 0):  # Skip all-background patches
                skipped += 1
                continue

            # Paths to save patches
            img_out_path = PATCH_IMG_DIR / f"{base_name}_patch_{count}.tif"
            mask_out_path = PATCH_MASK_DIR / f"{base_name}_patch_{count}_mask.tif"

            # Save image patch
            with rasterio.open(
                img_out_path, "w",
                driver="GTiff",
                height=PATCH_SIZE,
                width=PATCH_SIZE,
                count=3,
                dtype=img_patch.dtype
            ) as dst:
                dst.write(img_patch)

            # Save mask patch
            with rasterio.open(
                mask_out_path, "w",
                driver="GTiff",
                height=PATCH_SIZE,
                width=PATCH_SIZE,
                count=1,
                dtype=mask_patch.dtype
            ) as dst:
                dst.write(mask_patch, 1)

            count += 1

    return count, skipped


Loops over all tiles in the directory.

Matches each tile to its corresponding mask.

Calls extract_patches() and reports progress.

In [None]:
# Get all tile paths
tile_paths = sorted(TILE_DIR.glob("*.tif"))

# Loop through tiles and extract patches
for tile_path in tqdm(tile_paths, desc="Extracting patches"):
    base_name = tile_path.stem
    mask_path = MASK_DIR / f"{base_name}.tif"

    if not mask_path.exists():
        print(f"⚠️ Mask not found for {base_name}, skipping...")
        continue

    # Read RGB image
    with rasterio.open(tile_path) as img_src:
        img = img_src.read([1, 2, 3])

    # Read mask
    with rasterio.open(mask_path) as mask_src:
        mask = mask_src.read(1)

    # Extract patches
    kept, skipped = extract_patches(img, mask, base_name)
    print(f"✅ {base_name}: {kept} patches saved, {skipped} skipped (all background)")
