In [None]:
import os
import numpy as np
import torch
import open3d as o3d
import open3d.ml as _ml3d
import open3d.ml.torch as ml3d

def detect_grass_in_ply_simple(ply_file_path, model_path, green_sensitivity=0.1, ground_percentile=25, confidence_threshold=0.6):
    """
    Simplified approach to detect grass using rule-based detection
    
    Args:
        ply_file_path (str): Path to your .ply file
        model_path (str): Path to the pretrained model weights (for future use)
        green_sensitivity (float): How much greener a point needs to be (0.0-1.0)
        ground_percentile (int): Percentile for ground detection (0-100)
        confidence_threshold (float): Confidence threshold for grass detection (0.0-1.0)
    """
    print("Loading point cloud...")
    pcd = o3d.io.read_point_cloud(ply_file_path)
    
    if len(pcd.points) == 0:
        print("Error: Could not load point cloud or point cloud is empty")
        return None, None, None
    
    points = np.asarray(pcd.points)
    
    # Handle colors
    if pcd.colors and len(pcd.colors) > 0:
        colors = np.asarray(pcd.colors)
        features = np.concatenate([points, colors], axis=1)
        print(f"Using XYZ + RGB features ({features.shape[1]}D)")
    else:
        print("No colors found. Adding dummy RGB values...")
        dummy_colors = np.zeros((len(points), 3))
        colors = dummy_colors
        features = np.concatenate([points, dummy_colors], axis=1)
    
    print(f"Point cloud shape: {points.shape}")
    print(f"Features shape: {features.shape}")
    
    # Enhanced rule-based grass detection with confidence
    print("Using enhanced rule-based grass detection...")
    print(f"Settings: green_sensitivity={green_sensitivity}, ground_percentile={ground_percentile}, confidence_threshold={confidence_threshold}")
    
    # Initialize all points as unclassified
    predicted_labels = np.zeros(len(points), dtype=int)
    confidence_scores = np.zeros(len(points), dtype=float)
    
    if len(colors) > 0 and np.any(colors > 0):
        # Calculate green dominance
        red, green, blue = colors[:, 0], colors[:, 1], colors[:, 2]
        
        # Green is more dominant than red and blue
        green_dominance = green - np.maximum(red, blue)
        green_mask = green_dominance > green_sensitivity
        
        # Ground detection based on Z coordinate
        z_coords = points[:, 2]
        ground_threshold = np.percentile(z_coords, ground_percentile)
        ground_mask = z_coords < ground_threshold
        
        # Calculate confidence based on how green and how close to ground
        green_confidence = np.clip(green_dominance / 0.3, 0, 1)  # Normalize green dominance
        height_confidence = np.clip((ground_threshold - z_coords) / (ground_threshold - np.min(z_coords)), 0, 1)
        combined_confidence = (green_confidence + height_confidence) / 2
        
        # Apply masks and confidence threshold
        grass_candidates = green_mask & ground_mask
        high_confidence_grass = grass_candidates & (combined_confidence > confidence_threshold)
        
        # Assign labels
        predicted_labels[high_confidence_grass] = 6  # Vegetation class
        predicted_labels[ground_mask & ~high_confidence_grass] = 1  # Ground class
        confidence_scores = combined_confidence
        
        grass_indices = np.where(high_confidence_grass)[0]
        
        print(f"Enhanced detection results:")
        print(f"  - Total points: {len(points):,}")
        print(f"  - Ground points: {np.sum(ground_mask):,}")
        print(f"  - High-confidence grass points: {len(grass_indices):,}")
        print(f"  - Grass percentage: {len(grass_indices)/len(points)*100:.2f}%")
        print(f"  - Average grass confidence: {np.mean(combined_confidence[grass_indices]):.3f}" if len(grass_indices) > 0 else "  - No grass detected")
        
    else:
        print("No color information available for rule-based detection")
        grass_indices = np.array([])
        
    return predicted_labels, grass_indices, confidence_scores

