# SAM3 GIS Segmentation Workflow

This notebook provides an interactive workflow for:
1. **Loading raster imagery** (GeoTIFF, etc.)
2. **Viewing in an interactive Leaflet map** with `ipyleaflet`
3. **Selecting segmentation points/boxes** interactively
4. **Running SAM3 segmentation** on selected regions
5. **Post-processing masks to vector geometries**
6. **Exporting to File Geodatabase (.gdb)**

---
**Prerequisites:**
- SAM3 installed (`pip install -e .` from sam3 repo)
- HuggingFace authentication for model checkpoints
- GDAL/rasterio for raster handling
- Fiona/geopandas for vector export

## 1. Setup and Imports

In [None]:
pip install "segment-geospatial[samgeo3]"

In [None]:
pip install -e .

In [None]:
!pip install "transformers" "huggingface-hub<1.0"

In [None]:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

In [None]:
from huggingface_hub import login
login()

In [None]:
# Core imports
import os
import numpy as np
from pathlib import Path
from PIL import Image
import torch

# Raster handling
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.transform import rowcol, xy
from rasterio.features import shapes
from rasterio.crs import CRS

# Vector handling
import geopandas as gpd
from shapely.geometry import shape, mapping, Polygon, MultiPolygon
from shapely.ops import unary_union
import fiona

# Interactive mapping
import ipyleaflet
from ipyleaflet import Map, TileLayer, ImageOverlay, Marker, DrawControl, LayerGroup, GeoJSON
from ipywidgets import widgets, Output, VBox, HBox, Label, Button, Text, Dropdown, FloatSlider
from IPython.display import display, clear_output

# Visualization
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# SAM3 imports
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Initialize SAM3 Model

In [None]:
# --- SAM3 tokenizer path (FIX for FileNotFoundError) ---
from pathlib import Path
import torch

BPE_PATH = Path("/home/john/DevDrive/sam3/sam3/assets/bpe_simple_vocab_16e6.txt.gz")

# ✅ Only load once per kernel
if "model" not in globals():

    print("Loading SAM3 model (this may take a moment)...")

    model = build_sam3_image_model(
        bpe_path=str(BPE_PATH),
        load_from_HF=True,
        device="cuda" if torch.cuda.is_available() else "cpu"
    )

    model.eval()  # ✅ important for memory + inference stability
    processor = Sam3Processor(model)

    print("SAM3 model loaded successfully!")

else:
    print("SAM3 model already loaded — reusing existing GPU instance.")


## 3. Raster Loading Utilities

