In [11]:
import os
import numpy as np
from PIL import Image
import glob
from sklearn.cluster import KMeans
from collections import Counter

def count_classes_in_segmentation_maps(folder_path, color_tolerance=10, max_clusters=50):
    """
    Count the maximum number of unique classes in segmentation maps.
    Groups similar colors together to avoid counting compression artifacts.
    
    Args:
        folder_path (str): Path to the folder containing segmentation maps
        color_tolerance (int): Distance threshold for considering colors as same (0-255)
        max_clusters (int): Maximum number of clusters to consider
    
    Returns:
        tuple: (max_classes, max_classes_image, all_class_counts)
    """
    # Get all image files in the folder
    ext = '*.png'
    image_files = []

    image_files.extend(glob.glob(os.path.join(folder_path, ext)))
    
    if not image_files:
        print(f"No image files found in {folder_path}")
        return 0, None, []
    
    max_classes = 0
    max_classes_image = None
    all_class_counts = []
    
    print(f"Processing {len(image_files)} images...")
    print(f"Using color tolerance: {color_tolerance}")
    
    for i, image_path in enumerate(image_files):
        try:
            # Load the image
            img = Image.open(image_path)
            
            # Convert to numpy array
            img_array = np.array(img)
            
            # Handle different image formats
            if len(img_array.shape) == 3:
                # RGB image - group similar colors
                num_classes = count_clustered_colors(img_array, color_tolerance, max_clusters)
            else:
                print("WARNING: **********GRAYSCALE DETECTED**********")
                # Grayscale image - group similar values
                unique_values = np.unique(img_array)
                num_classes = len(unique_values)
            
            all_class_counts.append((os.path.basename(image_path), num_classes))
            
            if num_classes > max_classes:
                max_classes = num_classes
                max_classes_image = image_path
            
            # Print progress every 10 images
            if (i + 1) % 10 == 0:
                print(f"Processed {i + 1}/{len(image_files)} images...")
                
        except Exception as e:
            print(f"Error processing {image_path}: {e}")
            continue
    
    return max_classes, max_classes_image, all_class_counts

def count_clustered_colors(img_array, tolerance=10, max_clusters=50):
    """
    Count unique colors by clustering similar colors together.
    
    Args:
        img_array: numpy array of the image
        tolerance: color distance threshold
        max_clusters: maximum number of clusters to consider
    
    Returns:
        int: number of unique color clusters
    """
    # Flatten the image to get all pixels
    pixels = img_array.reshape(-1, img_array.shape[-1])
    
    # Get unique colors first
    unique_colors = np.unique(pixels, axis=0)
    
    # If we have few unique colors, no need to cluster
    if len(unique_colors) <= max_clusters:
        return len(unique_colors)
    
    # Use K-means clustering to group similar colors
    # Start with a reasonable number of clusters
    n_clusters = min(max_clusters, len(unique_colors))
    
    try:
        # Perform clustering
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        kmeans.fit(unique_colors)
        
        # Count unique cluster centers
        return len(np.unique(kmeans.labels_))
    except:
        # Fallback to simpler method if clustering fails
        return count_colors_by_distance(unique_colors, tolerance)

def count_colors_by_distance(colors, tolerance=10):
    """
    Count unique colors by grouping colors within a distance threshold.
    
    Args:
        colors: array of unique colors
        tolerance: distance threshold for grouping
    
    Returns:
        int: number of unique color groups
    """
    if len(colors) == 0:
        return 0
    
    # Group colors by distance
    color_groups = []
    
    for color in colors:
        # Check if this color is close to any existing group
        assigned = False
        for group in color_groups:
            # Calculate Euclidean distance to group representative
            distance = np.linalg.norm(color - group)
            if distance <= tolerance:
                assigned = True
                break
        
        # If not assigned to any group, create a new group
        if not assigned:
            color_groups.append(color)
    
    return len(color_groups)

