In [None]:
!pip install rasterio geopandas ultralytics
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install opencv-python mediapipe

Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-opc59tia
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-opc59tia
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
import os
import cv2
import time
import json
import torch
import rasterio
import numpy as np
import geopandas as gpd
import matplotlib.pyplot as plt

# Dask import
try:
    import dask.bag as db
    from dask.diagnostics import ProgressBar
    DASK_AVAILABLE = True
except ImportError:
    DASK_AVAILABLE = False
    print("⚠ Dask not available, using sequential processing")

from tqdm import tqdm
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon as MPLPolygon
from shapely.geometry import Point, Polygon, MultiPolygon

from segment_anything import sam_model_registry, SamPredictor

# Set up logging
import time
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
class DetectionState:
    """State container for detection pipeline"""

    def __init__(self, image_path, model, config):
        self.image_path = image_path
        self.model = model
        self.config = config

        # State variables
        self.full_image = None
        self.original_dims = None
        self.resized_image = None
        self.scaled_dims = None
        self.transform = None
        self.crs = None

        # Detection state
        self.boxes_list = []
        self.segmentation_results = []
        self.gdf = None
        self.detection_info = []

        # SAM state
        self.sam_predictor = None

        # Timing
        self.start_time = time.process_time()

    def __repr__(self):
        return (f"DetectionState(boxes={len(self.boxes_list)}, "
                f"segments={len(self.segmentation_results)}, "
                f"gdf_size={len(self.gdf) if self.gdf is not None else 0})")


class StateTransition:
    """Base class for state transitions"""

    def __call__(self, state):
        raise NotImplementedError


class LoadImageTransition(StateTransition):
    """Transition: Load and prepare image"""

    def __call__(self, state):
        try:
            with rasterio.open(state.image_path) as src:
                state.original_dims = (src.width, src.height)
                state.transform = src.transform
                state.crs = src.crs

                print(f"Original image size: {state.original_dims[0]}x{state.original_dims[1]}")

                if src.count >= 3:
                    state.full_image = src.read([1, 2, 3])
                    state.full_image = np.transpose(state.full_image, (1, 2, 0))
                else:
                    state.full_image = src.read(1)
                    if len(state.full_image.shape) == 2:
                        state.full_image = cv2.cvtColor(state.full_image, cv2.COLOR_GRAY2RGB)

                state.full_image = np.clip(state.full_image, 0, 255).astype(np.uint8)
        except:
            print("Loading as regular image (not GeoTIFF)...")
            state.full_image = cv2.imread(state.image_path)
            if state.full_image is None:
                raise ValueError(f"Could not load image from {state.image_path}")

            state.full_image = cv2.cvtColor(state.full_image, cv2.COLOR_BGR2RGB)
            h, w = state.full_image.shape[:2]
            state.original_dims = (w, h)

            print(f"Original image size: {w}x{h}")

            state.transform = from_bounds(0, 0, w, h, w, h)
            state.crs = None

        return state


class ResizeImageTransition(StateTransition):
    """Transition: Resize image based on resolution config"""

    def __call__(self, state):
        resolution = state.config['resolution']
        w, h = state.original_dims

        new_w = int(w * resolution)
        new_h = int(h * resolution)
        state.scaled_dims = (new_w, new_h)

        print(f"Processing at: {new_w}x{new_h} (resolution={resolution})")

        state.resized_image = cv2.resize(
            state.full_image,
            (new_w, new_h),
            interpolation=cv2.INTER_LINEAR
        )

        return state