In [None]:
class GeoRasterLoader:
    """
    Handles loading and georeferencing of raster imagery for SAM3 segmentation.
    """
    
    def __init__(self, raster_path: str):
        self.raster_path = Path(raster_path)
        self.src = None
        self.transform = None
        self.crs = None
        self.bounds = None
        self.image_array = None
        self.image_pil = None
        self._load_raster()
    
    def _load_raster(self):
        """Load raster and extract geospatial metadata."""
        self.src = rasterio.open(self.raster_path)
        self.transform = self.src.transform
        self.crs = self.src.crs
        self.bounds = self.src.bounds
        self.width = self.src.width
        self.height = self.src.height
        
        # Read image data (handle different band configurations)
        if self.src.count >= 3:
            # RGB or multi-band - read first 3 bands
            self.image_array = self.src.read([1, 2, 3]).transpose(1, 2, 0)
        elif self.src.count == 1:
            # Single band - convert to grayscale RGB
            band = self.src.read(1)
            self.image_array = np.stack([band, band, band], axis=-1)
        else:
            raise ValueError(f"Unsupported band count: {self.src.count}")
        
        # Normalize to 0-255 uint8 if needed
        if self.image_array.dtype != np.uint8:
            arr_min, arr_max = self.image_array.min(), self.image_array.max()
            self.image_array = ((self.image_array - arr_min) / (arr_max - arr_min) * 255).astype(np.uint8)
        
        self.image_pil = Image.fromarray(self.image_array)
        
        print(f"Loaded raster: {self.raster_path.name}")
        print(f"  Size: {self.width} x {self.height}")
        print(f"  CRS: {self.crs}")
        print(f"  Bounds: {self.bounds}")
    
    def pixel_to_geo(self, col: int, row: int) -> tuple:
        """Convert pixel coordinates to geographic coordinates."""
        return xy(self.transform, row, col)
    
    def geo_to_pixel(self, x: float, y: float) -> tuple:
        """Convert geographic coordinates to pixel coordinates."""
        row, col = rowcol(self.transform, x, y)
        return col, row
    
    def get_bounds_wgs84(self) -> tuple:
        """Get bounds in WGS84 (lat/lon) for Leaflet."""
        from pyproj import Transformer
        
        if self.crs.to_epsg() == 4326:
            return [
                [self.bounds.bottom, self.bounds.left],
                [self.bounds.top, self.bounds.right]
            ]
        
        transformer = Transformer.from_crs(self.crs, "EPSG:4326", always_xy=True)
        min_lon, min_lat = transformer.transform(self.bounds.left, self.bounds.bottom)
        max_lon, max_lat = transformer.transform(self.bounds.right, self.bounds.top)
        
        return [[min_lat, min_lon], [max_lat, max_lon]]
    
    def get_center_wgs84(self) -> tuple:
        """Get center point in WGS84 (lat, lon) for Leaflet."""
        bounds = self.get_bounds_wgs84()
        lat = (bounds[0][0] + bounds[1][0]) / 2
        lon = (bounds[0][1] + bounds[1][1]) / 2
        return (lat, lon)
    
    def save_png_for_overlay(self, output_path: str = None) -> str:
        """Save image as PNG for Leaflet overlay."""
        if output_path is None:
            output_path = str(self.raster_path.with_suffix('.png'))
        self.image_pil.save(output_path)
        return output_path
    
    def close(self):
        """Close the raster file."""
        if self.src:
            self.src.close()

## 4. Load Your Raster

In [None]:
# === CONFIGURE YOUR RASTER PATH HERE ===
RASTER_PATH = "/home/john/DevDrive/sam3/sam3/assets/gis_segmentation_test_data/10259550.tif"  # <-- Update this path

# Load the raster
raster = GeoRasterLoader(RASTER_PATH)

## 5. Interactive Map with Selection Tools