def visualize_results(pcd, labels, class_names, grass_indices, confidence_scores=None):
    """Visualize the segmentation results with proper class colors"""
    
    # Create color map for Toronto3D classes
    colors = np.array([
        [0.5, 0.5, 0.5],  # 0: Unclassified - gray
        [0.6, 0.4, 0.2],  # 1: Ground - brown
        [1.0, 0.0, 0.0],  # 2: Building - red
        [0.0, 0.0, 1.0],  # 3: Utility line - blue
        [1.0, 1.0, 0.0],  # 4: Pole - yellow
        [0.0, 0.8, 0.0],  # 5: Tree - dark green
        [0.0, 1.0, 0.0],  # 6: Vegetation/Grass - bright green
        [1.0, 0.0, 1.0],  # 7: Vehicle - magenta
    ])
    
    # Assign colors based on labels
    point_colors = colors[labels]
    
    # If confidence scores are available, modulate the intensity for grass points
    if confidence_scores is not None and len(grass_indices) > 0:
        for idx in grass_indices:
            # Make grass color intensity proportional to confidence
            confidence = confidence_scores[idx]
            point_colors[idx] = point_colors[idx] * confidence + np.array([0.0, 0.5, 0.0]) * (1 - confidence)
    
    pcd.colors = o3d.utility.Vector3dVector(point_colors)
    
    # Print class statistics
    print("\nClass distribution:")
    unique_labels, counts = np.unique(labels, return_counts=True)
    for label, count in zip(unique_labels, counts):
        if label < len(class_names):
            percentage = count / len(labels) * 100
            print(f"  {class_names[label]}: {count:,} points ({percentage:.1f}%)")
    
    # Visualize
    print("\nDisplaying segmentation results...")
    print("Class colors:")
    for i, name in enumerate(class_names):
        if i < len(colors):
            color_str = f"RGB({colors[i][0]:.1f}, {colors[i][1]:.1f}, {colors[i][2]:.1f})"
            print(f"  {name}: {color_str}")
    print("\nPress 'Q' to close the visualization")



    o3d.visualization.draw_geometries_with_animation_callback([pcd], custom_draw)
    
    #o3d.visualization.draw_geometries([pcd], window_name="Grass Detection Results - Enhanced", width=1024, height=768)

def custom_draw(vis):
    opt = vis.get_render_option()
    opt.point_size = 1.0
    opt.background_color = np.asarray([0, 0, 0])
    return False

def print_detection_settings():
    """Print available detection settings"""
    print("\nAvailable detection settings:")
    print("  green_sensitivity: 0.0-1.0 (default 0.1)")
    print("    - Higher values = stricter green requirement")
    print("    - Lower values = more permissive green detection")
    print("  ground_percentile: 0-100 (default 25)")
    print("    - Percentage of lowest points considered 'ground'")
    print("    - Higher values = more points considered ground level")
    print("  confidence_threshold: 0.0-1.0 (default 0.6)")
    print("    - Minimum confidence score for grass classification")
    print("    - Higher values = more conservative grass detection")

if __name__ == "__main__":
    # Configuration
    ply_file_path = "final_v2.ply"  # Replace with your PLY file path
    model_path = "./logs/randlanet_toronto3d_202201071330utc.pth"
    
    # Detection parameters - adjust these for different results!
    GREEN_SENSITIVITY = 0.0      # How green does a point need to be? (0.0-1.0)
    GROUND_PERCENTILE = 50       # What % of lowest points are considered ground? (0-100)
    CONFIDENCE_THRESHOLD = 0.0   # Minimum confidence for grass detection (0.0-1.0)
    
    print("Enhanced Grass Detection with Confidence Settings")
    print("=" * 50)
    print_detection_settings()
    print(f"\nCurrent settings:")
    print(f"  Green sensitivity: {GREEN_SENSITIVITY}")
    print(f"  Ground percentile: {GROUND_PERCENTILE}")
    print(f"  Confidence threshold: {CONFIDENCE_THRESHOLD}")
    
    # Check if files exist
    if not os.path.exists(ply_file_path):
        print(f"\nError: PLY file not found at {ply_file_path}")
        print("Please update the ply_file_path variable with the correct path to your PLY file")
        exit(1)
    
    print(f"\nProcessing: {ply_file_path}")
    print("Using enhanced rule-based detection with confidence scoring...")
    
    # Run enhanced grass detection
    predicted_labels, grass_indices, confidence_scores = detect_grass_in_ply_simple(
        ply_file_path, 
        model_path,
        green_sensitivity=GREEN_SENSITIVITY,
        ground_percentile=GROUND_PERCENTILE,
        confidence_threshold=CONFIDENCE_THRESHOLD
    )
    
    if predicted_labels is not None:
        print("\nGrass detection completed successfully!")
        
        # Toronto3D class names
        class_names = [
            'Unclassified',    # 0 - gray
            'Ground',          # 1 - brown  
            'Building',        # 2 - red
            'Utility line',    # 3 - blue
            'Pole',           # 4 - yellow
            'Tree',           # 5 - dark green
            'Vegetation',     # 6 - bright green (grass)
            'Vehicle'         # 7 - magenta
        ]
        
        # Create visualization
        print("Loading point cloud for visualization...")
        pcd = o3d.io.read_point_cloud(ply_file_path)
        
        visualize_results(pcd, predicted_labels, class_names, grass_indices, confidence_scores)
        
    else:
        print("\nGrass detection failed!")
    
    print("\nTo adjust detection sensitivity, modify the parameters at the top of the script:")
    print("- GREEN_SENSITIVITY: Make stricter (higher) or more permissive (lower)")
    print("- GROUND_PERCENTILE: Include more (higher) or fewer (lower) points as ground")
    print("- CONFIDENCE_THRESHOLD: Require higher (stricter) or lower (more permissive) confidence")

