In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Intelligent Foreground-Middleground-Background Segmentation System - Complete Version

This system integrates:
1. Theory-based initial partitioning
2. K-means boundary optimization
3. Intelligent connectivity processing
4. Multi-strategy layer decision
5. Intelligent hole filling post-processing
6. Forced semantic rules (e.g., sky is always background)

Author: Kai
Version: 2.1 (Updated with forced semantic rules)
"""

import numpy as np
import cv2
from scipy import ndimage
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import json
import os
from dataclasses import dataclass
from typing import Dict, Tuple, List, Optional
import warnings
warnings.filterwarnings('ignore')

# Google Colab compatibility check
try:
    from google.colab import files
    IN_COLAB = True
except ImportError:
    IN_COLAB = False
    print("Note: Not in Google Colab environment, file upload functionality will be skipped")

# ==================== Data Structure Definitions ====================

@dataclass
class ObjectFeatures:
    """Comprehensive feature description of an object"""
    # Basic attributes
    num_pixels: int
    semantic_class: int
    is_closed: bool
    is_countable: bool

    # Shape features
    effective_aspect_ratio: float
    bbox_aspect_ratio: float
    compactness: float
    elongation: float

    # Depth features
    depth_mean: float
    depth_std: float
    depth_range: float
    vertical_depth_gradient: float
    horizontal_depth_gradient: float

    # Spatial position
    centroid_y: float
    centroid_x: float
    bottom_y: int
    relative_y_position: float

    # Neighborhood relationships
    touching_ground: bool
    touching_sky: bool
    neighbor_classes: List[int]

    # Texture and structure
    edge_density: float
    internal_structure_complexity: float

# ==================== Forced Semantic Rules ====================

class ForcedSemanticRules:
    """
    Forced Semantic Rules System
    
    This class defines semantic categories that should ALWAYS be assigned
    to specific layers, regardless of depth or other analysis.
    
    For example: Sky is ALWAYS background.
    """
    
    def __init__(self, idx_to_info: Dict):
        self.idx_to_info = idx_to_info
        self.forced_background_classes = []
        self.forced_foreground_classes = []
        self.forced_middleground_classes = []
        
        self._initialize_forced_rules()
    
    def _initialize_forced_rules(self):
        """Initialize forced semantic rules based on class names"""
        for idx, info in self.idx_to_info.items():
            name = info['name'].lower()
            
            # SKY is ALWAYS BACKGROUND - this is a forced rule
            if 'sky' in name:
                self.forced_background_classes.append(idx)
                print(f"  [Forced Rule] Class '{info['name']}' (idx={idx}) -> ALWAYS BACKGROUND")
            
            # Sea/ocean is always background
            elif 'sea' in name and 'seat' not in name:
                self.forced_background_classes.append(idx)
                print(f"  [Forced Rule] Class '{info['name']}' (idx={idx}) -> ALWAYS BACKGROUND")
    
    def get_forced_layer(self, semantic_class: int) -> Optional[int]:
        if semantic_class in self.forced_background_classes:
            return 2
        elif semantic_class in self.forced_foreground_classes:
            return 0
        elif semantic_class in self.forced_middleground_classes:
            return 1
        return None
    
    def is_forced(self, semantic_class: int) -> bool:
        return self.get_forced_layer(semantic_class) is not None


# ==================== Intelligent Hole Filling System ====================

class IntelligentHoleFilling:
    """Intelligent Hole Filling System"""

    min_hole_size = 10
    max_hole_size = 5000
    depth_threshold_ratio = 0.15

    def __init__(self, depth_map: np.ndarray, fmb_map: np.ndarray):
        self.depth_map = depth_map
        self.fmb_map = fmb_map
        self.H, self.W = depth_map.shape
        self.neighbor_radius = 5

    def process(self) -> Tuple[np.ndarray, Dict]:
        filled_map = self.fmb_map.copy()
        fill_info = {
            'total_holes_detected': 0,
            'holes_filled': 0,
            'holes_preserved': 0,
            'details': []
        }

        for layer in [0, 1, 2]:
            layer_name = ['Foreground', 'Middleground', 'Background'][layer]
            print(f"\nProcessing layer {layer} ({layer_name})...")

            holes = self._detect_holes_in_layer(filled_map, layer)
            fill_info['total_holes_detected'] += len(holes)
            print(f"  Detected {len(holes)} holes")

            for hole_id, hole_mask in enumerate(holes):
                hole_size = np.sum(hole_mask)

                if hole_size < self.min_hole_size or hole_size > self.max_hole_size:
                    continue

                should_fill, analysis = self._analyze_hole(hole_mask, layer)

                if should_fill:
                    filled_map[hole_mask] = layer
                    fill_info['holes_filled'] += 1
                    print(f"    Hole {hole_id}: Filled (size={hole_size}, depth_diff={analysis['depth_difference']:.2f})")
                else:
                    fill_info['holes_preserved'] += 1
                    print(f"    Hole {hole_id}: Preserved (size={hole_size}, depth_diff={analysis['depth_difference']:.2f})")

                fill_info['details'].append({
                    'layer': layer, 'hole_id': hole_id, 'size': hole_size,
                    'filled': should_fill, 'analysis': analysis
                })

        print(f"\nFilling statistics:")
        print(f"  Total holes detected: {fill_info['total_holes_detected']}")
        print(f"  Holes filled: {fill_info['holes_filled']}")
        print(f"  Holes preserved: {fill_info['holes_preserved']}")

        return filled_map, fill_info

    def _detect_holes_in_layer(self, fmb_map: np.ndarray, layer: int) -> List[np.ndarray]:
        layer_mask = (fmb_map == layer).astype(np.uint8)
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
        layer_closed = cv2.morphologyEx(layer_mask, cv2.MORPH_CLOSE, kernel)
        layer_filled = self._fill_holes_morphological(layer_closed)
        holes_mask = layer_filled & (~layer_mask)
        labeled_holes, num_holes = ndimage.label(holes_mask)

        holes = []
        for hole_id in range(1, num_holes + 1):
            hole_mask = labeled_holes == hole_id
            if self._is_hole_surrounded(hole_mask, fmb_map, layer):
                holes.append(hole_mask)
        return holes

    def _fill_holes_morphological(self, binary_mask: np.ndarray) -> np.ndarray:
        filled = binary_mask.copy()
        h, w = binary_mask.shape
        mask = np.zeros((h + 2, w + 2), np.uint8)
        cv2.floodFill(filled, mask, (0, 0), 1)
        filled_inv = cv2.bitwise_not(filled)
        result = binary_mask | filled_inv
        return result.astype(np.uint8)

    def _is_hole_surrounded(self, hole_mask: np.ndarray, fmb_map: np.ndarray, layer: int) -> bool:
        dilated = ndimage.binary_dilation(hole_mask, iterations=2)
        boundary = dilated & (~hole_mask)
        boundary_values = fmb_map[boundary]
        unique_values, counts = np.unique(boundary_values, return_counts=True)
        target_ratio = 0.0
        for val, count in zip(unique_values, counts):
            if val == layer:
                target_ratio = count / np.sum(counts)
                break
        return target_ratio > 0.8

    def _analyze_hole(self, hole_mask: np.ndarray, surrounding_layer: int) -> Tuple[bool, Dict]:
        hole_depths = self.depth_map[hole_mask]
        hole_mean_depth = np.mean(hole_depths)
        hole_std_depth = np.std(hole_depths)

        dilated = ndimage.binary_dilation(hole_mask, iterations=self.neighbor_radius)
        neighbor_mask = dilated & (~hole_mask) & (self.fmb_map == surrounding_layer)

        if np.sum(neighbor_mask) == 0:
            return False, {
                'depth_difference': float('inf'),
                'hole_mean_depth': hole_mean_depth,
                'neighbor_mean_depth': None,
                'decision_reason': 'no_valid_neighbors'
            }

        neighbor_depths = self.depth_map[neighbor_mask]
        neighbor_mean_depth = np.mean(neighbor_depths)
        neighbor_std_depth = np.std(neighbor_depths)
        depth_difference = abs(hole_mean_depth - neighbor_mean_depth)

        local_depths = np.concatenate([hole_depths, neighbor_depths])
        local_depth_range = np.max(local_depths) - np.min(local_depths)
        if local_depth_range < 1:
            local_depth_range = 1

        normalized_difference = depth_difference / local_depth_range
        should_fill = normalized_difference < self.depth_threshold_ratio

        if hole_std_depth > neighbor_std_depth * 2:
            should_fill = False
            decision_reason = 'high_internal_variance'
        elif should_fill:
            decision_reason = 'small_depth_difference'
        else:
            decision_reason = 'large_depth_difference'

        analysis = {
            'depth_difference': depth_difference,
            'normalized_difference': normalized_difference,
            'hole_mean_depth': hole_mean_depth,
            'hole_std_depth': hole_std_depth,
            'neighbor_mean_depth': neighbor_mean_depth,
            'neighbor_std_depth': neighbor_std_depth,
            'local_depth_range': local_depth_range,
            'decision_reason': decision_reason
        }
        return should_fill, analysis

# ==================== Intelligent Layer Decision System ====================

class IntelligentLayerDecisionSystem:
    """Intelligent Integrated Layer Decision System"""

    def __init__(self, depth_map, semantic_map, fmb_map, openness, countability, idx_to_info,
                 forced_rules: ForcedSemanticRules = None):
        self.depth_map = depth_map
        self.semantic_map = semantic_map
        self.fmb_map = fmb_map
        self.openness = openness
        self.countability = countability
        self.idx_to_info = idx_to_info
        self.H, self.W = depth_map.shape
        self.forced_rules = forced_rules

        self.strategy_weights = {
            'bottom_sampling': 0.3,
            'horizontal_analysis': 0.3,
            'vertical_structure': 0.2,
            'spatial_context': 0.1,
            'depth_consistency': 0.1
        }
        self._initialize_semantic_categories()

    def _initialize_semantic_categories(self):
        """Initialize spatial expectations for semantic categories - UPDATED for new config"""
        self.sky_classes = []
        self.ground_classes = []
        self.tall_classes = []

        for idx, info in self.idx_to_info.items():
            name = info['name'].lower()
            
            if 'sky' in name:
                self.sky_classes.append(idx)
            elif any(word in name for word in [
                'road', 'route', 'pavement', 'sidewalk', 'grass', 'ground', 'land', 
                'floor', 'flooring', 'path', 'sand', 'dirt', 'track', 'runway',
                'field', 'earth', 'rug', 'carpet', 'stairway', 'staircase', 
                'escalator', 'step', 'stair', 'stairs'
            ]):
                self.ground_classes.append(idx)
            elif any(word in name for word in [
                'tree', 'building', 'edifice', 'tower', 'pole', 'light', 'lamp',
                'palm', 'skyscraper', 'column', 'pillar', 'streetlight', 'chandelier'
            ]):
                self.tall_classes.append(idx)

    def extract_object_features(self, region_mask) -> ObjectFeatures:
        """Extract comprehensive features of an object"""
        y_coords, x_coords = np.where(region_mask)
        if len(y_coords) == 0:
            return None

        num_pixels = len(y_coords)
        y_min, y_max = np.min(y_coords), np.max(y_coords)
        x_min, x_max = np.min(x_coords), np.max(x_coords)

        semantic_values = self.semantic_map[region_mask]
        semantic_class = np.bincount(semantic_values).argmax()
        is_closed = self.openness[semantic_class - 1] == 1 if semantic_class > 0 else False
        is_countable = self.countability[semantic_class - 1] == 1 if semantic_class > 0 else False

        effective_aspect_ratio = self._calculate_effective_aspect_ratio(region_mask)
        bbox_aspect_ratio = (y_max - y_min + 1) / (x_max - x_min + 1)
        compactness = num_pixels / ((y_max - y_min + 1) * (x_max - x_min + 1))

        if num_pixels > 5:
            coords = np.column_stack((x_coords - np.mean(x_coords), y_coords - np.mean(y_coords)))
            cov_matrix = np.cov(coords.T)
            eigenvalues = np.linalg.eigvalsh(cov_matrix)
            elongation = np.sqrt(max(eigenvalues) / (min(eigenvalues) + 1e-6))
        else:
            elongation = 1.0

        depths = self.depth_map[region_mask]
        depth_mean = np.mean(depths)
        depth_std = np.std(depths)
        depth_range = np.max(depths) - np.min(depths)

        vertical_gradient = self._calculate_vertical_gradient(region_mask, depths, y_coords)
        horizontal_gradient = self._calculate_horizontal_gradient(region_mask, depths, x_coords)

        centroid_y = np.mean(y_coords)
        centroid_x = np.mean(x_coords)
        bottom_y = y_max
        relative_y_position = centroid_y / self.H

        touching_ground = y_max >= self.H - 5
        touching_sky = y_min <= 5
        neighbor_classes = self._get_neighbor_classes(region_mask)
        edge_density = self._calculate_edge_density(region_mask)
        internal_structure = self._calculate_internal_structure(region_mask)

        return ObjectFeatures(
            num_pixels=num_pixels, semantic_class=semantic_class,
            is_closed=is_closed, is_countable=is_countable,
            effective_aspect_ratio=effective_aspect_ratio, bbox_aspect_ratio=bbox_aspect_ratio,
            compactness=compactness, elongation=elongation,
            depth_mean=depth_mean, depth_std=depth_std, depth_range=depth_range,
            vertical_depth_gradient=vertical_gradient, horizontal_depth_gradient=horizontal_gradient,
            centroid_y=centroid_y, centroid_x=centroid_x, bottom_y=bottom_y,
            relative_y_position=relative_y_position,
            touching_ground=touching_ground, touching_sky=touching_sky,
            neighbor_classes=neighbor_classes,
            edge_density=edge_density, internal_structure_complexity=internal_structure
        )

    def intelligent_layer_decision(self, region_mask) -> Tuple[int, float, Dict]:
        """Main function for intelligent layer decision"""
        features = self.extract_object_features(region_mask)
        if features is None:
            return None, 0, {}

        # CHECK FORCED RULES FIRST
        if self.forced_rules and self.forced_rules.is_forced(features.semantic_class):
            forced_layer = self.forced_rules.get_forced_layer(features.semantic_class)
            class_name = self.idx_to_info.get(features.semantic_class, {}).get('name', 'Unknown')
            print(f"    [FORCED] Class '{class_name}' -> Layer {forced_layer} (Background)")
            return forced_layer, 1.0, {'reason': 'forced_semantic_rule', 'class_name': class_name, 'features': features}

        if not features.is_closed:
            dominant_label = self._get_dominant_kmeans_label(region_mask)
            return dominant_label, 0.95, {'reason': 'open_object', 'features': features}

        strategies_results = {}
        strategies_results['bottom'] = self._adaptive_bottom_sampling_strategy(region_mask, features)
        strategies_results['horizontal'] = self._intelligent_horizontal_strategy(region_mask, features)
        strategies_results['vertical'] = self._vertical_structure_strategy(region_mask, features)
        strategies_results['spatial'] = self._spatial_context_strategy(region_mask, features)
        strategies_results['depth'] = self._depth_consistency_strategy(region_mask, features)

        return self._intelligent_fusion(strategies_results, features)

    def _adaptive_bottom_sampling_strategy(self, region_mask, features):
        bottom_ratio = self._get_intelligent_bottom_ratio(features)
        weight_modifier = 1.0
        if not features.touching_ground:
            weight_modifier = 0.5
        if features.effective_aspect_ratio < 0.5:
            weight_modifier *= 0.7

        y_coords, x_coords = np.where(region_mask)
        num_bottom_pixels = max(int(features.num_pixels * bottom_ratio), min(10, features.num_pixels), min(features.num_pixels, 50))
        y_sorted_indices = np.argsort(y_coords)[::-1]
        bottom_indices = y_sorted_indices[:num_bottom_pixels]

        bottom_labels = []
        bottom_depths = []
        for idx in bottom_indices:
            y, x = y_coords[idx], x_coords[idx]
            bottom_labels.append(self.fmb_map[y, x])
            bottom_depths.append(self.depth_map[y, x])

        layer_scores = self._calculate_layer_scores(bottom_labels, bottom_depths, features)
        best_layer = max(layer_scores, key=layer_scores.get)
        confidence = layer_scores[best_layer] / sum(layer_scores.values())

        return {'layer': best_layer, 'confidence': confidence * weight_modifier, 'method': 'bottom_sampling',
                'details': {'bottom_ratio': bottom_ratio, 'num_samples': num_bottom_pixels, 'layer_scores': layer_scores}}

    def _intelligent_horizontal_strategy(self, region_mask, features):
        if features.effective_aspect_ratio < 0.7 or features.elongation > 2.0:
            return self._row_by_row_analysis(region_mask, features)
        else:
            return self._optimized_horizontal_analysis(region_mask, features)

    def _row_by_row_analysis(self, region_mask, features):
        y_coords, x_coords = np.where(region_mask)
        y_min, y_max = np.min(y_coords), np.max(y_coords)
        row_analyses = []
        window_size = max(3, int((y_max - y_min) * 0.1))

        for y in range(y_min, y_max + 1):
            row_mask = (y_coords == y)
            if np.sum(row_mask) > 0:
                window_start = max(y_min, y - window_size // 2)
                window_end = min(y_max, y + window_size // 2)
                window_mask = (y_coords >= window_start) & (y_coords <= window_end)
                window_x = x_coords[window_mask]
                window_y = y_coords[window_mask]
                window_labels = self.fmb_map[window_y, window_x]
                unique_labels, counts = np.unique(window_labels, return_counts=True)
                dominant_label = unique_labels[np.argmax(counts)]
                confidence = counts[np.argmax(counts)] / len(window_labels)
                row_analyses.append({'y': y, 'dominant_layer': dominant_label, 'confidence': confidence, 'pixel_count': np.sum(row_mask)})

        vertical_pattern = self._analyze_vertical_pattern_advanced(row_analyses)
        final_layer, final_confidence = self._decide_from_row_analysis(row_analyses, vertical_pattern, features)

        return {'layer': final_layer, 'confidence': final_confidence, 'method': 'row_by_row',
                'details': {'num_rows': len(row_analyses), 'vertical_pattern': vertical_pattern}}

    def _vertical_structure_strategy(self, region_mask, features):
        y_coords, x_coords = np.where(region_mask)
        y_min, y_max = np.min(y_coords), np.max(y_coords)
        height = y_max - y_min + 1
        top_threshold = y_min + height * 0.33
        bottom_threshold = y_max - height * 0.33

        sections = {
            'top': (y_coords <= top_threshold),
            'middle': (y_coords > top_threshold) & (y_coords < bottom_threshold),
            'bottom': (y_coords >= bottom_threshold)
        }

        section_analysis = {}
        for section_name, section_mask in sections.items():
            if np.sum(section_mask) > 0:
                section_x = x_coords[section_mask]
                section_y = y_coords[section_mask]
                section_labels = self.fmb_map[section_y, section_x]
                section_depths = self.depth_map[section_y, section_x]
                unique_labels, counts = np.unique(section_labels, return_counts=True)
                dominant_label = unique_labels[np.argmax(counts)]
                section_analysis[section_name] = {
                    'dominant_layer': dominant_label,
                    'layer_confidence': counts[np.argmax(counts)] / np.sum(counts),
                    'mean_depth': np.mean(section_depths),
                    'depth_std': np.std(section_depths)
                }

        if 'bottom' in section_analysis and 'top' in section_analysis:
            if section_analysis['bottom']['dominant_layer'] == section_analysis['top']['dominant_layer']:
                final_layer = section_analysis['bottom']['dominant_layer']
                confidence = 0.8
            else:
                final_layer = section_analysis['bottom']['dominant_layer']
                confidence = 0.6
        else:
            available_section = list(section_analysis.values())[0]
            final_layer = available_section['dominant_layer']
            confidence = available_section['layer_confidence'] * 0.7

        return {'layer': final_layer, 'confidence': confidence, 'method': 'vertical_structure', 'details': section_analysis}

    def _spatial_context_strategy(self, region_mask, features):
        if features.touching_sky and features.semantic_class in self.sky_classes:
            return {'layer': 2, 'confidence': 0.9, 'method': 'spatial_context', 'details': {'reason': 'touching_sky'}}

        if features.touching_ground and features.relative_y_position > 0.7:
            return {'layer': 0, 'confidence': 0.7, 'method': 'spatial_context', 'details': {'reason': 'touching_ground_bottom'}}

        if features.relative_y_position < 0.3:
            suggested_layer = 2
            confidence = 0.4
        elif features.relative_y_position > 0.7:
            suggested_layer = 0
            confidence = 0.4
        else:
            suggested_layer = 1
            confidence = 0.3

        neighbor_layers = self._analyze_neighbor_layers(region_mask, features)
        if neighbor_layers:
            most_common_neighbor = max(neighbor_layers, key=neighbor_layers.count)
            if neighbor_layers.count(most_common_neighbor) > len(neighbor_layers) * 0.6:
                suggested_layer = most_common_neighbor
                confidence += 0.2

        return {'layer': suggested_layer, 'confidence': min(confidence, 0.6), 'method': 'spatial_context',
                'details': {'relative_position': features.relative_y_position, 'neighbor_influence': len(neighbor_layers) > 0}}

    def _depth_consistency_strategy(self, region_mask, features):
        layer_depth_profiles = self._get_layer_depth_profiles()
        match_scores = {}
        for layer, (depth_mean, depth_std) in layer_depth_profiles.items():
            z_score = abs(features.depth_mean - depth_mean) / (depth_std + 1e-6)
            match_scores[layer] = np.exp(-0.5 * z_score ** 2)

        if features.depth_std > 20:
            for layer in match_scores:
                match_scores[layer] *= 0.7

        best_layer = max(match_scores, key=match_scores.get)
        confidence = match_scores[best_layer]
        if features.depth_std < 10:
            confidence = min(confidence * 1.2, 0.9)

        return {'layer': best_layer, 'confidence': confidence, 'method': 'depth_consistency',
                'details': {'match_scores': match_scores, 'depth_std': features.depth_std}}

    def _intelligent_fusion(self, strategies_results, features):
        layer_votes = {0: 0, 1: 0, 2: 0}
        strategy_details = {}
        adjusted_weights = self._adjust_strategy_weights(features)

        for strategy_name, result in strategies_results.items():
            layer = result['layer']
            confidence = result['confidence']
            weight = adjusted_weights.get(strategy_name, 0.2)
            layer_votes[layer] += confidence * weight
            strategy_details[strategy_name] = {'layer': layer, 'confidence': confidence, 'weight': weight, 'weighted_vote': confidence * weight}

        total_votes = sum(layer_votes.values())
        if total_votes > 0:
            for layer in layer_votes:
                layer_votes[layer] /= total_votes

        final_layer = max(layer_votes, key=layer_votes.get)
        base_confidence = layer_votes[final_layer]

        agreeing_strategies = sum(1 for r in strategies_results.values() if r['layer'] == final_layer)
        if agreeing_strategies >= 3:
            consistency_bonus = 0.1 * (agreeing_strategies - 2)
            base_confidence = min(base_confidence + consistency_bonus, 0.95)

        vote_entropy = -sum(v * np.log(v + 1e-10) for v in layer_votes.values() if v > 0)
        max_entropy = np.log(3)
        if vote_entropy > max_entropy * 0.8:
            base_confidence *= 0.8

        final_confidence = self._apply_special_rules(final_layer, base_confidence, features, strategy_details)

        return final_layer, final_confidence, {
            'layer_votes': layer_votes, 'strategy_details': strategy_details,
            'adjusted_weights': adjusted_weights, 'agreeing_strategies': agreeing_strategies,
            'vote_entropy': vote_entropy, 'features': features
        }

    def _adjust_strategy_weights(self, features):
        weights = self.strategy_weights.copy()
        if features.effective_aspect_ratio < 0.5:
            weights['horizontal_analysis'] *= 1.5
            weights['bottom_sampling'] *= 0.7
        elif features.effective_aspect_ratio > 2.0:
            weights['vertical_structure'] *= 1.3

        if features.num_pixels < 100:
            weights['vertical_structure'] *= 0.5
            weights['spatial_context'] *= 1.2
        elif features.num_pixels > 5000:
            for key in weights:
                weights[key] *= 1.1

        if features.touching_ground:
            weights['bottom_sampling'] *= 1.3
        if features.touching_sky:
            weights['spatial_context'] *= 1.4

        if features.depth_std < 10:
            weights['depth_consistency'] *= 1.5
        elif features.depth_std > 50:
            weights['depth_consistency'] *= 0.5

        total_weight = sum(weights.values())
        for key in weights:
            weights[key] /= total_weight
        return weights

    def _apply_special_rules(self, layer, confidence, features, strategy_details):
        if features.is_countable and features.depth_range > 50:
            if confidence < 0.6:
                confidence *= 0.8

        if features.relative_y_position < 0.1 and layer != 2:
            confidence *= 0.8
        elif features.relative_y_position > 0.9 and layer != 0:
            confidence *= 0.8

        if 'spatial_context' in strategy_details:
            spatial_conf = strategy_details['spatial_context']['confidence']
            if spatial_conf > 0.7:
                confidence = confidence * 0.8 + spatial_conf * 0.2

        return min(confidence, 0.95)

    # ========== Helper Methods ==========

    def _calculate_effective_aspect_ratio(self, region_mask):
        y_coords, x_coords = np.where(region_mask)
        if len(y_coords) == 0:
            return 1.0
        unique_rows = len(np.unique(y_coords))
        unique_cols = len(np.unique(x_coords))
        return unique_rows / unique_cols if unique_cols > 0 else 1.0

    def _get_intelligent_bottom_ratio(self, features):
        if features.num_pixels < 100:
            base_ratio = 0.25
        elif features.num_pixels < 500:
            base_ratio = 0.15
        elif features.num_pixels < 2000:
            base_ratio = 0.10
        else:
            base_ratio = 0.05

        if features.effective_aspect_ratio < 0.5:
            base_ratio *= 1.5
        elif features.elongation > 3.0:
            base_ratio *= 0.7
        if not features.touching_ground:
            base_ratio *= 0.5
        return min(base_ratio, 0.5)

    def _calculate_vertical_gradient(self, region_mask, depths, y_coords):
        if len(y_coords) < 10:
            return 0.0
        y_min, y_max = np.min(y_coords), np.max(y_coords)
        if y_max <= y_min:
            return 0.0
        top_mask = y_coords <= y_min + (y_max - y_min) * 0.2
        bottom_mask = y_coords >= y_max - (y_max - y_min) * 0.2
        if np.sum(top_mask) > 0 and np.sum(bottom_mask) > 0:
            top_depth = np.mean(depths[top_mask])
            bottom_depth = np.mean(depths[bottom_mask])
            return abs(bottom_depth - top_depth) / (y_max - y_min)
        return 0.0

    def _calculate_horizontal_gradient(self, region_mask, depths, x_coords):
        if len(x_coords) < 10:
            return 0.0
        x_min, x_max = np.min(x_coords), np.max(x_coords)
        if x_max <= x_min:
            return 0.0
        left_mask = x_coords <= x_min + (x_max - x_min) * 0.2
        right_mask = x_coords >= x_max - (x_max - x_min) * 0.2
        if np.sum(left_mask) > 0 and np.sum(right_mask) > 0:
            left_depth = np.mean(depths[left_mask])
            right_depth = np.mean(depths[right_mask])
            return abs(right_depth - left_depth) / (x_max - x_min)
        return 0.0

    def _get_neighbor_classes(self, region_mask):
        dilated = ndimage.binary_dilation(region_mask, iterations=3)
        neighbor_mask = dilated & (~region_mask)
        if np.sum(neighbor_mask) == 0:
            return []
        neighbor_values = self.semantic_map[neighbor_mask]
        unique_classes = np.unique(neighbor_values)
        return unique_classes.tolist()

    def _calculate_edge_density(self, region_mask):
        if np.sum(region_mask) < 10:
            return 0.0
        region_depth = self.depth_map.copy()
        region_depth[~region_mask] = 0
        edges_x = ndimage.sobel(region_depth, axis=1)
        edges_y = ndimage.sobel(region_depth, axis=0)
        edge_magnitude = np.sqrt(edges_x**2 + edges_y**2)
        return np.mean(edge_magnitude[region_mask])

    def _calculate_internal_structure(self, region_mask):
        if np.sum(region_mask) < 20:
            return 0.0
        region_depth = self.depth_map[region_mask]
        return np.std(region_depth) / (np.mean(region_depth) + 1e-6)

    def _get_dominant_kmeans_label(self, region_mask):
        labels = self.fmb_map[region_mask]
        unique_labels, counts = np.unique(labels, return_counts=True)
        return unique_labels[np.argmax(counts)]

    def _calculate_layer_scores(self, labels, depths, features):
        layer_scores = {0: 0.01, 1: 0.01, 2: 0.01}
        for layer in [0, 1, 2]:
            layer_mask = np.array(labels) == layer
            if np.sum(layer_mask) > 0:
                base_score = np.sum(layer_mask) / len(labels)
                layer_depths = np.array(depths)[layer_mask]
                depth_consistency = 1.0 / (1.0 + np.std(layer_depths) / 10)
                semantic_bonus = self._get_semantic_layer_compatibility(features.semantic_class, layer)
                layer_scores[layer] = base_score * depth_consistency * semantic_bonus

        total_score = sum(layer_scores.values())
        if total_score > 0:
            for layer in layer_scores:
                layer_scores[layer] /= total_score
        return layer_scores

    def _get_semantic_layer_compatibility(self, semantic_class, layer):
        """Get compatibility between semantic class and layer - UPDATED keywords"""
        if semantic_class in self.idx_to_info:
            name = self.idx_to_info[semantic_class]['name'].lower()

            # Foreground tendency
            if any(word in name for word in [
                'person', 'individual', 'people', 'car', 'auto', 'automobile', 
                'chair', 'trash', 'ashcan', 'garbage', 'bin', 'bicycle', 'bike', 
                'bus', 'truck', 'van', 'motorbike', 'animal', 'dog', 'cat',
                'table', 'desk', 'sofa', 'couch', 'bed', 'armchair', 'seat',
                'bench', 'stool', 'ottoman'
            ]):
                return [1.2, 1.0, 0.8][layer]

            # Background tendency
            elif any(word in name for word in [
                'sky', 'mountain', 'mount', 'hill', 'sea', 'cloud', 'horizon'
            ]):
                return [0.6, 0.8, 1.3][layer]

            # Middleground tendency
            elif any(word in name for word in [
                'tree', 'building', 'edifice', 'wall', 'house', 'fence', 
                'tower', 'bridge', 'skyscraper', 'palm'
            ]):
                return [0.9, 1.1, 0.9][layer]

        return 1.0

    def _analyze_vertical_pattern_advanced(self, row_analyses):
        if len(row_analyses) < 3:
            return 'uniform'
        layers = [r['dominant_layer'] for r in row_analyses]
        changes = sum(1 for i in range(1, len(layers)) if layers[i] != layers[i-1])
        max_streak = 1
        current_streak = 1
        for i in range(1, len(layers)):
            if layers[i] == layers[i-1]:
                current_streak += 1
                max_streak = max(max_streak, current_streak)
            else:
                current_streak = 1

        change_ratio = changes / len(layers)
        streak_ratio = max_streak / len(layers)

        if change_ratio < 0.1:
            return 'uniform'
        elif change_ratio < 0.3 and streak_ratio > 0.5:
            return 'gradient'
        elif changes == 1:
            return 'split'
        else:
            return 'chaotic'

    def _decide_from_row_analysis(self, row_analyses, vertical_pattern, features):
        layer_votes = {0: 0, 1: 0, 2: 0}
        for analysis in row_analyses:
            layer = analysis['dominant_layer']
            confidence = analysis['confidence']
            pixels = analysis['pixel_count']
            weight = confidence * pixels
            layer_votes[layer] += weight

        total_votes = sum(layer_votes.values())
        if total_votes > 0:
            best_layer = max(layer_votes, key=layer_votes.get)
            base_confidence = layer_votes[best_layer] / total_votes

            if vertical_pattern == 'uniform':
                final_confidence = min(base_confidence + 0.1, 0.95)
            elif vertical_pattern == 'chaotic':
                final_confidence = base_confidence * 0.8
            else:
                final_confidence = base_confidence
            return best_layer, final_confidence
        return 1, 0.5

    def _optimized_horizontal_analysis(self, region_mask, features):
        labels = self.fmb_map[region_mask]
        unique_labels, counts = np.unique(labels, return_counts=True)
        dominant_label = unique_labels[np.argmax(counts)]
        confidence = counts[np.argmax(counts)] / len(labels)

        if len(unique_labels) > 1:
            distribution_entropy = -sum((c/len(labels)) * np.log(c/len(labels)) for c in counts)
            max_entropy = np.log(len(unique_labels))
            if distribution_entropy > max_entropy * 0.7:
                confidence *= 0.8

        return {'layer': dominant_label, 'confidence': confidence, 'method': 'horizontal_simple',
                'details': {'unique_labels': len(unique_labels), 'dominant_ratio': confidence}}

    def _get_layer_depth_profiles(self):
        profiles = {}
        for layer in [0, 1, 2]:
            layer_mask = self.fmb_map == layer
            if np.sum(layer_mask) > 0:
                layer_depths = self.depth_map[layer_mask]
                profiles[layer] = (np.mean(layer_depths), np.std(layer_depths))
            else:
                if layer == 0:
                    profiles[layer] = (200, 30)
                elif layer == 1:
                    profiles[layer] = (128, 40)
                else:
                    profiles[layer] = (50, 30)
        return profiles

    def _analyze_neighbor_layers(self, region_mask, features):
        neighbor_layers = []
        for neighbor_class in features.neighbor_classes:
            neighbor_mask = self.semantic_map == neighbor_class
            if np.sum(neighbor_mask) > 0:
                neighbor_labels = self.fmb_map[neighbor_mask]
                unique_labels, counts = np.unique(neighbor_labels, return_counts=True)
                dominant_label = unique_labels[np.argmax(counts)]
                neighbor_layers.append(dominant_label)
        return neighbor_layers

# ==================== Basic Functions ====================

def hex2rgb(hex_color):
    """Convert hexadecimal color to RGB tuple"""
    hex_color = hex_color.lstrip('#')
    return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))

def read_config_from_json(json_path):
    """Read configuration from JSON file"""
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    color_to_idx = {}
    idx_to_info = {}
    openness = []
    countability = []
    idx = 1

    for entry in data:
        rgb = hex2rgb(entry['color'])
        color_to_idx[rgb] = idx
        open_value = int(entry['openness'])
        countable_value = int(entry['countable'])

        idx_to_info[idx] = {
            'name': entry['name'],
            'openness': open_value,
            'countable': countable_value,
            'color': entry['color'].upper(),
            'rgb': rgb
        }
        openness.append(open_value)
        countability.append(countable_value)
        idx += 1

    return color_to_idx, idx_to_info, openness, countability

def semantic_rgb_to_idx_map(semantic_img, color_to_idx):
    """Convert RGB semantic image to index map"""
    H, W, _ = semantic_img.shape
    semantic_map = np.zeros((H, W), dtype=np.uint8)
    unique_colors = np.unique(semantic_img.reshape(-1, 3), axis=0)
    print(f"\nDetected image colors: {unique_colors.shape[0]} unique colors")
    color_set = set(color_to_idx.keys())

    for color in unique_colors:
        rgb = tuple(color.tolist())
        if rgb in color_set:
            semantic_map[np.all(semantic_img == rgb, axis=2)] = color_to_idx[rgb]

    return semantic_map

def process_fmb_segmentation_with_intelligent_system(depth_map, semantic_map, openness,
                                                   countability, idx_to_info,
                                                   enable_hole_filling=True):
    """Perform FMB segmentation using intelligent system"""
    H, W = depth_map.shape
    print(f"\nProcessing image size: {W}x{H}")
    valid_mask = depth_map > 0
    depth_valid = depth_map[valid_mask]

    if len(depth_valid) == 0:
        print("Warning: No valid depth values!")
        empty_map = np.full((H, W), 2, dtype=np.uint8)
        return empty_map, empty_map, [], None, None

    depth_sorted = np.sort(depth_valid)[::-1]
    idx1 = int(len(depth_sorted) * 0.51)
    idx2 = int(len(depth_sorted) * 0.77)

    thresh1 = depth_sorted[idx1] if idx1 < len(depth_sorted) else depth_sorted[-1]
    thresh2 = depth_sorted[idx2] if idx2 < len(depth_sorted) else depth_sorted[-1]

    print(f"Theoretical depth thresholds based on 51%-26%-23%:")
    print(f"  Foreground>{thresh1:.1f} (51%), Middleground {thresh2:.1f}-{thresh1:.1f} (26%), Background<{thresh2:.1f} (23%)")

    # Feature engineering
    n_features = 2
    features = np.zeros((np.sum(valid_mask), n_features))

    initial_labels = np.zeros_like(depth_valid, dtype=float)
    initial_labels[depth_valid > thresh1] = 0.0
    initial_labels[(depth_valid <= thresh1) & (depth_valid > thresh2)] = 1.0
    initial_labels[depth_valid <= thresh2] = 2.0
    features[:, 0] = initial_labels

    boundary_sensitivity = np.zeros_like(depth_valid, dtype=float)
    boundary_range = (np.max(depth_valid) - np.min(depth_valid)) * 0.08

    for i, d in enumerate(depth_valid):
        dist_to_thresh1 = abs(d - thresh1)
        dist_to_thresh2 = abs(d - thresh2)
        min_boundary_dist = min(dist_to_thresh1, dist_to_thresh2)
        if min_boundary_dist <= boundary_range:
            boundary_sensitivity[i] = np.exp(-min_boundary_dist**2 / (2 * (boundary_range/3)**2))
        else:
            boundary_sensitivity[i] = 0.0

    features[:, 1] = boundary_sensitivity
    weights = np.array([0.95, 0.05])
    weighted_features = features * weights

    print(f"\nFeature engineering statistics:")
    print(f"  - Pixels near boundaries (sensitivity > 0.1): {np.sum(boundary_sensitivity > 0.1)} "
          f"({np.sum(boundary_sensitivity > 0.1) / len(depth_valid) * 100:.1f}%)")
    print(f"  - Average boundary sensitivity: {np.mean(boundary_sensitivity):.3f}")

    # K-means
    centers = np.zeros((3, n_features))
    for i in range(3):
        centers[i][0] = i * weights[0]
        centers[i][1] = 0.0

    print("\nPerforming simplified K-means clustering...")

    try:
        kmeans = KMeans(n_clusters=3, init=centers, n_init=1, max_iter=50, random_state=42, tol=1e-4)
        cluster_labels = kmeans.fit_predict(weighted_features)
        print("\nValidating K-means adherence to theory...")

        for i in range(3):
            cluster_mask = cluster_labels == i
            if np.sum(cluster_mask) > 0:
                theory_labels_in_cluster = initial_labels[cluster_mask]
                unique_theory, counts = np.unique(theory_labels_in_cluster, return_counts=True)
                main_theory_count = counts[np.argmax(counts)]
                total_in_cluster = np.sum(cluster_mask)
                purity = main_theory_count / total_in_cluster
                print(f"  Cluster {i}: {purity*100:.1f}% consistent with theory")

    except Exception as e:
        print(f"K-means failed: {e}, using theory-based assignment")
        cluster_labels = initial_labels.astype(int)

    # Map cluster labels
    cluster_depths = []
    for i in range(3):
        mask = cluster_labels == i
        if np.sum(mask) > 0:
            cluster_depths.append(np.mean(depth_valid[mask]))
        else:
            cluster_depths.append(0)

    depth_order = np.argsort(cluster_depths)[::-1]
    label_mapping = {depth_order[i]: i for i in range(3)}

    fmb_map = np.full((H, W), 2, dtype=np.uint8)
    valid_coords = np.where(valid_mask)
    for idx, (y, x) in enumerate(zip(valid_coords[0], valid_coords[1])):
        original_label = cluster_labels[idx]
        mapped_label = label_mapping[original_label]
        fmb_map[y, x] = mapped_label

    kmeans_original = fmb_map.copy()

    # Initialize Forced Semantic Rules
    print("\n" + "=" * 50)
    print("Initializing Forced Semantic Rules...")
    print("=" * 50)
    forced_rules = ForcedSemanticRules(idx_to_info)

    # Apply Forced Rules FIRST
    print("\n" + "=" * 50)
    print("Applying Forced Semantic Rules...")
    print("=" * 50)
    
    forced_adjustments = 0
    for sem_class in np.unique(semantic_map):
        if sem_class == 0:
            continue
        
        forced_layer = forced_rules.get_forced_layer(sem_class)
        if forced_layer is not None:
            class_mask = semantic_map == sem_class
            pixel_count = np.sum(class_mask)
            if pixel_count > 0:
                class_name = idx_to_info.get(sem_class, {}).get('name', f'Class_{sem_class}')
                current_layers = fmb_map[class_mask]
                changed_pixels = np.sum(current_layers != forced_layer)
                fmb_map[class_mask] = forced_layer
                print(f"  {class_name}: {pixel_count} pixels -> Layer {forced_layer} (Background)")
                print(f"    Changed {changed_pixels} pixels ({changed_pixels/pixel_count*100:.1f}%)")
                forced_adjustments += 1
    
    print(f"\nForced rule adjustments: {forced_adjustments} classes processed")

    # Intelligent Connectivity Processing
    print("\n" + "=" * 50)
    print("Processing closed objects with intelligent system...")
    print("=" * 50)

    decision_system = IntelligentLayerDecisionSystem(
        depth_map, semantic_map, fmb_map, openness, countability, idx_to_info,
        forced_rules=forced_rules
    )

    adjustments = []
    adjustment_count = 0

    unique_classes = np.unique(semantic_map)
    for sem_class in unique_classes:
        if sem_class == 0:
            continue

        if sem_class <= len(openness) and openness[sem_class - 1] == 1:
            class_name = idx_to_info[sem_class]['name']
            
            if forced_rules.is_forced(sem_class):
                print(f"\nSkipping {class_name} (already processed by forced rules)")
                continue
                
            print(f"\nProcessing closed class: {class_name}")

            class_mask = (semantic_map == sem_class).astype(np.uint8)
            labeled_array, num_features = ndimage.label(class_mask)
            print(f"  Found {num_features} connected components")

            for region_id in range(1, num_features + 1):
                region_mask = labeled_array == region_id
                pixel_count = np.sum(region_mask)

                if pixel_count < 20:
                    continue

                result = decision_system.intelligent_layer_decision(region_mask)
                if result[0] is not None:
                    layer, confidence, details = result

                    current_labels = fmb_map[region_mask]
                    if not np.all(current_labels == layer):
                        fmb_map[region_mask] = layer
                        adjustment_count += 1

                        adjustments.append({
                            'class': sem_class, 'class_name': idx_to_info[sem_class]['name'],
                            'region': region_id, 'layer': layer, 'confidence': confidence, 'details': details
                        })

                        print(f"\n  Adjusted object: {idx_to_info[sem_class]['name']} (region {region_id})")
                        print(f"    Decided layer: {['Foreground', 'Middleground', 'Background'][layer]}")
                        print(f"    Confidence: {confidence:.2f}")

    if adjustment_count > 0:
        print(f"\nProcessing complete: Unified {adjustment_count} closed objects to single layers")

    # Hole Filling
    fmb_before_filling = None
    fill_info = None

    if enable_hole_filling:
        print("\n" + "=" * 50)
        print("Executing intelligent hole filling post-processing...")
        print("=" * 50)

        fmb_before_filling = fmb_map.copy()
        hole_filler = IntelligentHoleFilling(depth_map, fmb_map)
        fmb_map, fill_info = hole_filler.process()
        adjustments.append({'type': 'hole_filling', 'info': fill_info})

    # Statistics
    print("\nFinal deviation from theoretical distribution:")
    total_valid = np.sum(valid_mask)
    for i in range(3):
        actual_count = np.sum(fmb_map == i)
        actual_pct = (actual_count / total_valid) * 100 if total_valid > 0 else 0
        theory_pct = [51.0, 26.0, 23.0][i]
        deviation = actual_pct - theory_pct
        print(f"  Layer {i}: {actual_pct:.1f}% (theory: {theory_pct}%, deviation: {deviation:+.1f}%)")

    return fmb_map, kmeans_original, adjustments, fmb_before_filling, fill_info

# ==================== Visualization Functions ====================

def create_colored_fmb_map(fmb_map):
    """Create colored visualization of FMB map"""
    fmb_colors = {0: [220, 20, 60], 1: [46, 125, 50], 2: [30, 144, 255]}
    H, W = fmb_map.shape
    colored_fmb = np.zeros((H, W, 3), dtype=np.uint8)
    for value, color in fmb_colors.items():
        mask = fmb_map == value
        colored_fmb[mask] = color
    return colored_fmb

def visualize_semantic_with_original_colors(semantic_map, idx_to_info):
    H, W = semantic_map.shape
    colored_semantic = np.zeros((H, W, 3), dtype=np.uint8)
    for idx, info in idx_to_info.items():
        mask = semantic_map == idx
        colored_semantic[mask] = info['rgb']
    return colored_semantic

def visualize_closed_objects(semantic_map, openness):
    H, W = semantic_map.shape
    colored_objects = np.zeros((H, W, 3), dtype=np.uint8)
    object_id = 0
    np.random.seed(42)
    unique_classes = np.unique(semantic_map)
    print(f"\nChecking closed objects...")
    closed_found = 0

    for sem_class in unique_classes:
        if sem_class > 0 and sem_class <= len(openness):
            is_closed = openness[sem_class - 1]
            if is_closed == 1:
                closed_found += 1
                class_mask = (semantic_map == sem_class).astype(np.uint8)
                labeled_array, num_features = ndimage.label(class_mask)
                for region_id in range(1, num_features + 1):
                    region_mask = labeled_array == region_id
                    color = np.random.randint(50, 256, 3)
                    while np.mean(color) < 80 or np.mean(color) > 200:
                        color = np.random.randint(50, 256, 3)
                    colored_objects[region_mask] = color
                    object_id += 1

    return colored_objects, object_id

def visualize_adjustment_process(kmeans_original, fmb_map, semantic_map, openness):
    H, W = fmb_map.shape
    adjustment_vis = {}
    adjustment_mask = kmeans_original != fmb_map
    adjustment_vis['mask'] = adjustment_mask
    adjusted_objects_map = np.zeros((H, W, 3), dtype=np.uint8)
    object_adjustment_info = []

    if semantic_map is not None and openness is not None:
        unique_classes = np.unique(semantic_map)
        for sem_class in unique_classes:
            if sem_class > 0 and sem_class <= len(openness):
                if openness[sem_class - 1] == 1:
                    class_mask = (semantic_map == sem_class).astype(np.uint8)
                    labeled_array, num_features = ndimage.label(class_mask)
                    for region_id in range(1, num_features + 1):
                        region_mask = labeled_array == region_id
                        if np.any(adjustment_mask[region_mask]):
                            original_labels = kmeans_original[region_mask]
                            final_labels = fmb_map[region_mask]
                            orig_label = np.bincount(original_labels).argmax()
                            final_label = final_labels[0]
                            if orig_label != final_label:
                                if orig_label < final_label:
                                    color = [255, 100, 100]
                                else:
                                    color = [100, 255, 100]
                                adjusted_objects_map[region_mask] = color
                                object_adjustment_info.append({
                                    'class': sem_class, 'object_id': region_id,
                                    'from_layer': orig_label, 'to_layer': final_label,
                                    'pixel_count': np.sum(region_mask)
                                })

    adjustment_vis['adjusted_objects_map'] = adjusted_objects_map
    adjustment_vis['object_info'] = object_adjustment_info
    return adjustment_vis

def create_semantic_table(idx_to_info, semantic_map):
    present_classes = np.unique(semantic_map[semantic_map > 0])
    print(f"\nSemantic categories found in image: {present_classes}")
    table_data = []

    for class_id in present_classes:
        if class_id in idx_to_info:
            info = idx_to_info[class_id]
            openness_str = "Closed" if info['openness'] == 1 else "Open"
            countable_str = "Countable" if info['countable'] == 1 else "Uncountable"
            pixels = np.sum(semantic_map == class_id)
            percentage = (pixels / semantic_map.size) * 100
            table_data.append([class_id, info['name'], openness_str, countable_str, f"{pixels:,}", f"{percentage:.1f}%"])
        else:
            pixels = np.sum(semantic_map == class_id)
            percentage = (pixels / semantic_map.size) * 100
            table_data.append([class_id, f"Unknown_Class_{class_id}", "Unknown", "Unknown", f"{pixels:,}", f"{percentage:.1f}%"])

    print("\nSemantic Categories in Image:")
    print("-" * 90)
    print(f"{'ID':>4} | {'Category Name':<40} | {'Openness':<8} | {'Countable':<10} | {'Pixels':>10} | {'%':>6}")
    print("-" * 90)
    for row in sorted(table_data, key=lambda x: int(x[4].replace(',', '')), reverse=True):
        print(f"{row[0]:>4} | {row[1]:<40} | {row[2]:<8} | {row[3]:<10} | {row[4]:>10} | {row[5]:>6}")
    print("-" * 90)
    return table_data

def visualize_hole_filling_results(original_fmb, filled_fmb, depth_map, fill_info):
    H, W = original_fmb.shape
    comparison = np.zeros((H * 2, W * 2, 3), dtype=np.uint8)
    fmb_colors = {0: [220, 20, 60], 1: [46, 125, 50], 2: [30, 144, 255]}

    original_colored = np.zeros((H, W, 3), dtype=np.uint8)
    for layer, color in fmb_colors.items():
        original_colored[original_fmb == layer] = color
    comparison[:H, :W] = original_colored

    filled_colored = np.zeros((H, W, 3), dtype=np.uint8)
    for layer, color in fmb_colors.items():
        filled_colored[filled_fmb == layer] = color
    comparison[:H, W:] = filled_colored

    changes = original_fmb != filled_fmb
    change_vis = original_colored.copy()
    change_vis[changes] = [255, 255, 0]
    comparison[H:, :W] = change_vis

    depth_normalized = ((depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255).astype(np.uint8)
    depth_colored = cv2.applyColorMap(depth_normalized, cv2.COLORMAP_VIRIDIS)
    comparison[H:, W:] = depth_colored

    font = cv2.FONT_HERSHEY_SIMPLEX
    cv2.putText(comparison, 'Original FMB', (10, 30), font, 1, (255, 255, 255), 2)
    cv2.putText(comparison, 'After Hole Filling', (W + 10, 30), font, 1, (255, 255, 255), 2)
    cv2.putText(comparison, 'Changes (Yellow)', (10, H + 30), font, 1, (255, 255, 255), 2)
    cv2.putText(comparison, 'Depth Reference', (W + 10, H + 30), font, 1, (255, 255, 255), 2)

    stats_text = f"Filled: {fill_info['holes_filled']}, Preserved: {fill_info['holes_preserved']}"
    cv2.putText(comparison, stats_text, (10, H * 2 - 10), font, 0.8, (255, 255, 255), 2)

    return comparison

# ==================== Main Function ====================

def main(enable_hole_filling=True):
    """Main function, executes complete intelligent FMB segmentation workflow"""
    print("=" * 50)
    print("Intelligent FMB Segmentation System")
    print("  Foreground-Middleground-Background Segmentation")
    print("  Version 2.1 (with Forced Semantic Rules)")
    print("  - Sky is ALWAYS Background (forced rule)")
    print("=" * 50)
    print("\nPlease upload the following files:")
    print("1. Semantic segmentation image (semantic.png)")
    print("2. Depth map (depth.png)")
    print("3. JSON configuration file")
    print("\n" + "=" * 50)

    if IN_COLAB:
        uploaded = files.upload()
    else:
        print("Please ensure files are in the current directory")
        uploaded = {}
        import glob
        png_files = glob.glob("*.png")
        json_files = glob.glob("*.json")
        for f in png_files + json_files:
            uploaded[f] = None

    semantic_path = None
    depth_path = None
    json_path = None

    for filename in uploaded.keys():
        lower_name = filename.lower()
        if 'semantic' in lower_name and lower_name.endswith('.png'):
            semantic_path = filename
        elif 'depth' in lower_name and lower_name.endswith('.png'):
            depth_path = filename
        elif lower_name.endswith('.json'):
            json_path = filename

    if semantic_path is None:
        for filename in uploaded.keys():
            if filename.lower().endswith('.png') and 'depth' not in filename.lower():
                semantic_path = filename
                break

    if depth_path is None:
        for filename in uploaded.keys():
            if filename.lower().endswith('.png') and filename != semantic_path:
                depth_path = filename
                break

    print(f"\nIdentified files:")
    print(f"  Semantic: {semantic_path}")
    print(f"  Depth: {depth_path}")
    print(f"  JSON: {json_path}")

    if not all([semantic_path, depth_path, json_path]):
        print("\nError: Could not identify all required files!")
        return

    print("\nReading configuration...")
    color_to_idx, idx_to_info, openness, countability = read_config_from_json(json_path)
    print(f"  Loaded {len(color_to_idx)} semantic classes")

    print("\nReading images...")
    semantic_img = cv2.imread(semantic_path)
    semantic_img = cv2.cvtColor(semantic_img, cv2.COLOR_BGR2RGB)
    depth_img = cv2.imread(depth_path, cv2.IMREAD_GRAYSCALE)
    depth_map = depth_img.astype(np.float32)

    H, W = depth_map.shape
    print(f"  Image size: {W}x{H}")

    print("\nConverting semantic image...")
    semantic_map = semantic_rgb_to_idx_map(semantic_img, color_to_idx)

    create_semantic_table(idx_to_info, semantic_map)

    print("\n" + "=" * 50)
    print("Starting FMB Segmentation...")
    print("=" * 50)

    fmb_map, kmeans_original, adjustments, fmb_before_filling, fill_info = \
        process_fmb_segmentation_with_intelligent_system(
            depth_map, semantic_map, openness, countability, idx_to_info,
            enable_hole_filling=enable_hole_filling
        )

    print("\n" + "=" * 50)
    print("Generating visualizations...")
    print("=" * 50)

    colored_kmeans = create_colored_fmb_map(kmeans_original)
    colored_fmb = create_colored_fmb_map(fmb_map)
    colored_semantic = visualize_semantic_with_original_colors(semantic_map, idx_to_info)
    colored_objects, total_objects = visualize_closed_objects(semantic_map, openness)
    adjustment_vis = visualize_adjustment_process(kmeans_original, fmb_map, semantic_map, openness)

    if fmb_before_filling is not None:
        colored_before_filling = create_colored_fmb_map(fmb_before_filling)
        colored_after_filling = create_colored_fmb_map(fmb_map)

    closed_semantic_colored = np.zeros((H, W, 3), dtype=np.uint8)
    for i in range(len(openness)):
        if openness[i] == 1:
            mask = semantic_map == i+1
            if i+1 in idx_to_info:
                closed_semantic_colored[mask] = idx_to_info[i+1]['rgb']

    # Create visualization figure
    fig = plt.figure(figsize=(24, 18))

    ax1 = plt.subplot(3, 4, 1)
    ax1.imshow(colored_semantic)
    ax1.set_title('Semantic Segmentation (Original Colors)', fontsize=12)
    ax1.axis('off')

    ax2 = plt.subplot(3, 4, 2)
    ax2.imshow(depth_map, cmap='viridis')
    ax2.set_title('Depth Map (0=far, 255=near)', fontsize=12)
    ax2.axis('off')

    ax3 = plt.subplot(3, 4, 3)
    ax3.imshow(colored_objects)
    ax3.set_title(f'Closed Objects ({total_objects} objects)', fontsize=12)
    ax3.axis('off')

    ax4 = plt.subplot(3, 4, 4)
    ax4.imshow(closed_semantic_colored)
    ax4.set_title('Closed Classes Only (Original Colors)', fontsize=12)
    ax4.axis('off')

    ax5 = plt.subplot(3, 4, 5)
    ax5.imshow(colored_kmeans)
    ax5.set_title('K-means Original', fontsize=12)
    ax5.axis('off')

    ax6 = plt.subplot(3, 4, 6)
    ax6.imshow(colored_fmb)
    ax6.set_title('After Intelligent Processing' + (' + Hole Filling' if enable_hole_filling else ''), fontsize=12)
    ax6.axis('off')

    ax7 = plt.subplot(3, 4, 7)
    diff_map = np.zeros_like(colored_fmb)
    diff_mask = kmeans_original != fmb_map
    diff_map[~diff_mask] = colored_kmeans[~diff_mask] * 0.3
    diff_map[diff_mask] = [255, 255, 0]
    ax7.imshow(diff_map)
    ax7.set_title('Changes (Yellow = Adjusted)', fontsize=12)
    ax7.axis('off')

    ax8 = plt.subplot(3, 4, 8)
    ax8.imshow(adjustment_vis['adjusted_objects_map'])
    ax8.set_title('Adjusted Objects Detail', fontsize=12)
    ax8.axis('off')

    ax9 = plt.subplot(3, 4, 9)
    comparison = np.hstack([colored_kmeans[:, :W//2], colored_fmb[:, W//2:]])
    ax9.imshow(comparison)
    ax9.set_title('K-means (left) vs Final (right)', fontsize=12)
    ax9.axis('off')

    ax10 = plt.subplot(3, 4, 10)
    overlay = colored_fmb.copy()
    object_edges = ndimage.sobel(colored_objects.mean(axis=2)) > 0
    overlay[object_edges] = [255, 255, 255]
    ax10.imshow(overlay)
    ax10.set_title('FMB + Object Boundaries', fontsize=12)
    ax10.axis('off')

    ax11 = plt.subplot(3, 4, 11)
    adjustment_counts = {'Foreground': 0, 'Middleground': 0, 'Background': 0}
    for info in adjustment_vis['object_info']:
        to_name = ['Foreground', 'Middleground', 'Background'][info['to_layer']]
        adjustment_counts[to_name] += 1
    if any(adjustment_counts.values()):
        bars = ax11.bar(adjustment_counts.keys(), adjustment_counts.values())
        ax11.set_title('Objects Moved To Each Layer', fontsize=12)
        ax11.set_ylabel('Number of Objects')
    else:
        ax11.text(0.5, 0.5, 'No Adjustments Made', ha='center', va='center', transform=ax11.transAxes)
        ax11.set_title('Adjustment Statistics', fontsize=12)

    ax12 = plt.subplot(3, 4, 12)
    ax12.axis('off')
    legend_elements = [
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=[220/255, 20/255, 60/255], markersize=10, label='Foreground (Near)'),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=[46/255, 125/255, 50/255], markersize=10, label='Middleground'),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=[30/255, 144/255, 255/255], markersize=10, label='Background (Far)'),
        plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='yellow', markersize=10, label='Adjusted Pixels'),
    ]
    ax12.legend(handles=legend_elements, loc='center', fontsize=10)
    ax12.set_title('Legend', fontsize=12)

    plt.tight_layout()
    plt.show()

    # Hole filling visualization
    if fmb_before_filling is not None and fill_info is not None:
        fig_holes = plt.figure(figsize=(20, 10))

        ax1 = plt.subplot(2, 3, 1)
        ax1.imshow(create_colored_fmb_map(fmb_before_filling))
        ax1.set_title('Before Hole Filling', fontsize=14)
        ax1.axis('off')

        ax2 = plt.subplot(2, 3, 2)
        ax2.imshow(create_colored_fmb_map(fmb_map))
        ax2.set_title('After Hole Filling', fontsize=14)
        ax2.axis('off')

        ax3 = plt.subplot(2, 3, 3)
        changes_mask = fmb_before_filling != fmb_map
        change_vis = create_colored_fmb_map(fmb_before_filling).copy()
        change_vis[changes_mask] = [255, 255, 0]
        ax3.imshow(change_vis)
        ax3.set_title(f'Filled Holes (Yellow) - {fill_info["holes_filled"]} holes', fontsize=14)
        ax3.axis('off')

        ax4 = plt.subplot(2, 3, 4)
        ax4.imshow(depth_map, cmap='viridis')
        ax4.set_title('Depth Map Reference', fontsize=14)
        ax4.axis('off')

        ax5 = plt.subplot(2, 3, 5)
        labels = ['Holes Filled', 'Holes Preserved']
        sizes = [fill_info['holes_filled'], fill_info['holes_preserved']]
        if sum(sizes) > 0:
            ax5.pie(sizes, labels=labels, autopct='%1.0f%%', startangle=90)
        else:
            ax5.text(0.5, 0.5, 'No holes detected', ha='center', va='center', transform=ax5.transAxes)
        ax5.set_title('Hole Filling Statistics', fontsize=14)

        ax6 = plt.subplot(2, 3, 6)
        ax6.axis('off')
        legend_text = f"""