In [None]:
class SAM3MapInterface:
    """
    Interactive map interface for SAM3 segmentation with ipyleaflet.
    Supports point clicks and box drawing for segmentation prompts.
    """
    
    def __init__(self, raster_loader: GeoRasterLoader, processor: Sam3Processor):
        self.raster = raster_loader
        self.processor = processor
        self.inference_state = None
        
        # Storage for prompts and results
        self.point_prompts = []  # List of (x, y, label) in pixel coords
        self.box_prompts = []    # List of [x1, y1, x2, y2] in pixel coords
        self.text_prompt = ""    # Text prompt for concept segmentation
        self.current_masks = None
        self.current_scores = None
        self.mask_geometries = []  # GeoJSON geometries of masks
        
        # Initialize map and widgets
        self._setup_map()
        self._setup_controls()
        self._setup_sam3_image()
    
    def _setup_sam3_image(self):
        """Pre-compute image embeddings for SAM3."""
        print("Computing image embeddings...")
        self.inference_state = self.processor.set_image(self.raster.image_pil)
        print("Image embeddings ready!")
    
    def _setup_map(self):
        """Initialize the Leaflet map with raster overlay."""
        center = self.raster.get_center_wgs84()
        bounds = self.raster.get_bounds_wgs84()
        
        # Create map
        self.map = Map(
            center=center,
            zoom=15,
            scroll_wheel_zoom=True,
            layout=widgets.Layout(width='100%', height='600px')
        )
        
        # Add base layer options
        self.basemap_osm = TileLayer(url='https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png', name='OpenStreetMap')
        self.basemap_satellite = TileLayer(
            url='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
            name='Satellite'
        )
        self.map.add_layer(self.basemap_satellite)
        
        # Save raster as PNG and add as overlay
        self.png_path = self.raster.save_png_for_overlay()
        
        # Create image overlay (for local files, we need to serve or embed)
        # For local development, convert to base64
        import base64
        with open(self.png_path, 'rb') as f:
            img_data = base64.b64encode(f.read()).decode()
        img_url = f'data:image/png;base64,{img_data}'
        
        self.image_overlay = ImageOverlay(
            url=img_url,
            bounds=bounds,
            name='Raster'
        )
        self.map.add_layer(self.image_overlay)
        
        # Layer for markers (point prompts)
        self.marker_layer = LayerGroup(name='Points')
        self.map.add_layer(self.marker_layer)
        
        # Layer for segmentation results
        self.result_layer = LayerGroup(name='Segmentation')
        self.map.add_layer(self.result_layer)
        
        # Draw control for boxes
        self.draw_control = DrawControl(
            rectangle={'shapeOptions': {'color': '#00ff00', 'fillOpacity': 0.2}},
            polygon={},
            polyline={},
            circle={},
            circlemarker={},
            marker={}
        )
        self.draw_control.on_draw(self._handle_draw)
        self.map.add_control(self.draw_control)
        
        # Handle map clicks for point prompts
        self.map.on_interaction(self._handle_click)
        
        # Fit bounds
        self.map.fit_bounds(bounds)
    
    def _setup_controls(self):
        """Setup control widgets."""
        # Output area for status/results
        self.output = Output(layout=widgets.Layout(width='100%', height='200px', overflow='auto'))
        
        # Text prompt input
        self.text_input = Text(
            placeholder='Enter text prompt (e.g., "building", "road", "vegetation")',
            description='Text Prompt:',
            layout=widgets.Layout(width='400px'),
            style={'description_width': 'initial'}
        )
        
        # Point label selector (positive/negative)
        self.point_label = Dropdown(
            options=[('Positive (include)', 1), ('Negative (exclude)', 0)],
            value=1,
            description='Click Mode:',
            style={'description_width': 'initial'}
        )
        
        # Confidence threshold
        self.confidence_slider = FloatSlider(
            value=0.5,
            min=0.0,
            max=1.0,
            step=0.05,
            description='Min Confidence:',
            style={'description_width': 'initial'}
        )
        
        # Action buttons
        self.segment_btn = Button(description='Run Segmentation', button_style='success')
        self.segment_btn.on_click(self._run_segmentation)
        
        self.clear_btn = Button(description='Clear Prompts', button_style='warning')
        self.clear_btn.on_click(self._clear_prompts)
        
        self.export_btn = Button(description='Export to GDB', button_style='info')
        self.export_btn.on_click(self._export_to_gdb)
        
        # Layout
        controls_row1 = HBox([self.text_input, self.point_label])
        controls_row2 = HBox([self.confidence_slider, self.segment_btn, self.clear_btn, self.export_btn])
        
        self.controls = VBox([controls_row1, controls_row2])
    
    def _handle_click(self, **kwargs):
        """Handle map click events for point prompts."""
        if kwargs.get('type') == 'click':
            coords = kwargs.get('coordinates')
            if coords:
                lat, lon = coords
                
                # Convert to pixel coordinates
                from pyproj import Transformer
                
                if self.raster.crs.to_epsg() != 4326:
                    transformer = Transformer.from_crs("EPSG:4326", self.raster.crs, always_xy=True)
                    x, y = transformer.transform(lon, lat)
                else:
                    x, y = lon, lat
                
                col, row = self.raster.geo_to_pixel(x, y)
                
                # Check if within bounds
                if 0 <= col < self.raster.width and 0 <= row < self.raster.height:
                    label = self.point_label.value
                    self.point_prompts.append((col, row, label))
                    
                    # Add marker
                    color = 'green' if label == 1 else 'red'
                    marker = Marker(
                        location=(lat, lon),
                        draggable=False,
                        title=f"{'Positive' if label == 1 else 'Negative'} point"
                    )
                    self.marker_layer.add_layer(marker)
                    
                    with self.output:
                        print(f"Added {'positive' if label == 1 else 'negative'} point at pixel ({col}, {row})")
    
    def _handle_draw(self, target, action, geo_json):
        """Handle draw events for box prompts."""
        if action == 'created' and geo_json['geometry']['type'] == 'Polygon':
            coords = geo_json['geometry']['coordinates'][0]
            
            # Get bounding box from polygon coords
            lons = [c[0] for c in coords]
            lats = [c[1] for c in coords]
            
            from pyproj import Transformer
            
            if self.raster.crs.to_epsg() != 4326:
                transformer = Transformer.from_crs("EPSG:4326", self.raster.crs, always_xy=True)
                x_coords = []
                y_coords = []
                for lon, lat in zip(lons, lats):
                    x, y = transformer.transform(lon, lat)
                    x_coords.append(x)
                    y_coords.append(y)
            else:
                x_coords, y_coords = lons, lats
            
            # Convert to pixel coordinates
            min_x, max_x = min(x_coords), max(x_coords)
            min_y, max_y = min(y_coords), max(y_coords)
            
            col1, row1 = self.raster.geo_to_pixel(min_x, max_y)  # top-left
            col2, row2 = self.raster.geo_to_pixel(max_x, min_y)  # bottom-right
            
            # Clamp to image bounds
            col1 = max(0, min(col1, self.raster.width - 1))
            col2 = max(0, min(col2, self.raster.width - 1))
            row1 = max(0, min(row1, self.raster.height - 1))
            row2 = max(0, min(row2, self.raster.height - 1))
            
            self.box_prompts.append([col1, row1, col2, row2])
            
            with self.output:
                print(f"Added box prompt: ({col1}, {row1}) to ({col2}, {row2})")
    
    def _run_segmentation(self, btn=None):
        """Execute SAM3 segmentation with current prompts."""
        with self.output:
            clear_output(wait=True)
            print("Running SAM3 segmentation...")
            
            try:
                # Reset inference state for fresh prompts
                self.inference_state = self.processor.set_image(self.raster.image_pil)
                
                # Determine prompt type
                text = self.text_input.value.strip()
                
                if text:
                    # Text prompt segmentation (Promptable Concept Segmentation)
                    print(f"Using text prompt: '{text}'")
                    
                    output = self.processor.set_text_prompt(
                        state=self.inference_state,
                        prompt=text
                    )
                    
                    self.current_masks = output["masks"]
                    self.current_scores = output["scores"]
                    boxes = output.get("boxes", [])

                    # ======================================================
                    # ✅ DEBUG BLOCK – RAW SCORE INSPECTION (TEXT MODE)
                    # ======================================================
                    min_conf = self.confidence_slider.value

                    print("\n=== SAM3 SCORE DEBUG ===")
                    print("Prompt:", text)
                    print("Min confidence slider:", min_conf)

                    if self.current_scores:
                        scores = [
                            s.item() if isinstance(s, torch.Tensor) else float(s)
                            for s in self.current_scores
                        ]
                        print("Total regions:", len(scores))
                        print("Max score:", max(scores))
                        print("Mean score:", sum(scores) / len(scores))
                        print("Top 10 scores:", sorted(scores, reverse=True)[:10])
                    else:
                        print("No scores returned at all")

                    print("========================\n")
                    # ======================================================
                    
                elif self.point_prompts or self.box_prompts:
                    # Visual prompt segmentation (Promptable Visual Segmentation)
                    print(f"Using {len(self.point_prompts)} point(s) and {len(self.box_prompts)} box(es)")
                    
                    # Format prompts for SAM3
                    if self.point_prompts:
                        points = [[p[0], p[1]] for p in self.point_prompts]
                        labels = [p[2] for p in self.point_prompts]
                        output = self.processor.set_point_prompt(
                            state=self.inference_state,
                            points=points,
                            labels=labels
                        )
                    elif self.box_prompts:
                        output = self.processor.set_box_prompt(
                            state=self.inference_state,
                            box=self.box_prompts[0]  # Use first box
                        )
                    
                    self.current_masks = output["masks"]
                    self.current_scores = output.get("scores", [1.0] * len(output["masks"]))
                    
                else:
                    print("No prompts specified! Click on map to add points, draw boxes, or enter text.")
                    return
                
                # Filter by confidence
                min_conf = self.confidence_slider.value
                
                # Convert masks to geometries and display
                self._process_masks(min_conf)
                
                print(f"Segmentation complete! Found {len(self.mask_geometries)} region(s)")

    def _process_masks(self, min_confidence: float):
        """Convert mask arrays to georeferenced geometries."""
        self.mask_geometries = []
        
        # Clear previous results
        self.result_layer.clear_layers()
        
        if self.current_masks is None:
            return
        
        for i, (mask, score) in enumerate(zip(self.current_masks, self.current_scores)):
            # Handle different mask formats
            if isinstance(score, torch.Tensor):
                score = score.item()
            if isinstance(mask, torch.Tensor):
                mask = mask.cpu().numpy()
            
            if score < min_confidence:
                continue
            
            # Ensure mask is 2D binary
            if mask.ndim > 2:
                mask = mask.squeeze()
            mask_binary = (mask > 0.5).astype(np.uint8)
            
            # Convert to vector using rasterio.features.shapes
            for geom, value in shapes(mask_binary, transform=self.raster.transform):
                if value == 1:
                    # Convert to shapely geometry
                    poly = shape(geom)
                    
                    # Store with metadata
                    self.mask_geometries.append({
                        'geometry': poly,
                        'score': score,
                        'mask_id': i,
                        'prompt': self.text_input.value or 'visual_prompt'
                    })
                    
                    # Convert to WGS84 for display
                    self._add_geometry_to_map(poly, score)
        
        with self.output:
            print(f"Processed {len(self.mask_geometries)} geometries")
    
    def _add_geometry_to_map(self, geometry, score: float):
        """Add a geometry to the map in WGS84."""
        from pyproj import Transformer
        from shapely.ops import transform as shapely_transform
        
        # Transform to WGS84 if needed
        if self.raster.crs.to_epsg() != 4326:
            transformer = Transformer.from_crs(self.raster.crs, "EPSG:4326", always_xy=True)
            geometry_wgs84 = shapely_transform(transformer.transform, geometry)
        else:
            geometry_wgs84 = geometry
        
        # Create GeoJSON layer
        geojson_data = {
            'type': 'Feature',
            'properties': {'score': score},
            'geometry': mapping(geometry_wgs84)
        }
        
        # Color based on score
        color = f'#{int(255*(1-score)):02x}{int(255*score):02x}00'
        
        geojson_layer = GeoJSON(
            data=geojson_data,
            style={'color': color, 'fillColor': color, 'fillOpacity': 0.4, 'weight': 2}
        )
        self.result_layer.add_layer(geojson_layer)
    
    def _clear_prompts(self, btn=None):
        """Clear all prompts and results."""
        self.point_prompts = []
        self.box_prompts = []
        self.current_masks = None
        self.current_scores = None
        self.mask_geometries = []
        
        self.marker_layer.clear_layers()
        self.result_layer.clear_layers()
        self.draw_control.clear()
        
        with self.output:
            clear_output(wait=True)
            print("Cleared all prompts and results")
    
    def _export_to_gdb(self, btn=None):
        """Export segmentation results to File Geodatabase."""
        with self.output:
            if not self.mask_geometries:
                print("No geometries to export! Run segmentation first.")
                return
            
            # Create output path
            output_gdb = self.raster.raster_path.parent / f"{self.raster.raster_path.stem}_segments.gdb"
            
            print(f"Exporting to: {output_gdb}")
            
            # Create GeoDataFrame
            gdf = gpd.GeoDataFrame(
                [
                    {
                        'geometry': g['geometry'],
                        'score': g['score'],
                        'mask_id': g['mask_id'],
                        'prompt': g['prompt']
                    }
                    for g in self.mask_geometries
                ],
                crs=self.raster.crs
            )
            
            # Export to GDB using Fiona's OpenFileGDB driver
            try:
                gdf.to_file(str(output_gdb), driver='OpenFileGDB', layer='segments')
                print(f"Successfully exported {len(gdf)} features to {output_gdb}")
            except Exception as e:
                # Fallback to GeoPackage if OpenFileGDB not available
                output_gpkg = output_gdb.with_suffix('.gpkg')
                print(f"OpenFileGDB driver not available. Exporting to GeoPackage: {output_gpkg}")
                gdf.to_file(str(output_gpkg), driver='GPKG', layer='segments')
                print(f"Successfully exported {len(gdf)} features to {output_gpkg}")
    
    def display(self):
        """Display the complete interface."""
        return VBox([self.controls, self.map, self.output])
    
    def get_geodataframe(self) -> gpd.GeoDataFrame:
        """Get current segmentation results as GeoDataFrame."""
        if not self.mask_geometries:
            return gpd.GeoDataFrame()
        
        return gpd.GeoDataFrame(
            [
                {
                    'geometry': g['geometry'],
                    'score': g['score'],
                    'mask_id': g['mask_id'],
                    'prompt': g['prompt']
                }
                for g in self.mask_geometries
            ],
            crs=self.raster.crs
        )