In [None]:
import os
import numpy as np
import torch
import open3d as o3d
import open3d.ml as _ml3d
import open3d.ml.torch as ml3d

def detect_grass_in_ply_simple(ply_file_path, model_path, green_sensitivity=0.1, ground_percentile=25, confidence_threshold=0.6):
    """
    Simplified approach to detect grass using rule-based detection
    
    Args:
        ply_file_path (str): Path to your .ply file
        model_path (str): Path to the pretrained model weights (for future use)
        green_sensitivity (float): How much greener a point needs to be (0.0-1.0)
        ground_percentile (int): Percentile for ground detection (0-100)
        confidence_threshold (float): Confidence threshold for grass detection (0.0-1.0)
    """
    print("Loading point cloud...")
    pcd = o3d.io.read_point_cloud(ply_file_path)
    
    if len(pcd.points) == 0:
        print("Error: Could not load point cloud or point cloud is empty")
        return None, None, None
    
    points = np.asarray(pcd.points)
    
    # Handle colors
    if pcd.colors and len(pcd.colors) > 0:
        colors = np.asarray(pcd.colors)
        features = np.concatenate([points, colors], axis=1)
        print(f"Using XYZ + RGB features ({features.shape[1]}D)")
    else:
        print("No colors found. Adding dummy RGB values...")
        dummy_colors = np.zeros((len(points), 3))
        colors = dummy_colors
        features = np.concatenate([points, dummy_colors], axis=1)
    
    print(f"Point cloud shape: {points.shape}")
    print(f"Features shape: {features.shape}")
    
    # Enhanced rule-based grass detection with confidence
    print("Using enhanced rule-based grass detection...")
    print(f"Settings: green_sensitivity={green_sensitivity}, ground_percentile={ground_percentile}, confidence_threshold={confidence_threshold}")
    
    # Initialize all points as unclassified
    predicted_labels = np.zeros(len(points), dtype=int)
    confidence_scores = np.zeros(len(points), dtype=float)
    
    if len(colors) > 0 and np.any(colors > 0):
        # Calculate green dominance
        red, green, blue = colors[:, 0], colors[:, 1], colors[:, 2]
        
        # Green is more dominant than red and blue
        green_dominance = green - np.maximum(red, blue)
        green_mask = green_dominance > green_sensitivity
        
        # Ground detection based on Z coordinate
        z_coords = points[:, 2]
        ground_threshold = np.percentile(z_coords, ground_percentile)
        ground_mask = z_coords < ground_threshold
        
        # Calculate confidence based on how green and how close to ground
        green_confidence = np.clip(green_dominance / 0.3, 0, 1)  # Normalize green dominance
        height_confidence = np.clip((ground_threshold - z_coords) / (ground_threshold - np.min(z_coords)), 0, 1)
        combined_confidence = (green_confidence + height_confidence) / 2
        
        # Apply masks and confidence threshold
        grass_candidates = green_mask & ground_mask
        high_confidence_grass = grass_candidates & (combined_confidence > confidence_threshold)
        
        # Assign labels
        predicted_labels[high_confidence_grass] = 6  # Vegetation class
        predicted_labels[ground_mask & ~high_confidence_grass] = 1  # Ground class
        confidence_scores = combined_confidence
        
        grass_indices = np.where(high_confidence_grass)[0]
        
        print(f"Enhanced detection results:")
        print(f"  - Total points: {len(points):,}")
        print(f"  - Ground points: {np.sum(ground_mask):,}")
        print(f"  - High-confidence grass points: {len(grass_indices):,}")
        print(f"  - Grass percentage: {len(grass_indices)/len(points)*100:.2f}%")
        print(f"  - Average grass confidence: {np.mean(combined_confidence[grass_indices]):.3f}" if len(grass_indices) > 0 else "  - No grass detected")
        
    else:
        print("No color information available for rule-based detection")
        grass_indices = np.array([])
        
    return predicted_labels, grass_indices, confidence_scores