class YOLODetectionTransition(StateTransition):
    """Transition: Perform YOLO detection on tiles"""

    def __call__(self, state):
        tile_size = state.config['tile_size']
        overlap = state.config['overlap']
        conf_threshold = state.config['conf_threshold']
        iou_threshold = state.config['iou_threshold']
        max_det = state.config['max_det']
        resolution = state.config['resolution']

        new_w, new_h = state.scaled_dims
        step_size = tile_size - overlap

        # Generate tile positions
        tile_positions = []
        for y in range(0, new_h, step_size):
            for x in range(0, new_w, step_size):
                x_end = min(x + tile_size, new_w)
                y_end = min(y + tile_size, new_h)
                tile_positions.append((x, y, x_end, y_end))

        print(f"Processing {len(tile_positions)} tiles with YOLO...")

        # Always use sequential for Colab - Dask causes issues with GPU memory
        state.boxes_list = self._process_tiles_sequential(
            state, tile_positions, tile_size,
            conf_threshold, iou_threshold, max_det, resolution
        )

        print(f"YOLO detected {len(state.boxes_list)} boxes before NMS")
        return state

    def _process_tiles_sequential(self, state, tile_positions, tile_size, conf, iou, max_det, resolution):
        """Process tiles sequentially (stable for Colab)"""
        all_boxes = []

        # Process in smaller batches to avoid memory issues
        batch_size = min(8, len(tile_positions))

        for batch_idx in range(0, len(tile_positions), batch_size):
            batch = tile_positions[batch_idx:batch_idx + batch_size]

            for x, y, x_end, y_end in tqdm(batch, desc=f"YOLO Batch {batch_idx//batch_size + 1}"):
                win_w = x_end - x
                win_h = y_end - y
                tile = state.resized_image[y:y_end, x:x_end]

                if tile.shape[:2] != (tile_size, tile_size):
                    tile = cv2.resize(tile, (tile_size, tile_size))
                    scale_x = win_w / tile_size
                    scale_y = win_h / tile_size
                else:
                    scale_x = scale_y = 1.0

                tile = np.ascontiguousarray(tile)

                results = state.model.predict(
                    source=tile,
                    conf=conf,
                    iou=iou,
                    max_det=max_det,
                    save_txt=False,
                    save_conf=True,
                    verbose=False
                )

                if results[0].boxes is not None and len(results[0].boxes) > 0:
                    boxes = results[0].boxes.xyxy.cpu().numpy()
                    confidences = results[0].boxes.conf.cpu().numpy()

                    for i in range(len(boxes)):
                        x1, y1, x2, y2 = boxes[i]
                        x1_orig = int((x1 * scale_x + x) / resolution)
                        y1_orig = int((y1 * scale_y + y) / resolution)
                        x2_orig = int((x2 * scale_x + x) / resolution)
                        y2_orig = int((y2 * scale_y + y) / resolution)

                        all_boxes.append({
                            'bbox': [x1_orig, y1_orig, x2_orig, y2_orig],
                            'confidence': float(confidences[i])
                        })

        return all_boxes


class NMSTransition(StateTransition):
    """Transition: Apply Non-Maximum Suppression"""

    def __call__(self, state):
        if not state.config.get('apply_nms', True) or len(state.boxes_list) == 0:
            return state

        nms_iou = state.config.get('nms_iou', 0.3)
        original_count = len(state.boxes_list)
        state.boxes_list = self._apply_nms(state.boxes_list, nms_iou)
        print(f"After NMS: {original_count} → {len(state.boxes_list)} boxes")

        return state

    @staticmethod
    def _apply_nms(boxes_list, iou_threshold):
        if len(boxes_list) == 0:
            return boxes_list

        boxes_list = sorted(boxes_list, key=lambda x: x['confidence'], reverse=True)
        keep = []
        removed = set()

        for i in range(len(boxes_list)):
            if i in removed:
                continue

            keep.append(boxes_list[i])
            box_i = boxes_list[i]['bbox']

            for j in range(i + 1, len(boxes_list)):
                if j in removed:
                    continue

                box_j = boxes_list[j]['bbox']
                iou = NMSTransition._calculate_iou(box_i, box_j)

                if iou > iou_threshold:
                    removed.add(j)

        return keep

    @staticmethod
    def _calculate_iou(box1, box2):
        x1_inter = max(box1[0], box2[0])
        y1_inter = max(box1[1], box2[1])
        x2_inter = min(box1[2], box2[2])
        y2_inter = min(box1[3], box2[3])

        if x2_inter < x1_inter or y2_inter < y1_inter:
            return 0.0

        inter_area = (x2_inter - x1_inter) * (y2_inter - y1_inter)
        box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
        box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
        union_area = box1_area + box2_area - inter_area

        return inter_area / union_area if union_area > 0 else 0.0