## 6. Launch Interactive Map Interface

In [None]:
# Create and display the interactive interface
map_interface = SAM3MapInterface(raster, processor)
display(map_interface.display())

### How to Use:

1. **Text Prompt**: Enter a concept (e.g., "building", "road", "tree") and click "Run Segmentation"
2. **Point Prompts**: 
   - Select "Positive" or "Negative" mode
   - Click on the map to add points
   - Green = include, Red = exclude
3. **Box Prompts**: Draw rectangles on the map using the draw tool
4. **Run Segmentation**: Click button to process
5. **Export**: Click "Export to GDB" to save results

6.1 "Normalize" the segementations into GIS shapes and square geometry before export (Optional)

In [None]:
from shapely import minimum_rotated_rectangle
from shapely.geometry import box

# Get raw results
gdf_raw = map_interface.get_geodataframe()

# Square them (pick one method)
gdf_raw['geometry'] = gdf_raw['geometry'].apply(minimum_rotated_rectangle)  # oriented box
# OR
# gdf_raw['geometry'] = gdf_raw['geometry'].apply(lambda g: box(*g.bounds))  # axis-aligned box

# Export
gdf_raw.to_file("output.gdb", driver='OpenFileGDB', layer='segments')

## 7. Manual Export Options

In [None]:
# Get results as GeoDataFrame for further processing
gdf = map_interface.get_geodataframe()