def save_point_cloud_with_labels(pcd, labels, confidence_scores, output_path, class_names=None):
    """
    Save the point cloud with class labels and confidence scores
    
    Args:
        pcd: Open3D point cloud object
        labels: numpy array of predicted class labels
        confidence_scores: numpy array of confidence scores
        output_path: path where to save the result
        class_names: list of class names (optional)
    """
    print(f"\nSaving point cloud with labels to: {output_path}")
    
    # Create color map for Toronto3D classes
    colors = np.array([
        [0.5, 0.5, 0.5],  # 0: Unclassified - gray
        [0.6, 0.4, 0.2],  # 1: Ground - brown
        [1.0, 0.0, 0.0],  # 2: Building - red
        [0.0, 0.0, 1.0],  # 3: Utility line - blue
        [1.0, 1.0, 0.0],  # 4: Pole - yellow
        [0.0, 0.8, 0.0],  # 5: Tree - dark green
        [0.0, 1.0, 0.0],  # 6: Vegetation/Grass - bright green
        [1.0, 0.0, 1.0],  # 7: Vehicle - magenta
    ])
    
    # Assign colors based on labels
    point_colors = colors[labels]
    
    # Modulate grass color intensity based on confidence
    grass_indices = np.where(labels == 6)[0]
    if len(grass_indices) > 0 and confidence_scores is not None:
        for idx in grass_indices:
            confidence = confidence_scores[idx]
            point_colors[idx] = point_colors[idx] * confidence + np.array([0.0, 0.5, 0.0]) * (1 - confidence)
    
    # Create a copy of the point cloud and assign new colors
    pcd_labeled = o3d.geometry.PointCloud()
    pcd_labeled.points = pcd.points
    pcd_labeled.colors = o3d.utility.Vector3dVector(point_colors)
    
    # Save the point cloud
    success = o3d.io.write_point_cloud(output_path, pcd_labeled)
    
    if success:
        print(f"✓ Point cloud saved successfully!")
        
        # Save additional metadata as text file
        metadata_path = output_path.replace('.ply', '_metadata.txt')
        with open(metadata_path, 'w') as f:
            f.write("Grass Detection Results\n")
            f.write("=" * 30 + "\n\n")
            
            # Write class statistics
            unique_labels, counts = np.unique(labels, return_counts=True)
            f.write("Class Distribution:\n")
            for label, count in zip(unique_labels, counts):
                if class_names and label < len(class_names):
                    percentage = count / len(labels) * 100
                    f.write(f"  {class_names[label]}: {count:,} points ({percentage:.1f}%)\n")
            
            f.write(f"\nTotal points: {len(labels):,}\n")
            
            # Write confidence statistics for grass
            if len(grass_indices) > 0 and confidence_scores is not None:
                f.write(f"\nGrass Confidence Statistics:\n")
                f.write(f"  Average confidence: {np.mean(confidence_scores[grass_indices]):.3f}\n")
                f.write(f"  Min confidence: {np.min(confidence_scores[grass_indices]):.3f}\n")
                f.write(f"  Max confidence: {np.max(confidence_scores[grass_indices]):.3f}\n")
            
            # Write color mapping
            f.write(f"\nClass Color Mapping:\n")
            if class_names:
                for i, name in enumerate(class_names):
                    if i < len(colors):
                        f.write(f"  {name}: RGB({colors[i][0]:.1f}, {colors[i][1]:.1f}, {colors[i][2]:.1f})\n")
        
        print(f"✓ Metadata saved to: {metadata_path}")
        
    else:
        print("✗ Failed to save point cloud")
    
    return success

def save_grass_only_point_cloud(pcd, grass_indices, confidence_scores, output_path):
    """
    Save a point cloud containing only the detected grass points
    
    Args:
        pcd: Original Open3D point cloud object
        grass_indices: indices of detected grass points
        confidence_scores: confidence scores for all points
        output_path: path where to save the grass-only point cloud
    """
    if len(grass_indices) == 0:
        print("No grass points to save!")
        return False
    
    print(f"\nSaving grass-only point cloud to: {output_path}")
    
    # Extract grass points
    points = np.asarray(pcd.points)
    grass_points = points[grass_indices]
    
    # Create grass-only point cloud
    grass_pcd = o3d.geometry.PointCloud()
    grass_pcd.points = o3d.utility.Vector3dVector(grass_points)
    
    # Color grass points based on confidence (green gradient)
    if confidence_scores is not None:
        grass_confidence = confidence_scores[grass_indices]
        # Create green gradient based on confidence
        grass_colors = np.zeros((len(grass_indices), 3))
        grass_colors[:, 1] = grass_confidence  # Green channel
        grass_colors[:, 0] = 0.2 * (1 - grass_confidence)  # Slight red for low confidence
        grass_pcd.colors = o3d.utility.Vector3dVector(grass_colors)
    else:
        # Default green color
        grass_colors = np.tile([0.0, 1.0, 0.0], (len(grass_indices), 1))
        grass_pcd.colors = o3d.utility.Vector3dVector(grass_colors)
    
    # Save grass-only point cloud
    success = o3d.io.write_point_cloud(output_path, grass_pcd)
    
    if success:
        print(f"✓ Grass-only point cloud saved successfully!")
        print(f"  Contains {len(grass_indices):,} grass points")
    else:
        print("✗ Failed to save grass-only point cloud")
    
    return success