class SAMInitTransition(StateTransition):
    """Transition: Initialize SAM model"""

    def __call__(self, state):
        sam_checkpoint = state.config['sam_checkpoint']
        model_type = state.config.get('sam_model_type', 'vit_h')
        device = state.config.get('device', 'cuda')

        if not os.path.exists(sam_checkpoint):
            raise FileNotFoundError(f"SAM checkpoint not found at {sam_checkpoint}")

        print(f"\nLoading SAM model ({model_type})...")

        if device == "cuda" and not torch.cuda.is_available():
            print("CUDA not available, using CPU")
            device = "cpu"

        sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        sam.to(device=device)
        state.sam_predictor = SamPredictor(sam)

        print(f"SAM model loaded on {device}")
        return state


class SAMSegmentationTransition(StateTransition):
    """Transition: Perform SAM segmentation with optimization"""

    def __call__(self, state):
        print(f"Segmenting {len(state.boxes_list)} objects with SAM...")

        h, w = state.full_image.shape[:2]
        max_image_size = state.config.get('sam_max_image_size', 4096)
        use_chunked = max(h, w) > max_image_size

        if use_chunked:
            print(f"⚠ Large image ({w}x{h}), using optimized chunked processing")
            state.segmentation_results = self._segment_chunked_optimized(state, max_image_size)
        else:
            print(f"Setting full image for SAM ({w}x{h})")
            state.sam_predictor.set_image(state.full_image)
            state.segmentation_results = self._segment_full_optimized(state)

        print(f"SAM segmentation complete: {len(state.segmentation_results)} masks generated")
        return state

    def _segment_full_optimized(self, state):
        """Optimized full image segmentation"""
        results = []

        # Sort boxes by Y coordinate for better memory access patterns
        sorted_boxes = sorted(state.boxes_list, key=lambda x: x['bbox'][1])

        # Process in batches to avoid GPU memory issues
        batch_size = min(50, len(sorted_boxes))

        for i in range(0, len(sorted_boxes), batch_size):
            batch = sorted_boxes[i:i + batch_size]

            for box_info in tqdm(batch, desc=f"SAM Batch {i//batch_size + 1}"):
                bbox = box_info['bbox']
                confidence = box_info['confidence']
                input_box = np.array(bbox)

                masks, scores, _ = state.sam_predictor.predict(
                    point_coords=None,
                    point_labels=None,
                    box=input_box[None, :],
                    multimask_output=False,
                )

                mask = masks[0]
                mask_coords = self._mask_to_polygon(mask)

                if mask_coords is not None and len(mask_coords) >= 3:
                    results.append({
                        'mask': mask,
                        'coords': mask_coords,
                        'bbox': bbox,
                        'confidence': confidence,
                        'sam_score': float(scores[0])
                    })

        return results

    def _segment_chunked_optimized(self, state, max_size):
        """Optimized chunked processing"""
        h, w = state.full_image.shape[:2]
        results = []

        # Group boxes by spatial proximity to share crops
        grid_size = 200
        box_groups = {}

        for idx, box_info in enumerate(state.boxes_list):
            bbox = box_info['bbox']
            grid_x = bbox[0] // grid_size
            grid_y = bbox[1] // grid_size
            key = (grid_x, grid_y)

            if key not in box_groups:
                box_groups[key] = []
            box_groups[key].append((idx, box_info))

        print(f"Grouped {len(state.boxes_list)} boxes into {len(box_groups)} spatial groups")

        # Process each group
        for group_idx, (key, group_boxes) in enumerate(tqdm(box_groups.items(), desc="SAM Groups")):
            if len(group_boxes) == 0:
                continue

            # Find bounding box for the entire group
            min_x = min(box[1]['bbox'][0] for box in group_boxes)
            min_y = min(box[1]['bbox'][1] for box in group_boxes)
            max_x = max(box[1]['bbox'][2] for box in group_boxes)
            max_y = max(box[1]['bbox'][3] for box in group_boxes)

            # Add padding
            box_w = max_x - min_x
            box_h = max_y - min_y
            pad_factor = 1.5
            pad_w = int(box_w * (pad_factor - 1) / 2)
            pad_h = int(box_h * (pad_factor - 1) / 2)

            crop_x1 = max(0, min_x - pad_w)
            crop_y1 = max(0, min_y - pad_h)
            crop_x2 = min(w, max_x + pad_w)
            crop_y2 = min(h, max_y + pad_h)

            # Ensure crop isn't too large
            crop_w = crop_x2 - crop_x1
            crop_h = crop_y2 - crop_y1

            if crop_w > max_size or crop_h > max_size:
                pad_w = min(pad_w, (max_size - box_w) // 2)
                pad_h = min(pad_h, (max_size - box_h) // 2)
                crop_x1 = max(0, min_x - pad_w)
                crop_y1 = max(0, min_y - pad_h)
                crop_x2 = min(w, max_x + pad_w)
                crop_y2 = min(h, max_y + pad_h)

            # Extract and process crop once for the whole group
            crop = state.full_image[crop_y1:crop_y2, crop_x1:crop_x2]
            state.sam_predictor.set_image(crop)

            # Process all boxes in the group
            for idx, box_info in group_boxes:
                bbox = box_info['bbox']
                confidence = box_info['confidence']

                bbox_in_crop = [
                    bbox[0] - crop_x1,
                    bbox[1] - crop_y1,
                    bbox[2] - crop_x1,
                    bbox[3] - crop_y1
                ]

                input_box = np.array(bbox_in_crop)

                try:
                    masks, scores, _ = state.sam_predictor.predict(
                        point_coords=None,
                        point_labels=None,
                        box=input_box[None, :],
                        multimask_output=False,
                    )

                    mask = masks[0]
                    mask_coords_crop = self._mask_to_polygon(mask)

                    if mask_coords_crop is not None and len(mask_coords_crop) >= 3:
                        mask_coords = mask_coords_crop + np.array([crop_x1, crop_y1])

                        results.append({
                            'mask': None,
                            'coords': mask_coords,
                            'bbox': bbox,
                            'confidence': confidence,
                            'sam_score': float(scores[0])
                        })
                except Exception as e:
                    continue

        return results

    @staticmethod
    def _mask_to_polygon(mask):
        contours, _ = cv2.findContours(
            mask.astype(np.uint8),
            cv2.RETR_EXTERNAL,
            cv2.CHAIN_APPROX_SIMPLE
        )

        if len(contours) == 0:
            return None

        contour = max(contours, key=cv2.contourArea)
        epsilon = 0.001 * cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, epsilon, True)

        return approx.reshape(-1, 2)


class CreateGeoDataFrameTransition(StateTransition):
    """Transition: Create GeoDataFrame from segmentation results"""

    def __call__(self, state):
        print(f"Creating GeoDataFrame from {len(state.segmentation_results)} results...")

        all_polygons = []
        detection_info = []

        # Process in batches to avoid memory issues
        batch_size = min(100, len(state.segmentation_results))

        for i in range(0, len(state.segmentation_results), batch_size):
            batch = state.segmentation_results[i:i + batch_size]

            for result in tqdm(batch, desc=f"Polygons Batch {i//batch_size + 1}"):
                coords = result['coords']

                try:
                    polygon_shapely = Polygon(coords)
                    if not polygon_shapely.is_valid:
                        polygon_shapely = polygon_shapely.buffer(0)
                    area_pixels = polygon_shapely.area
                except:
                    continue

                geo_coords = []
                for px, py in coords:
                    geo_x, geo_y = rasterio.transform.xy(state.transform, py, px)
                    geo_coords.append((geo_x, geo_y))

                try:
                    geo_polygon = Polygon(geo_coords)
                    if not geo_polygon.is_valid:
                        geo_polygon = geo_polygon.buffer(0)
                except:
                    continue

                all_polygons.append(geo_polygon)
                detection_info.append({
                    'centroid_x': coords[:, 0].mean(),
                    'centroid_y': coords[:, 1].mean(),
                    'area_pixels': area_pixels,
                    'confidence': result['confidence'],
                    'sam_score': result['sam_score'],
                    'num_points': len(coords)
                })

        if len(all_polygons) == 0:
            state.gdf = gpd.GeoDataFrame()
            state.detection_info = []
            return state

        state.gdf = gpd.GeoDataFrame({
            'geometry': all_polygons,
            'area_pixels': [d['area_pixels'] for d in detection_info],
            'confidence': [d['confidence'] for d in detection_info],
            'sam_score': [d['sam_score'] for d in detection_info],
            'centroid_x': [d['centroid_x'] for d in detection_info],
            'centroid_y': [d['centroid_y'] for d in detection_info],
            'num_points': [d['num_points'] for d in detection_info]
        }, crs=state.crs)

        state.detection_info = detection_info
        print(f"✓ Created GeoDataFrame with {len(state.gdf)} features")
        return state


class VisualizationTransition(StateTransition):
    """Transition: Visualize results"""

    def __call__(self, state):
        output_dir = state.config['output_dir']
        height, width = state.full_image.shape[:2]

        max_viz_size = state.config.get('max_viz_size', 2048)

        if max(width, height) > max_viz_size:
            print(f"⚠ Large image ({width}x{height}), downsampling for visualization...")
            scale = max_viz_size / max(width, height)
            viz_width = int(width * scale)
            viz_height = int(height * scale)
            viz_image = cv2.resize(state.full_image, (viz_width, viz_height),
                                  interpolation=cv2.INTER_AREA)
            print(f"  Visualization size: {viz_width}x{viz_height}")
        else:
            viz_image = state.full_image
            viz_width, viz_height = width, height
            scale = 1.0

        fig, ax = plt.subplots(1, 1, figsize=(12, 8), dpi=150)
        ax.imshow(viz_image)
        ax.set_title(f'YOLO + SAM Segmentation - {len(state.segmentation_results)} Trees Detected',
                     fontsize=14, fontweight='bold')
        ax.set_axis_off()

        if state.segmentation_results:
            print(f"  Drawing {len(state.segmentation_results)} detections...")

            # Draw in batches to avoid memory issues
            batch_size = min(500, len(state.segmentation_results))
            for i in range(0, len(state.segmentation_results), batch_size):
                batch = state.segmentation_results[i:i + batch_size]

                for result in batch:
                    coords = result['coords'] * scale
                    coords_closed = np.vstack([coords, coords[0]])

                    ax.plot(coords_closed[:, 0], coords_closed[:, 1],
                           'lime', linewidth=1, alpha=0.7)

                    polygon = MPLPolygon(coords, closed=True,
                                       facecolor='lime', alpha=0.2,
                                       edgecolor='lime', linewidth=1)
                    ax.add_patch(polygon)

            avg_conf = np.mean([r['confidence'] for r in state.segmentation_results]) if state.segmentation_results else 0
            avg_sam = np.mean([r['sam_score'] for r in state.segmentation_results]) if state.segmentation_results else 0

            stats_text = (f"Total Trees: {len(state.segmentation_results)}\n"
                         f"Avg YOLO Conf: {avg_conf:.3f}\n"
                         f"Avg SAM Score: {avg_sam:.3f}")

            ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
                   fontsize=12, verticalalignment='top',
                   bbox=dict(boxstyle="round", fc="white", alpha=0.9, pad=0.5))

        ax.set_xlim(0, viz_width)
        ax.set_ylim(viz_height, 0)

        plt.tight_layout()

        save_dpi = state.config.get('save_dpi', 150)
        image_name = state.config.get('image_name', 'output')
        output_path = f'{output_dir}/{image_name}_segmentation.png'
        plt.savefig(output_path, dpi=save_dpi, bbox_inches='tight')

        plt.close('all')

        print(f"✓ Visualization saved to {output_path}")
        return state


