# HumanStreets Segmentation - Colab Version (T4) - SAM3

This notebook runs the segmentation pipeline using **SAM3** (Segment Anything Model 3 / Custom) and saves the results to a **GeoPackage (GPKG)** file with separate layers for each class.
It uses the tiling logic from `load_segment_upload.py`.

In [None]:
# Install dependencies
!pip install rasterio ultralytics geopandas shapely pyproj tqdm git+https://github.com/ultralytics/CLIP.git

In [None]:
import os
# Set PyTorch memory configuration BEFORE importing torch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import rasterio
import numpy as np
import time
import geopandas as gpd
import pandas as pd
import torch
import gc
from shapely.geometry import Polygon, MultiPolygon
from rasterio.windows import Window
from itertools import product
from pyproj import Transformer
from tqdm.notebook import tqdm

# Check for GPU
print(f"CUDA Available: {torch.cuda.is_available()}")
device = 0 if torch.cuda.is_available() else 'cpu'

In [None]:
# @title 2. Mount Drive & Paths
from google.colab import drive
import os

print("Mounting Google Drive... Please check for an authentication popup!")
try:
    drive.mount('/content/drive', force_remount=True)
except Exception as e:
    print(f"Mount failed: {e}. Try running this cell again.")

# --- CONFIGURATION ---
# Change this to your folder in Drive
BASE_DIR = "/content/drive/MyDrive/Spatial_Data_Bootcamp/captstone/Colab_Walkability"

TIF_PATH = os.path.join(BASE_DIR, "mosiac_rgb_6cmPerPixel.tif")
MODEL_PATH = os.path.join(BASE_DIR, "sam3.pt")
STREETS_PATH = os.path.join(BASE_DIR, "streets.geojson")

OUTPUT_GPKG = os.path.join(BASE_DIR, "sam3_results.gpkg")

print(f"Checking files:\n TIF: {os.path.exists(TIF_PATH)}\n Model: {os.path.exists(MODEL_PATH)}\n Streets: {os.path.exists(STREETS_PATH)}")
# TILE_SIZE updated to 1036 to match stride 14 requirement (1036 = 74 * 14)
TILE_SIZE = 1036 
OVERLAP = 50

# Classes to detect with SAM3
SAM3_CLASSES = ["road", "sidewalk", "car", "obstacle", "tree"]

In [None]:
# Import SAM Predictor
# We try to import the specific predictor used in the project's other notebooks
try:
    from ultralytics.models.sam.predict import SAM3SemanticPredictor
    print("Using SAM3SemanticPredictor")
except ImportError:
    print("SAM3SemanticPredictor not found, falling back to standard SAM or checking imports...")
    # Fallback or error handling - assuming the environment will support it as per Run_SAM3_Colab.ipynb
    from ultralytics import SAM
    # Note: Standard SAM might not support 'text' prompts the same way.
    # If this fails, ensure you are using the correct modified ultralytics version or script.

def setup_transformer(src):
    """Setup coordinate transformer and bounds boundaries."""
    to_wgs84 = Transformer.from_crs(src.crs, "EPSG:4326", always_xy=True)
    bounds = src.bounds
    min_lon, min_lat = to_wgs84.transform(bounds.left, bounds.bottom)
    max_lon, max_lat = to_wgs84.transform(bounds.right, bounds.top)
    return to_wgs84, (min_lat, min_lon, max_lat, max_lon)

def validate_and_correct_poly(poly):
    """Ensure polygon is valid, simple, and not empty."""
    if not poly.is_valid:
        poly = poly.buffer(0)
    if poly.is_empty:
        return []
    if isinstance(poly, MultiPolygon):
        return [p for p in poly.geoms if p.is_valid and not p.is_empty]
    return [poly]