def visualize_results(pcd, labels, class_names, grass_indices, confidence_scores=None):
    """Visualize the segmentation results with proper class colors"""
    
    # Create color map for Toronto3D classes
    colors = np.array([
        [0.5, 0.5, 0.5],  # 0: Unclassified - gray
        [0.6, 0.4, 0.2],  # 1: Ground - brown
        [1.0, 0.0, 0.0],  # 2: Building - red
        [0.0, 0.0, 1.0],  # 3: Utility line - blue
        [1.0, 1.0, 0.0],  # 4: Pole - yellow
        [0.0, 0.8, 0.0],  # 5: Tree - dark green
        [0.0, 1.0, 0.0],  # 6: Vegetation/Grass - bright green
        [1.0, 0.0, 1.0],  # 7: Vehicle - magenta
    ])
    
    # Assign colors based on labels
    point_colors = colors[labels]
    
    # If confidence scores are available, modulate the intensity for grass points
    if confidence_scores is not None and len(grass_indices) > 0:
        for idx in grass_indices:
            # Make grass color intensity proportional to confidence
            confidence = confidence_scores[idx]
            point_colors[idx] = point_colors[idx] * confidence + np.array([0.0, 0.5, 0.0]) * (1 - confidence)
    
    pcd.colors = o3d.utility.Vector3dVector(point_colors)
    
    # Print class statistics
    print("\nClass distribution:")
    unique_labels, counts = np.unique(labels, return_counts=True)
    for label, count in zip(unique_labels, counts):
        if label < len(class_names):
            percentage = count / len(labels) * 100
            print(f"  {class_names[label]}: {count:,} points ({percentage:.1f}%)")
    
    # Visualize
    print("\nDisplaying segmentation results...")
    print("Class colors:")
    for i, name in enumerate(class_names):
        if i < len(colors):
            color_str = f"RGB({colors[i][0]:.1f}, {colors[i][1]:.1f}, {colors[i][2]:.1f})"
            print(f"  {name}: {color_str}")
    print("\nPress 'Q' to close the visualization")

    o3d.visualization.draw_geometries_with_animation_callback([pcd], custom_draw)

def custom_draw(vis):
    opt = vis.get_render_option()
    opt.point_size = 1.0
    opt.background_color = np.asarray([0, 0, 0])
    return False

def print_detection_settings():
    """Print available detection settings"""
    print("\nAvailable detection settings:")
    print("  green_sensitivity: 0.0-1.0 (default 0.1)")
    print("    - Higher values = stricter green requirement")
    print("    - Lower values = more permissive green detection")
    print("  ground_percentile: 0-100 (default 25)")
    print("    - Percentage of lowest points considered 'ground'")
    print("    - Higher values = more points considered ground level")
    print("  confidence_threshold: 0.0-1.0 (default 0.6)")
    print("    - Minimum confidence score for grass classification")
    print("    - Higher values = more conservative grass detection")