class SaveResultsTransition(StateTransition):
    """Transition: Save results to files"""

    def __call__(self, state):
        if state.gdf is None or len(state.gdf) == 0:
            print("No detections to save")
            return state

        output_dir = state.config['output_dir']
        image_name = state.config.get('image_name', 'output')

        # Ensure output directory exists
        os.makedirs(output_dir, exist_ok=True)

        # Save GeoJSON
        if state.gdf.crs is not None:
            geojson_path = f'{output_dir}/{image_name}_detections.geojson'
            try:
                state.gdf.to_file(geojson_path, driver='GeoJSON')
                print(f"✓ Saved GeoJSON to {geojson_path}")
            except Exception as e:
                print(f"⚠ Failed to save GeoJSON: {e}")
        else:
            print("⚠ Skipping GeoJSON (no CRS)")

        # Save CSV
        csv_path = f'{output_dir}/{image_name}_detections.csv'
        try:
            state.gdf.drop(columns='geometry').to_csv(csv_path, index=False)
            print(f"✓ Saved CSV to {csv_path}")
        except Exception as e:
            print(f"⚠ Failed to save CSV: {e}")

        # Save statistics
        if state.detection_info:
            stats = {
                'total_detections': len(state.gdf),
                'average_area': float(np.mean([d['area_pixels'] for d in state.detection_info])),
                'median_area': float(np.median([d['area_pixels'] for d in state.detection_info])),
                'average_yolo_confidence': float(np.mean([d['confidence'] for d in state.detection_info])),
                'average_sam_score': float(np.mean([d['sam_score'] for d in state.detection_info]))
            }

            stats_path = f'{output_dir}/{image_name}_statistics.json'
            try:
                with open(stats_path, 'w') as f:
                    json.dump(stats, f, indent=2)
                print(f"✓ Saved statistics to {stats_path}")
            except Exception as e:
                print(f"⚠ Failed to save statistics: {e}")

        # Save processing time
        processing_time = time.process_time() - state.start_time
        print(f"✓ Total processing time: {processing_time:.2f} seconds")

        return state


