Imports libraries for filesystem operations, deep learning inference, geospatial reading/writing, and progress tracking.

In [None]:
import os
from pathlib import Path
import numpy as np
import torch
import rasterio
import segmentation_models_pytorch as smp
from tqdm import tqdm


Defines paths for model, input tiles, and output predictions.

Sets GPU/CPU device.

Ensures the output directory exists.

In [None]:
# --- Paths ---
MODEL_PATH = "PATH_TO_YOUR_MODEL/cafo_multi_patch.pt"        # Path to your trained PyTorch model (.pt)
INPUT_DIR = "PATH_TO_YOUR_INPUT_TILES"                       # Folder containing NAIP tile images (.tif)
OUTPUT_DIR = "PATH_TO_SAVE_PREDICTIONS"                      # Folder to save predicted mask images

# --- Device ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- Make sure output folder exists ---
os.makedirs(OUTPUT_DIR, exist_ok=True)

Loads your pretrained segmentation model and moves it to GPU/CPU.

Sets model to evaluation mode.

In [None]:
# Load the trained U-Net model
model = smp.Unet(encoder_name="resnet18", encoder_weights=None, in_channels=3, classes=1)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()

print("✅ Model loaded and ready for inference.")


Loops over all input tiles.

Reads each tile with rasterio and normalizes it.

Converts to tensor for model inference.

Runs forward pass and binarizes output.

Saves prediction to a GeoTIFF, keeping CRS and transform intact.

In [None]:
# Get all input tiles
tile_paths = sorted(Path(INPUT_DIR).glob("*.tif"))

for tile_path in tqdm(tile_paths, desc="Running inference"):
    with rasterio.open(tile_path) as src:
        image = src.read([1,2,3]).astype(np.float32) / 255.0  # Normalize
        transform = src.transform
        crs = src.crs

    # Convert to tensor: add batch dimension
    image_tensor = torch.from_numpy(image).unsqueeze(0).to(DEVICE)

    # Forward pass
    with torch.no_grad():
        prediction = model(image_tensor).squeeze().cpu().numpy()

    # Binarize prediction (0/1)
    prediction_binary = (prediction > 0.5).astype(np.uint8)

    # Save as GeoTIFF
    out_path = Path(OUTPUT_DIR) / f"{tile_path.stem}_prediction.tif"
    with rasterio.open(
        out_path,
        "w",
        driver="GTiff",
        height=prediction_binary.shape[0],
        width=prediction_binary.shape[1],
        count=1,
        dtype=rasterio.uint8,
        crs=crs,
        transform=transform,
    ) as dst:
        dst.write(prediction_binary, 1)


Prints a final message indicating completion.

In [None]:
print(f"✅ Inference complete. Predictions saved to: {OUTPUT_DIR}")