if __name__ == "__main__":
    # Configuration
    ply_file_path = "final_v2.ply"  # Replace with your PLY file path
    model_path = "./logs/randlanet_toronto3d_202201071330utc.pth"
    
    # Output file paths
    output_labeled_path = "grass_detection_results.ply"        # Full point cloud with labels
    output_grass_only_path = "grass_only_results.ply"         # Only grass points
    
    # Detection parameters - adjust these for different results!
    GREEN_SENSITIVITY = 0.0      # How green does a point need to be? (0.0-1.0)
    GROUND_PERCENTILE = 50       # What % of lowest points are considered ground? (0-100)
    CONFIDENCE_THRESHOLD = 0.0   # Minimum confidence for grass detection (0.0-1.0)
    
    print("Enhanced Grass Detection with Save Functionality")
    print("=" * 50)
    print_detection_settings()
    print(f"\nCurrent settings:")
    print(f"  Green sensitivity: {GREEN_SENSITIVITY}")
    print(f"  Ground percentile: {GROUND_PERCENTILE}")
    print(f"  Confidence threshold: {CONFIDENCE_THRESHOLD}")
    print(f"\nOutput files:")
    print(f"  Labeled point cloud: {output_labeled_path}")
    print(f"  Grass-only point cloud: {output_grass_only_path}")
    
    # Check if files exist
    if not os.path.exists(ply_file_path):
        print(f"\nError: PLY file not found at {ply_file_path}")
        print("Please update the ply_file_path variable with the correct path to your PLY file")
        exit(1)
    
    print(f"\nProcessing: {ply_file_path}")
    print("Using enhanced rule-based detection with confidence scoring...")
    
    # Run enhanced grass detection
    predicted_labels, grass_indices, confidence_scores = detect_grass_in_ply_simple(
        ply_file_path, 
        model_path,
        green_sensitivity=GREEN_SENSITIVITY,
        ground_percentile=GROUND_PERCENTILE,
        confidence_threshold=CONFIDENCE_THRESHOLD
    )
    
    if predicted_labels is not None:
        print("\nGrass detection completed successfully!")
        
        # Toronto3D class names
        class_names = [
            'Unclassified',    # 0 - gray
            'Ground',          # 1 - brown  
            'Building',        # 2 - red
            'Utility line',    # 3 - blue
            'Pole',           # 4 - yellow
            'Tree',           # 5 - dark green
            'Vegetation',     # 6 - bright green (grass)
            'Vehicle'         # 7 - magenta
        ]
        
        # Load point cloud for saving and visualization
        print("Loading point cloud for saving and visualization...")
        pcd = o3d.io.read_point_cloud(ply_file_path)
        
        # Save results
        print("\n" + "="*50)
        print("SAVING RESULTS")
        print("="*50)
        
        # Save full labeled point cloud
        save_point_cloud_with_labels(
            pcd, predicted_labels, confidence_scores, 
            output_labeled_path, class_names
        )
        
        # Save grass-only point cloud
        save_grass_only_point_cloud(
            pcd, grass_indices, confidence_scores, 
            output_grass_only_path
        )
        
        # Visualize results
        print("\n" + "="*50)
        print("VISUALIZATION")
        print("="*50)
        visualize_results(pcd, predicted_labels, class_names, grass_indices, confidence_scores)
        
    else:
        print("\nGrass detection failed!")
    
    print("\nFiles saved:")
    print(f"  - {output_labeled_path} (full point cloud with class labels)")
    print(f"  - {output_grass_only_path} (grass points only)")
    print(f"  - {output_labeled_path.replace('.ply', '_metadata.txt')} (detection statistics)")
    print("\nTo adjust detection sensitivity, modify the parameters at the top of the script!")

In [9]:
import os
import numpy as np
import open3d as o3d
from collections import defaultdict

def load_point_cloud_with_colors(ply_file_path):
    """
    Load the point cloud with colors
    
    Args:
        ply_file_path (str): Path to the PLY file
        
    Returns:
        pcd: Open3D point cloud
        points: numpy array of points
        colors: numpy array of colors
    """
    print(f"Loading point cloud from: {ply_file_path}")
    
    # Load point cloud
    pcd = o3d.io.read_point_cloud(ply_file_path)
    
    if len(pcd.points) == 0:
        print("Error: Could not load point cloud or point cloud is empty")
        return None, None, None
    
    points = np.asarray(pcd.points)
    colors = np.asarray(pcd.colors)
    
    print(f"Loaded {len(points):,} points with colors")
    
    return pcd, points, colors

def is_green_point(color, green_threshold=0.5):
    """
    Check if a color is green (grass-like)
    
    Args:
        color: RGB color array
        green_threshold: minimum green channel value to consider as green
        
    Returns:
        bool: True if the point is green
    """
    # Check if green channel is dominant and above threshold
    r, g, b = color
    return g > green_threshold and g > r and g > b

def create_voxel_grid(points, voxel_size=0.1):
    """
    Create voxel indices for all points
    
    Args:
        points: numpy array of 3D points
        voxel_size: size of each voxel
        
    Returns:
        voxel_indices: array of voxel coordinates for each point
    """
    print(f"Creating voxel grid with voxel size: {voxel_size}")
    
    # Convert points to voxel coordinates
    voxel_indices = np.floor(points / voxel_size).astype(int)
    
    # Get voxel grid bounds
    min_voxel = np.min(voxel_indices, axis=0)
    max_voxel = np.max(voxel_indices, axis=0)
    grid_size = max_voxel - min_voxel + 1
    
    print(f"Voxel grid bounds: {min_voxel} to {max_voxel}")
    print(f"Grid size: {grid_size} voxels")
    
    return voxel_indices