class DetectionPipeline:
    """State-space detection pipeline orchestrator"""

    def __init__(self, transitions):
        self.transitions = transitions

    def run(self, initial_state):
        state = initial_state

        for i, transition in enumerate(self.transitions):
            print(f"\n{'='*60}")
            print(f"Transition {i+1}/{len(self.transitions)}: {transition.__class__.__name__}")
            print(f"{'='*60}")
            state = transition(state)

        return state


def create_detection_pipeline(config):
    """Factory function to create the detection pipeline"""
    transitions = [
        LoadImageTransition(),
        ResizeImageTransition(),
        YOLODetectionTransition(),
        NMSTransition(),
        SAMInitTransition(),
        SAMSegmentationTransition(),
        CreateGeoDataFrameTransition(),
        VisualizationTransition(),
        SaveResultsTransition()
    ]

    return DetectionPipeline(transitions)

if __name__ == "__main__":
    # Example configuration with optimizations for Colab
    config = {
        # Core detection parameters
        'tile_size': 1280,
        'overlap': 128,
        'resolution': 0.5,
        'conf_threshold': 0.20,
        'iou_threshold': 0.45,
        'max_det': 5000,
        'apply_nms': False,
        'nms_iou': 0.3,

        # SAM parameters
        'sam_checkpoint': '/content/drive/MyDrive/AGRI/TreeCrown_Segmentation/models/segmentation/sam_vit_b_01ec64.pth',
        'sam_model_type': 'vit_b',
        'device': 'cuda',
        'sam_max_image_size': 4096,

        # Optimization settings
        'use_dask': True,
        'max_viz_size': 2048,
        'save_dpi': 150,

        # Output settings
        'output_dir': '/content/drive/MyDrive/AGRI/TreeCrown_Segmentation/output',
        'image_name': 'test_optimized'
    }

    # Ensure output directory exists
    os.makedirs(config['output_dir'], exist_ok=True)

    # Load YOLO model
    from ultralytics import YOLO
    print("Loading YOLO model...")
    yolo_model = YOLO('/content/drive/MyDrive/AGRI/TreeCrown_Segmentation/models/detection/yolo11n-100epoch-v2.pt')
    print("✓ YOLO model loaded")

    # Create initial state
    initial_state = DetectionState(
         image_path='/content/drive/MyDrive/AGRI/TreeCrown_Segmentation/ortho_872.tif',
         model=yolo_model,
         config=config
    )

    # Create and run pipeline
    pipeline = create_detection_pipeline(config)

    try:
        final_state = pipeline.run(initial_state)
        print(f"\n{'='*60}")
        print("✓ Pipeline completed successfully!")
        print(f"Final state: {final_state}")

        if final_state.gdf is not None and len(final_state.gdf) > 0:
            print(f"\nSummary:")
            print(f"  Detected trees: {len(final_state.gdf)}")
            print(f"  Output directory: {config['output_dir']}")
            print(f"  Output files:")
            for ext in ['.geojson', '.csv', '.json', '.png']:
                file_path = f"{config['output_dir']}/{config['image_name']}_*{ext}"
                import glob
                files = glob.glob(file_path)
                for f in files:
                    print(f"    - {os.path.basename(f)}")

    except Exception as e:
        print(f"\n{'='*60}")
        print(f"✗ Pipeline failed with error: {str(e)}")
        import traceback
        traceback.print_exc()

    print(f"{'='*60}")

