In [None]:
import torch
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
import os
from tqdm import tqdm
import numpy as np
import rasterio
import warnings

# Suppress noisy warnings from rasterio
warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

# --- Configuration ---
SOURCE_IMAGE_DIR = "/content/dataset-medium/images"
OUTPUT_MASK_DIR = "/content/full_size_road_masks"
PATCH_SIZE = 512  # The size of the chunks to feed to the model
STRIDE = 256      # How far to move the window for the next chunk. < PATCH_SIZE creates overlap.

# Create the output directory if it doesn't exist
os.makedirs(OUTPUT_MASK_DIR, exist_ok=True)

# --- Load the Model ---
print("Loading road detection model from Hugging Face...")
processor = SegformerImageProcessor.from_pretrained("gmbernardi/segformer-b1-spacenet-roads")
model = SegformerForSemanticSegmentation.from_pretrained("gmbernardi/segformer-b1-spacenet-roads")

# Check if a GPU is available and move the model to it
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Using device: {device}")

# --- Find all .tif files ---
source_files = [f for f in os.listdir(SOURCE_IMAGE_DIR) if f.endswith('.tif')]
print(f"Found {len(source_files)} source .tif images to process.")

# --- Main Processing Loop ---
for filename in tqdm(source_files, desc="Processing Full Images"):
    source_path = os.path.join(SOURCE_IMAGE_DIR, filename)
    
    # Change the output file extension to .png
    output_filename = os.path.splitext(filename)[0] + '.png'
    output_path = os.path.join(OUTPUT_MASK_DIR, output_filename)

    if os.path.exists(output_path):
        print(f"Skipping {filename}, output already exists.")
        continue

    try:
        # --- Read the large TIFF file ---
        with rasterio.open(source_path) as src:
            # Note: rasterio reads channels first (C, H, W), Pillow/Numpy use (H, W, C)
            # We'll read the first 3 bands (RGB) and transpose them
            full_image_np = src.read([1, 2, 3]).transpose(1, 2, 0)
            H, W, C = full_image_np.shape

            # Create empty arrays to store the stitched predictions
            # prediction_map stores the summed probabilities
            # visit_map counts how many times each pixel was predicted (for averaging overlaps)
            prediction_map = np.zeros((H, W), dtype=np.float32)
            visit_map = np.zeros((H, W), dtype=np.uint8)

            # --- Sliding Window Loop ---
            for y in range(0, H, STRIDE):
                for x in range(0, W, STRIDE):
                    # Define the patch boundaries
                    y_end = min(y + PATCH_SIZE, H)
                    x_end = min(x + PATCH_SIZE, W)
                    y_start = y_end - PATCH_SIZE
                    x_start = x_end - PATCH_SIZE
                    
                    # Extract the patch
                    patch = full_image_np[y_start:y_end, x_start:x_end]

                    # If patch is smaller than expected, pad it
                    # This handles the edges of the image
                    if patch.shape[0] != PATCH_SIZE or patch.shape[1] != PATCH_SIZE:
                        padded_patch = np.zeros((PATCH_SIZE, PATCH_SIZE, C), dtype=np.uint8)
                        padded_patch[:patch.shape[0], :patch.shape[1], :] = patch
                        patch = padded_patch

                    # Convert patch to PIL image for the processor
                    patch_img = Image.fromarray(patch)

                    # --- Run Inference ---
                    inputs = processor(images=patch_img, return_tensors="pt").to(device)
                    with torch.no_grad():
                        outputs = model(**inputs)
                    
                    logits = outputs.logits.cpu()
                    upsampled_logits = torch.nn.functional.interpolate(
                        logits, size=(PATCH_SIZE, PATCH_SIZE), mode='bilinear', align_corners=False
                    )
                    
                    # Use softmax to get probabilities for the 'road' class (class_id=1)
                    road_probs = torch.nn.functional.softmax(upsampled_logits, dim=1)[0, 1].numpy()

                    # --- Stitch the result back into the full-size map ---
                    # Add the probabilities to the prediction_map
                    prediction_map[y_start:y_end, x_start:x_end] += road_probs
                    # Increment the visit count for this region
                    visit_map[y_start:y_end, x_start:x_end] += 1
            
            # --- Finalize the Mask ---
            # Avoid division by zero for areas that were never visited (shouldn't happen)
            visit_map[visit_map == 0] = 1
            # Average the predictions in the overlapping regions
            final_probabilities = prediction_map / visit_map

            # Convert probabilities (0.0 to 1.0) to a grayscale image (0 to 255)
            final_mask_image = Image.fromarray((final_probabilities * 255).astype(np.uint8))
            
            # Save the final stitched mask
            final_mask_image.save(output_path)

    except Exception as e:
        print(f"!!! ERROR processing {filename}: {e}")

print("\nFull-size pseudo-label generation complete!")