def find_greenest_clusters(points, colors, voxel_size=0.1, top_clusters=3, green_threshold=0.5):
    """
    Find the greenest clusters using voxel-based analysis
    
    Args:
        points: numpy array of 3D points
        colors: numpy array of RGB colors
        voxel_size: size of voxels for clustering
        top_clusters: number of top clusters to identify
        green_threshold: minimum green value to consider as green
        
    Returns:
        top_green_voxels: set of voxel keys for top green clusters
        voxel_dict: dictionary mapping voxel keys to point indices
    """
    print(f"\nFinding greenest clusters...")
    print(f"Voxel size: {voxel_size}")
    print(f"Green threshold: {green_threshold}")
    print(f"Looking for top {top_clusters} clusters")
    
    # Create voxel grid
    voxel_indices = create_voxel_grid(points, voxel_size)
    
    # Map voxel index → list of point indices
    print("Building voxel dictionary...")
    voxel_dict = defaultdict(list)
    for idx, voxel in enumerate(voxel_indices):
        voxel_dict[tuple(voxel)].append(idx)
    
    print(f"Created {len(voxel_dict)} occupied voxels")
    
    # Calculate greenness score for each voxel
    print("Calculating voxel greenness scores...")
    voxel_greenness = {}
    green_voxel_count = 0
    
    for voxel_key, indices in voxel_dict.items():
        # Get colors of all points in this voxel
        voxel_colors = colors[indices]
        
        # Count green points in this voxel
        green_points = sum(1 for color in voxel_colors if is_green_point(color, green_threshold))
        total_points = len(indices)
        
        # Calculate greenness ratio
        greenness_ratio = green_points / total_points
        
        # Only consider voxels that are majority green
        if greenness_ratio > 0.5:
            voxel_greenness[voxel_key] = greenness_ratio
            green_voxel_count += 1
    
    print(f"Found {green_voxel_count} green voxels out of {len(voxel_dict)} total voxels")
    
    if len(voxel_greenness) == 0:
        print("Warning: No green voxels found!")
        return set(), voxel_dict
    
    # Calculate cluster connectivity scores
    print("Calculating green cluster connectivity...")
    green_cluster_scores = {}
    neighbor_offsets = [
        (-1, 0, 0), (1, 0, 0),
        (0, -1, 0), (0, 1, 0),
        (0, 0, -1), (0, 0, 1)
    ]
    
    for voxel_key in voxel_greenness:
        # Count green neighbors
        green_neighbors = 0
        total_neighbor_greenness = 0
        
        for offset in neighbor_offsets:
            neighbor_key = tuple(np.array(voxel_key) + np.array(offset))
            if neighbor_key in voxel_greenness:
                green_neighbors += 1
                total_neighbor_greenness += voxel_greenness[neighbor_key]
        
        # Score = own greenness * number of green neighbors * average neighbor greenness
        if green_neighbors > 0:
            avg_neighbor_greenness = total_neighbor_greenness / green_neighbors
            cluster_score = voxel_greenness[voxel_key] * green_neighbors * avg_neighbor_greenness
        else:
            cluster_score = voxel_greenness[voxel_key]  # Isolated green voxel
        
        green_cluster_scores[voxel_key] = cluster_score
    
    print(f"Scored {len(green_cluster_scores)} green voxels for clustering")
    
    # Get top greenest clusters
    top_green_clusters = sorted(green_cluster_scores.items(), key=lambda x: -x[1])[:top_clusters]
    top_green_voxels = set([v[0] for v in top_green_clusters])
    
    print(f"\nTop {len(top_green_clusters)} greenest clusters:")
    for i, (voxel, score) in enumerate(top_green_clusters):
        center = (np.array(voxel) + 0.5) * voxel_size
        point_count = len(voxel_dict[voxel])
        greenness = voxel_greenness[voxel]
        print(f"  Cluster {i+1}: Center ({center[0]:.3f}, {center[1]:.3f}, {center[2]:.3f})")
        print(f"             Score: {score:.3f}, Greenness: {greenness:.1%}, Points: {point_count}")
    
    return top_green_voxels, voxel_dict

def visualize_greenest_clusters(pcd, points, top_green_voxels, voxel_dict, highlight_color=[0, 1, 0], background_color=[0.2, 0.2, 0.2]):
    """
    Visualize the top greenest clusters
    
    Args:
        pcd: Open3D point cloud
        points: numpy array of points
        top_green_voxels: set of voxel keys for top clusters
        voxel_dict: dictionary mapping voxel keys to point indices
        highlight_color: RGB color for top clusters
        background_color: RGB color for other points
    """
    print(f"\nVisualizing greenest clusters...")
    print(f"Highlighting {len(top_green_voxels)} top clusters in bright green")
    print(f"Other points shown in gray")
    
    # Assign final colors
    new_colors = np.full((len(points), 3), background_color)  # Start with gray background
    highlighted_points = 0
    
    for voxel_key, indices in voxel_dict.items():
        if voxel_key in top_green_voxels:
            for i in indices:
                new_colors[i] = highlight_color  # bright green
                highlighted_points += 1
    
    print(f"Highlighted {highlighted_points:,} points in top greenest clusters")
    
    # Update point cloud colors
    pcd.colors = o3d.utility.Vector3dVector(new_colors)
    
    # Visualize
    print("\nPress 'Q' to close the visualization")
    o3d.visualization.draw_geometries_with_animation_callback([pcd], custom_draw)