if not gdf.empty:
    print(f"GeoDataFrame with {len(gdf)} features")
    print(f"CRS: {gdf.crs}")
    display(gdf.head())

In [None]:
# Export to various formats
def export_results(gdf: gpd.GeoDataFrame, output_dir: str, base_name: str = "segments"):
    """
    Export GeoDataFrame to multiple formats.
    
    Args:
        gdf: GeoDataFrame with segmentation results
        output_dir: Output directory path
        base_name: Base name for output files
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True)
    
    if gdf.empty:
        print("No features to export")
        return
    
    # File Geodatabase
    try:
        gdb_path = output_dir / f"{base_name}.gdb"
        gdf.to_file(str(gdb_path), driver='OpenFileGDB', layer=base_name)
        print(f"✓ File Geodatabase: {gdb_path}")
    except Exception as e:
        print(f"✗ File Geodatabase failed: {e}")
    
    # GeoPackage
    gpkg_path = output_dir / f"{base_name}.gpkg"
    gdf.to_file(str(gpkg_path), driver='GPKG', layer=base_name)
    print(f"✓ GeoPackage: {gpkg_path}")
    
    # Shapefile
    shp_path = output_dir / f"{base_name}.shp"
    gdf.to_file(str(shp_path), driver='ESRI Shapefile')
    print(f"✓ Shapefile: {shp_path}")
    
    # GeoJSON
    geojson_path = output_dir / f"{base_name}.geojson"
    gdf.to_file(str(geojson_path), driver='GeoJSON')
    print(f"✓ GeoJSON: {geojson_path}")
    
    return {
        'gdb': gdb_path if gdb_path.exists() else None,
        'gpkg': gpkg_path,
        'shp': shp_path,
        'geojson': geojson_path
    }

# Example usage:
# export_results(gdf, "./outputs", "my_segments")

## 8. Batch Processing (Optional)

In [None]:
def batch_segment_with_text(raster_path: str, text_prompts: list, output_gdb: str, 
                            min_confidence: float = 0.5):
    """
    Batch process a raster with multiple text prompts.
    
    Args:
        raster_path: Path to input raster
        text_prompts: List of text prompts (e.g., ["building", "road", "tree"])
        output_gdb: Output geodatabase path
        min_confidence: Minimum confidence threshold
    """
    # Load raster
    raster = GeoRasterLoader(raster_path)
    
    # Initialize model
    model = build_sam3_image_model()
    processor = Sam3Processor(model)
    
    all_results = []
    
    for prompt in text_prompts:
        print(f"Processing: '{prompt}'")
        
        # Run segmentation
        inference_state = processor.set_image(raster.image_pil)
        output = processor.set_text_prompt(state=inference_state, prompt=prompt)
        
        masks = output["masks"]
        scores = output["scores"]
        
        # Convert to geometries
        for i, (mask, score) in enumerate(zip(masks, scores)):
            if isinstance(score, torch.Tensor):
                score = score.item()
            if score < min_confidence:
                continue
            
            if isinstance(mask, torch.Tensor):
                mask = mask.cpu().numpy()
            if mask.ndim > 2:
                mask = mask.squeeze()
            
            mask_binary = (mask > 0.5).astype(np.uint8)
            
            for geom, value in shapes(mask_binary, transform=raster.transform):
                if value == 1:
                    all_results.append({
                        'geometry': shape(geom),
                        'class': prompt,
                        'score': score,
                        'mask_id': i
                    })
    
    # Create GeoDataFrame and export
    gdf = gpd.GeoDataFrame(all_results, crs=raster.crs)
    
    if not gdf.empty:
        try:
            gdf.to_file(output_gdb, driver='OpenFileGDB', layer='segments')
        except:
            gdf.to_file(output_gdb.replace('.gdb', '.gpkg'), driver='GPKG', layer='segments')
        
        print(f"Exported {len(gdf)} features")
    
    raster.close()
    return gdf

# Example batch processing:
# gdf = batch_segment_with_text(
#     "path/to/raster.tif",
#     ["building", "road", "vegetation", "water"],
#     "./outputs/segments.gdb"
# )

## 9. Visualization of Results

In [None]:
def visualize_segmentation(raster: GeoRasterLoader, gdf: gpd.GeoDataFrame, figsize=(15, 10)):
    """
    Visualize segmentation results overlaid on the raster.
    """
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    
    # Original image
    axes[0].imshow(raster.image_array)
    axes[0].set_title("Original Raster")
    axes[0].axis('off')
    
    # Image with segmentation overlay
    axes[1].imshow(raster.image_array)
    
    if not gdf.empty:
        # Create colormap for different classes/prompts
        unique_prompts = gdf['prompt'].unique() if 'prompt' in gdf.columns else ['segment']
        colors = plt.cm.tab10(np.linspace(0, 1, len(unique_prompts)))
        color_map = dict(zip(unique_prompts, colors))
        
        for idx, row in gdf.iterrows():
            prompt = row.get('prompt', 'segment')
            color = color_map[prompt]
            
            if row.geometry.geom_type == 'Polygon':
                coords = np.array(row.geometry.exterior.coords)
                # Convert geo coords to pixel coords
                pixel_coords = np.array([raster.geo_to_pixel(x, y) for x, y in coords])
                axes[1].fill(pixel_coords[:, 0], pixel_coords[:, 1], 
                           alpha=0.4, color=color, edgecolor=color, linewidth=2)
            elif row.geometry.geom_type == 'MultiPolygon':
                for poly in row.geometry.geoms:
                    coords = np.array(poly.exterior.coords)
                    pixel_coords = np.array([raster.geo_to_pixel(x, y) for x, y in coords])
                    axes[1].fill(pixel_coords[:, 0], pixel_coords[:, 1], 
                               alpha=0.4, color=color, edgecolor=color, linewidth=2)
        
        # Add legend
        legend_patches = [plt.Rectangle((0,0),1,1, fc=color_map[p], alpha=0.6) for p in unique_prompts]
        axes[1].legend(legend_patches, unique_prompts, loc='upper right')
    
    axes[1].set_title("Segmentation Results")
    axes[1].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize results
# visualize_segmentation(raster, gdf)

## 10. Cleanup

In [None]:
# Close raster file when done
raster.close()
print("Cleanup complete!")