Loading YOLO model...
✓ YOLO model loaded

Transition 1/9: LoadImageTransition
Original image size: 40244x25007

Transition 2/9: ResizeImageTransition
Processing at: 20122x12503 (resolution=0.5)

Transition 3/9: YOLODetectionTransition
Processing 198 tiles with YOLO...


YOLO Batch 1: 100%|██████████| 8/8 [00:00<00:00,  9.78it/s]
YOLO Batch 2: 100%|██████████| 8/8 [00:00<00:00, 73.33it/s]
YOLO Batch 3: 100%|██████████| 8/8 [00:00<00:00, 65.50it/s]
YOLO Batch 4: 100%|██████████| 8/8 [00:00<00:00, 76.42it/s]
YOLO Batch 5: 100%|██████████| 8/8 [00:00<00:00, 78.19it/s]
YOLO Batch 6: 100%|██████████| 8/8 [00:00<00:00, 76.05it/s]
YOLO Batch 7: 100%|██████████| 8/8 [00:00<00:00, 74.30it/s]
YOLO Batch 8: 100%|██████████| 8/8 [00:00<00:00, 78.88it/s]
YOLO Batch 9: 100%|██████████| 8/8 [00:00<00:00, 74.04it/s]
YOLO Batch 10: 100%|██████████| 8/8 [00:00<00:00, 74.79it/s]
YOLO Batch 11: 100%|██████████| 8/8 [00:00<00:00, 80.26it/s]
YOLO Batch 12: 100%|██████████| 8/8 [00:00<00:00, 63.28it/s]
YOLO Batch 13: 100%|██████████| 8/8 [00:00<00:00, 77.17it/s]
YOLO Batch 14: 100%|██████████| 8/8 [00:00<00:00, 75.67it/s]
YOLO Batch 15: 100%|██████████| 8/8 [00:00<00:00, 78.17it/s]
YOLO Batch 16: 100%|██████████| 8/8 [00:00<00:00, 75.13it/s]
YOLO Batch 17: 100%|██████████| 8