def save_cluster_results(pcd, output_path, top_green_voxels, voxel_dict, voxel_size):
    """
    Save the cluster analysis results
    
    Args:
        pcd: Open3D point cloud with cluster colors
        output_path: path to save the results
        top_green_voxels: set of top cluster voxel keys
        voxel_dict: voxel dictionary
        voxel_size: voxel size used
    """
    print(f"\nSaving greenest clusters to: {output_path}")
    
    # Save colored point cloud
    success = o3d.io.write_point_cloud(output_path, pcd)
    
    if success:
        print("✓ Greenest clusters saved successfully!")
        
        # Save cluster metadata
        metadata_path = output_path.replace('.ply', '_cluster_info.txt')
        with open(metadata_path, 'w') as f:
            f.write("Greenest Cluster Analysis Results\n")
            f.write("=" * 40 + "\n\n")
            f.write(f"Voxel size: {voxel_size}\n")
            f.write(f"Number of top clusters: {len(top_green_voxels)}\n\n")
            
            f.write("Top Greenest Cluster Centers (x, y, z):\n")
            for i, voxel in enumerate(top_green_voxels):
                center = (np.array(voxel) + 0.5) * voxel_size
                point_count = len(voxel_dict[voxel])
                f.write(f"  Cluster {i+1}: ({center[0]:.3f}, {center[1]:.3f}, {center[2]:.3f}) - {point_count} points\n")
        
        print(f"✓ Cluster metadata saved to: {metadata_path}")
    else:
        print("✗ Failed to save cluster results")
    
    return success

def custom_draw(vis):
    """Custom visualization settings"""
    opt = vis.get_render_option()
    opt.point_size = 1.0
    opt.background_color = np.asarray([0, 0, 0])
    return False

if __name__ == "__main__":
    # Configuration
    input_file = "grass_detection_results.ply"  # Input file from grass detection
    output_file = "greenest_clusters.ply"       # Output file for cluster results
    
    # Clustering parameters
    VOXEL_SIZE = 0.4        # Size of voxels for clustering
    TOP_CLUSTERS = 3        # Number of top clusters to identify
    GREEN_THRESHOLD = 0.5   # Minimum green channel value to consider as green
    
    # Visualization colors
    HIGHLIGHT_COLOR = [0, 1, 0]      # Bright green for top clusters
    BACKGROUND_COLOR = [0.2, 0.2, 0.2]  # Dark gray for other points
    
    print("Greenest Cluster Analysis")
    print("=" * 30)
    print(f"Input file: {input_file}")
    print(f"Output file: {output_file}")
    print(f"Voxel size: {VOXEL_SIZE}")
    print(f"Green threshold: {GREEN_THRESHOLD}")
    print(f"Top clusters to find: {TOP_CLUSTERS}")
    
    # Check if input file exists
    if not os.path.exists(input_file):
        print(f"\nError: Input file not found at {input_file}")
        print("Please make sure you have run the grass detection script first")
        print("and that the grass_detection_results.ply file exists")
        exit(1)
    
    # Load point cloud with colors
    pcd, points, colors = load_point_cloud_with_colors(input_file)
    
    if pcd is None:
        print("Failed to load point cloud!")
        exit(1)
    
    # Find greenest clusters
    top_green_voxels, voxel_dict = find_greenest_clusters(
        points, colors, 
        voxel_size=VOXEL_SIZE, 
        top_clusters=TOP_CLUSTERS,
        green_threshold=GREEN_THRESHOLD
    )
    
    if len(top_green_voxels) == 0:
        print("No green clusters found! Try adjusting the green threshold or voxel size.")
        exit(1)
    
    # Print cluster center coordinates
    print("\n" + "="*50)
    print("TOP GREENEST CLUSTER CENTERS")
    print("="*50)
    print("Greenest cluster centers (x, y, z):")
    for i, voxel in enumerate(top_green_voxels):
        center = (np.array(voxel) + 0.5) * VOXEL_SIZE
        print(f"  Cluster {i+1}: {center[0]:.3f}, {center[1]:.3f}, {center[2]:.3f}")
    
    # Visualize results
    visualize_greenest_clusters(
        pcd, points, top_green_voxels, voxel_dict,
        highlight_color=HIGHLIGHT_COLOR,
        background_color=BACKGROUND_COLOR
    )
    
    # Save results
    save_cluster_results(pcd, output_file, top_green_voxels, voxel_dict, VOXEL_SIZE)
    
    print(f"\nAnalysis complete!")
    print(f"Results saved to: {output_file}")
    print(f"Cluster info saved to: {output_file.replace('.ply', '_cluster_info.txt')}")
    print(f"\nTo adjust clustering:")
    print(f"  - Decrease VOXEL_SIZE for more detailed clusters")
    print(f"  - Increase GREEN_THRESHOLD for stricter green detection")
    print(f"  - Change TOP_CLUSTERS to find more/fewer clusters")