Hole Filling Summary:
- Total holes detected: {fill_info['total_holes_detected']}
- Holes filled: {fill_info['holes_filled']}
- Holes preserved: {fill_info['holes_preserved']}
"""
        ax6.text(0.1, 0.5, legend_text, transform=ax6.transAxes, fontsize=12, verticalalignment='center', fontfamily='monospace')

        plt.tight_layout()
        plt.show()

    # Print statistics
    print("\nFMB Distribution Statistics:")
    print("-" * 60)
    print(f"{'Layer':<15} {'K-means Original':>20} {'Final Result':>20}")
    print("-" * 60)
    total_pixels = fmb_map.size
    for layer, layer_name in enumerate(['Foreground', 'Middleground', 'Background']):
        kmeans_count = np.sum(kmeans_original == layer)
        final_count = np.sum(fmb_map == layer)
        kmeans_pct = (kmeans_count / total_pixels) * 100
        final_pct = (final_count / total_pixels) * 100
        print(f"{layer_name:<15} {kmeans_count:>8,} ({kmeans_pct:>5.1f}%) {final_count:>8,} ({final_pct:>5.1f}%)")
    print("-" * 60)
    changed_pixels = np.sum(kmeans_original != fmb_map)
    changed_pct = (changed_pixels / total_pixels) * 100
    print(f"\nAdjusted pixels: {changed_pixels:,} ({changed_pct:.1f}%)")

    # Save results
    cv2.imwrite('fmb_kmeans_original.png', cv2.cvtColor(colored_kmeans, cv2.COLOR_RGB2BGR))
    cv2.imwrite('fmb_segmentation_final.png', cv2.cvtColor(colored_fmb, cv2.COLOR_RGB2BGR))
    cv2.imwrite('closed_objects.png', cv2.cvtColor(colored_objects, cv2.COLOR_RGB2BGR))
    cv2.imwrite('semantic_original_colors.png', cv2.cvtColor(colored_semantic, cv2.COLOR_RGB2BGR))
    cv2.imwrite('closed_classes_original_colors.png', cv2.cvtColor(closed_semantic_colored, cv2.COLOR_RGB2BGR))

    if adjustment_vis['adjusted_objects_map'] is not None:
        cv2.imwrite('adjustment_visualization.png', cv2.cvtColor(adjustment_vis['adjusted_objects_map'], cv2.COLOR_RGB2BGR))

    print(f"\nResults saved:")
    print(f"  - fmb_kmeans_original.png")
    print(f"  - fmb_segmentation_final.png")
    print(f"  - closed_objects.png")
    print(f"  - semantic_original_colors.png")
    print(f"  - closed_classes_original_colors.png")
    print(f"  - adjustment_visualization.png")

    if fmb_before_filling is not None:
        hole_comparison = visualize_hole_filling_results(fmb_before_filling, fmb_map, depth_map, fill_info)
        cv2.imwrite('hole_filling_comparison.png', cv2.cvtColor(hole_comparison, cv2.COLOR_RGB2BGR))
        print(f"  - hole_filling_comparison.png")

    if IN_COLAB:
        try:
            files.download('fmb_kmeans_original.png')
            files.download('fmb_segmentation_final.png')
            files.download('closed_objects.png')
            files.download('adjustment_visualization.png')
            files.download('semantic_original_colors.png')
            files.download('closed_classes_original_colors.png')
            if fmb_before_filling is not None:
                files.download('hole_filling_comparison.png')
            print("\nProcessing complete! Files downloaded.")
        except Exception as e:
            print(f"\nProcessing complete! Files saved but download failed: {e}")
    else:
        print(f"\nProcessing complete!")

# Entry point
if __name__ == "__main__":
    def in_notebook():
        try:
            from IPython import get_ipython
            if 'IPKernelApp' in get_ipython().config:
                return True
        except:
            pass
        return False

    enable_hole_filling = True

    if not in_notebook():
        import argparse
        parser = argparse.ArgumentParser(description='Intelligent FMB Segmentation System')
        parser.add_argument('--no-hole-filling', action='store_true', help='Disable hole filling')
        args, _ = parser.parse_known_args()
        enable_hole_filling = not args.no_hole_filling
    else:
        print("Running in Jupyter/Colab environment")
        print(f"   Hole filling: {'Enabled' if enable_hole_filling else 'Disabled'}")

    print("Installing required libraries...")
    os.system('pip install pandas openpyxl scikit-learn scipy > /dev/null 2>&1')
    print("Installation complete!\n")

    plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
    main(enable_hole_filling=enable_hole_filling)