def process_image_sam3(tif_path, model_path, output_gpkg):
    """Run tiled inference using SAM3 and save to GPKG layers."""
    
    if not os.path.exists(tif_path):
        print(f"Error: TIF file not found at {tif_path}")
        return

    # Load Model
    print(f"Loading SAM3 model from {model_path}...")
    try:
        # Initialize SAM3 Predictor
        # overrides logic taken from Run_SAM3_Colab.ipynb
        try:
            predictor = SAM3SemanticPredictor(overrides=dict(conf=0.25, task="segment", mode="predict", model=model_path, imgsz=TILE_SIZE))
        except NameError:
            # Fallback if class not imported
            predictor = SAM(model_path)
            
    except Exception as e:
        print(f"Model Load Error: {e}")
        return

    # Initialize storage for each class
    class_polygons = {name: [] for name in SAM3_CLASSES}

    with rasterio.open(tif_path) as src:
        W, H = src.width, src.height
        transform_affine = src.transform
        transformer, wgs_bounds = setup_transformer(src)
        min_lat, min_lon, max_lat, max_lon = wgs_bounds

        print(f"Processing Image: {W}x{H} | CRS: {src.crs}")

        # Tile Generation with Overlap
        stride = TILE_SIZE - OVERLAP
        # We use a simple range loop for stride.
        # Ensure the last tile covers the edge by going up to W and H
        x_anchors = list(range(0, W, stride))
        y_anchors = list(range(0, H, stride))
        
        tiles = list(product(x_anchors, y_anchors))
        print(f"Total Tiles: {len(tiles)} (Size: {TILE_SIZE}, Overlap: {OVERLAP})")

        # Enumerate to track index for memory management
        for i, (col, row) in enumerate(tqdm(tiles, desc="Processing Tiles")):
            
            # Memory Cleanup every 10 tiles
            if i % 10 == 0:
                torch.cuda.empty_cache()
                gc.collect()

            # Read Tile
            # Window(col, row, width, height)
            # Rasterio handles truncation if col+width > W
            window = Window(col, row, min(TILE_SIZE, W - col), min(TILE_SIZE, H - row))
            img_data = src.read([1, 2, 3], window=window)
            
            # Check if tile has data
            if img_data.shape[0] < 3 or img_data.max() == 0: continue
            
            # Check if tile is too small (might happen at very edge if stride aligns oddly, but usually fine)
            if img_data.shape[1] < 10 or img_data.shape[2] < 10: continue

            # Prepare Image (HCC -> HWC, Contiguous)
            img = np.ascontiguousarray(np.transpose(img_data, (1, 2, 0)))
            
            # Inference - NO GRAD to save memory
            with torch.no_grad():
                # We assume predictor works like in Run_SAM3_Colab.ipynb
                # set_image might be needed if using SemanticPredictor clss specifically
                if hasattr(predictor, 'set_image'):
                    predictor.set_image(img)
                    results = predictor(text=SAM3_CLASSES, save=False, verbose=False)
                else:
                    # Fallback usage
                    results = predictor(img, imgsz=TILE_SIZE, verbose=False)

            if not results or results[0].masks is None: 
                del img, img_data, results
                continue
            
            res = results[0]
            names_map = res.names
            
            for j, poly_coords in enumerate(res.masks.xy):
                if len(poly_coords) < 3: continue
                
                # Identify Class
                class_id = int(res.boxes.cls[j])
                
                # FIXED: Handle names_map as list or dict
                class_name_raw = "unknown"
                if hasattr(names_map, 'get'):
                     class_name_raw = names_map.get(class_id, "unknown")
                elif isinstance(names_map, list):
                     if 0 <= class_id < len(names_map):
                          class_name_raw = names_map[class_id]

                # 1. Tile -> Global Pixel
                global_x = poly_coords[:, 0] + col
                global_y = poly_coords[:, 1] + row
                
                # 2. Global Pixel -> Native CRS
                native_x, native_y = rasterio.transform.xy(transform_affine, global_y, global_x)
                
                # 3. Native -> WGS84
                wgs_lon, wgs_lat = transformer.transform(native_x, native_y)
                
                # Validation
                if not (np.all(np.isfinite(wgs_lon)) and np.all(np.isfinite(wgs_lat))): continue
                if np.any(np.abs(wgs_lat) > 90) or np.any(np.abs(wgs_lon) > 180): continue

                # Create Polygon
                raw_poly = Polygon(zip(wgs_lon, wgs_lat))
                
                # Validate geometry
                valid_polys = validate_and_correct_poly(raw_poly)
                
                # Assign to correct class bucket
                for p in valid_polys:
                    matched = False
                    # Match returned class name to our target list
                    for target in SAM3_CLASSES:
                        if target in class_name_raw.lower():
                            class_polygons[target].append(p)
                            matched = True
                            break
                    # If not matched to specific list, maybe skip or add to 'other' if needed
            
            # Clean up per loop to free memory
            del img, img_data, results, res
                        

    # --- Save to GPKG ---
    print(f"Saving results to {output_gpkg}...")
    
    save_count = 0
    for cls_name, polys in class_polygons.items():
        if len(polys) > 0:
            print(f"  - Layer '{cls_name}': {len(polys)} polygons")
            gdf = gpd.GeoDataFrame({'geometry': polys}, crs="EPSG:4326")
            # Use layer=cls_name to create separate layers
            gdf.to_file(output_gpkg, layer=cls_name, driver="GPKG")
            save_count += 1
    
    if save_count > 0:
        print(f"Success! Saved {save_count} layers to {output_gpkg}")
    else:
        print("No polygons found to save.")

def main():
    print("--- Starting SAM3 Segmentation Pipeline ---")
    start_time = time.time()
    
    process_image_sam3(TIF_PATH, MODEL_PATH, OUTPUT_GPKG)
    
    print(f"Total Time: {time.time() - start_time:.2f}s")

if __name__ == "__main__":
    main()