def analyze_specific_image(image_path, color_tolerance=10, max_clusters=50):
    """
    Analyze a specific image to see its unique classes/colors.
    
    Args:
        image_path (str): Path to the image
        color_tolerance (int): Distance threshold for grouping colors
        max_clusters (int): Maximum clusters to consider
    """
    try:
        img = Image.open(image_path)
        img_array = np.array(img)
        
        print(f"\nAnalyzing: {os.path.basename(image_path)}")
        print(f"Image shape: {img_array.shape}")
        
        if len(img_array.shape) == 3:
            # RGB image
            pixels = img_array.reshape(-1, img_array.shape[-1])
            unique_colors = np.unique(pixels, axis=0)
            
            print(f"Raw unique colors: {len(unique_colors)}")
            
            # Group similar colors
            clustered_count = count_clustered_colors(img_array, color_tolerance, max_clusters)
            print(f"Clustered unique colors (tolerance={color_tolerance}): {clustered_count}")
            
            # Show most common colors
            color_counts = Counter(map(tuple, pixels))
            most_common = color_counts.most_common(10)
            
            print("Most common colors:")
            for i, (color, count) in enumerate(most_common):
                percentage = (count / len(pixels)) * 100
                print(f"  {i+1}. RGB{color}: {count} pixels ({percentage:.1f}%)")
                
        else:
            # Grayscale image
            unique_values = np.unique(img_array)
            print(f"Number of unique values: {len(unique_values)}")
            print("Unique values:", unique_values[:20])  # Show first 20 values
            
    except Exception as e:
        print(f"Error analyzing {image_path}: {e}")

# Main execution
if __name__ == "__main__":
    # Set your folder path here
    folder_path = "./CVUSA_subset/polarmap/segmap/"  # Change this to your actual folder path
    
    # Adjust these parameters based on your needs
    color_tolerance = 15  # Colors within this distance are considered the same
    max_clusters = 30     # Maximum number of semantic classes expected
    
    print(f"Configuration:")
    print(f"  Color tolerance: {color_tolerance}")
    print(f"  Max clusters: {max_clusters}")
    print()
    
    # Count classes across all images
    max_classes, max_image, class_counts = count_classes_in_segmentation_maps(
        folder_path, color_tolerance, max_clusters
    )
    
    if max_classes > 0:
        print(f"\n{'='*50}")
        print(f"RESULTS:")
        print(f"{'='*50}")
        print(f"Maximum number of classes in a single image: {max_classes}")
        print(f"Image with maximum classes: {os.path.basename(max_image) if max_image else 'None'}")
        
        # Show statistics
        counts = [count for _, count in class_counts]
        print(f"\nStatistics across all images:")
        print(f"  Average classes per image: {np.mean(counts):.2f}")
        print(f"  Minimum classes in an image: {np.min(counts)}")
        print(f"  Maximum classes in an image: {np.max(counts)}")
        
        # Show top 10 images with most classes
        print(f"\nTop 10 images with most classes:")
        sorted_counts = sorted(class_counts, key=lambda x: x[1], reverse=True)
        for i, (filename, count) in enumerate(sorted_counts[:10]):
            print(f"  {i+1:2d}. {filename}: {count} classes")
        
        # Analyze the image with maximum classes
        if max_image:
            analyze_specific_image(max_image, color_tolerance, max_clusters)
    else:
        print("No valid images found or processed.")

Configuration:
  Color tolerance: 15
  Max clusters: 30

Processing 8862 images...
Using color tolerance: 15
Processed 10/8862 images...
Processed 20/8862 images...
Processed 30/8862 images...
Processed 40/8862 images...
Processed 50/8862 images...
Processed 60/8862 images...
Processed 70/8862 images...
Processed 80/8862 images...
Processed 90/8862 images...
Processed 100/8862 images...
Processed 110/8862 images...
Processed 120/8862 images...
Processed 130/8862 images...
Processed 140/8862 images...
Processed 150/8862 images...
Processed 160/8862 images...
Processed 170/8862 images...
Processed 180/8862 images...
Processed 190/8862 images...
Processed 200/8862 images...
Processed 210/8862 images...
Processed 220/8862 images...
Processed 230/8862 images...
Processed 240/8862 images...
Processed 250/8862 images...
Processed 260/8862 images...
Processed 270/8862 images...
Processed 280/8862 images...
Processed 290/8862 images...
Processed 300/8862 images...
Processed 310/8862 images...
P