YOLO detected 2221 boxes before NMS

Transition 4/9: NMSTransition

Transition 5/9: SAMInitTransition

Loading SAM model (vit_b)...
SAM model loaded on cuda

Transition 6/9: SAMSegmentationTransition
Segmenting 2221 objects with SAM...
⚠ Large image (40244x25007), using optimized chunked processing
Grouped 2221 boxes into 1896 spatial groups


SAM Groups: 100%|██████████| 1896/1896 [13:38<00:00,  2.32it/s]


SAM segmentation complete: 2221 masks generated

Transition 7/9: CreateGeoDataFrameTransition
Creating GeoDataFrame from 2221 results...


Polygons Batch 1: 100%|██████████| 100/100 [00:00<00:00, 101.73it/s]
Polygons Batch 2: 100%|██████████| 100/100 [00:01<00:00, 92.82it/s]
Polygons Batch 3: 100%|██████████| 100/100 [00:00<00:00, 152.99it/s]
Polygons Batch 4: 100%|██████████| 100/100 [00:00<00:00, 149.78it/s]
Polygons Batch 5: 100%|██████████| 100/100 [00:00<00:00, 167.43it/s]
Polygons Batch 6: 100%|██████████| 100/100 [00:00<00:00, 158.28it/s]
Polygons Batch 7: 100%|██████████| 100/100 [00:00<00:00, 141.83it/s]
Polygons Batch 8: 100%|██████████| 100/100 [00:00<00:00, 147.46it/s]
Polygons Batch 9: 100%|██████████| 100/100 [00:00<00:00, 139.86it/s]
Polygons Batch 10: 100%|██████████| 100/100 [00:00<00:00, 136.82it/s]
Polygons Batch 11: 100%|██████████| 100/100 [00:00<00:00, 147.94it/s]
Polygons Batch 12: 100%|██████████| 100/100 [00:00<00:00, 145.34it/s]
Polygons Batch 13: 100%|██████████| 100/100 [00:00<00:00, 150.25it/s]
Polygons Batch 14: 100%|██████████| 100/100 [00:00<00:00, 140.96it/s]
Polygons Batch 15: 100%|██████

✓ Created GeoDataFrame with 2221 features

Transition 8/9: VisualizationTransition
⚠ Large image (40244x25007), downsampling for visualization...
  Visualization size: 2048x1272
  Drawing 2221 detections...
✓ Visualization saved to /content/drive/MyDrive/AGRI/TreeCrown_Segmentation/output/test_optimized_segmentation.png

Transition 9/9: SaveResultsTransition
✓ Saved GeoJSON to /content/drive/MyDrive/AGRI/TreeCrown_Segmentation/output/test_optimized_detections.geojson
✓ Saved CSV to /content/drive/MyDrive/AGRI/TreeCrown_Segmentation/output/test_optimized_detections.csv
✓ Saved statistics to /content/drive/MyDrive/AGRI/TreeCrown_Segmentation/output/test_optimized_statistics.json
✓ Total processing time: 881.27 seconds

✓ Pipeline completed successfully!
Final state: DetectionState(boxes=2221, segments=2221, gdf_size=2221)

Summary:
  Detected trees: 2221
  Output directory: /content/drive/MyDrive/AGRI/TreeCrown_Segmentation/output
  Output files:
    - test_optimized_detections.geojson
 