In [None]:
# Image Segmentation using Vision Transformer (ViT) with Graph Clustering
#
# This notebook demonstrates how to perform image segmentation using a pre-trained Vision Transformer model
# and convert the results into a clustered graph representation.

# ============================================================================
# 1. Install Required Libraries
# ============================================================================

!pip install transformers torch torchvision pillow matplotlib numpy opencv-python
!pip install datasets networkx scikit-learn plotly community-detection
!pip install scikit-image

# ============================================================================
# 2. Import Libraries
# ============================================================================

import torch
import torch.nn as nn
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import cv2
from datasets import load_dataset
import requests
from io import BytesIO

# Graph libraries
import networkx as nx # Ensure this import is processed before using nx
from sklearn.cluster import SpectralClustering, KMeans
from sklearn.metrics import silhouette_score
from skimage.measure import regionprops, label
from skimage.segmentation import find_boundaries
import plotly.graph_objects as go
import plotly.express as px
from collections import defaultdict, Counter

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ============================================================================
# 3. Load Pre-trained Model
# ============================================================================

# Load pre-trained SegFormer model
model_name = "nvidia/segformer-b0-finetuned-ade-512-512"

# Load processor and model
processor = SegformerImageProcessor.from_pretrained(model_name)
model = SegformerForSemanticSegmentation.from_pretrained(model_name)

# Move model to device
model = model.to(device)
model.eval()

print(f"Model loaded: {model_name}")
print(f"Number of classes: {model.config.num_labels}")

# ============================================================================
# 4. Helper Functions
# ============================================================================

def load_image_from_url(url):
    """Load image from URL"""
    response = requests.get(url)
    image = Image.open(BytesIO(response.content)).convert('RGB')
    return image

def segment_image(image, processor, model, device):
    """Perform image segmentation"""
    # Preprocess image
    inputs = processor(images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)

    # Get segmentation map
    logits = outputs.logits

    # Resize to original image size
    upsampled_logits = nn.functional.interpolate(
        logits,
        size=image.size[::-1],  # (height, width)
        mode="bilinear",
        align_corners=False,
    )

    # Get predicted segmentation map
    predicted_segmentation_map = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy()

    return predicted_segmentation_map

def visualize_segmentation(image, segmentation_map, alpha=0.6):
    """Visualize segmentation results"""
    # Create colored segmentation map
    colored_seg = np.zeros((*segmentation_map.shape, 3), dtype=np.uint8)

    # Assign colors to different classes
    unique_classes = np.unique(segmentation_map)
    colors = plt.cm.tab20(np.linspace(0, 1, len(unique_classes)))

    for i, class_id in enumerate(unique_classes):
        mask = segmentation_map == class_id
        colored_seg[mask] = (colors[i][:3] * 255).astype(np.uint8)

    # Convert PIL image to numpy array
    image_np = np.array(image)

    # Blend original image with segmentation
    blended = cv2.addWeighted(image_np, 1-alpha, colored_seg, alpha, 0)

    return blended, colored_seg

# ============================================================================
# 5. Graph Creation Functions
# ============================================================================

def extract_segment_features(image, segmentation_map):
    """Extract features for each segment"""
    image_np = np.array(image)
    unique_segments = np.unique(segmentation_map)

    segment_features = {}

    for segment_id in unique_segments:
        mask = segmentation_map == segment_id

        # Basic properties
        area = np.sum(mask)

        # Color features (mean RGB)
        mean_color = image_np[mask].mean(axis=0)

        # Spatial features (centroid, bounding box)
        y_coords, x_coords = np.where(mask)
        centroid_y = np.mean(y_coords)
        centroid_x = np.mean(x_coords)

        bbox_min_y, bbox_max_y = np.min(y_coords), np.max(y_coords)
        bbox_min_x, bbox_max_x = np.min(x_coords), np.max(x_coords)

        # Compactness (perimeter^2 / area)
        perimeter = len(find_boundaries(mask, mode='outer'))
        compactness = (perimeter ** 2) / area if area > 0 else 0

        segment_features[segment_id] = {
            'area': area,
            'mean_color': mean_color,
            'centroid': (centroid_x, centroid_y),
            'bbox': (bbox_min_x, bbox_min_y, bbox_max_x, bbox_max_y),
            'compactness': compactness,
            'perimeter': perimeter
        }

    return segment_features

def create_adjacency_graph(segmentation_map, connectivity_type='boundary'):
    """Create adjacency graph from segmentation map"""
    # Ensure networkx is accessible here.
    # The import is at the top of the cell, so rerunning the cell should fix it.
    G = nx.Graph()

    unique_segments = np.unique(segmentation_map)

    # Add nodes
    for segment_id in unique_segments:
        G.add_node(segment_id)

    if connectivity_type == 'boundary':
        # Find boundary adjacencies
        boundaries = find_boundaries(segmentation_map, mode='inner')

        # Check all boundary pixels for adjacent segments
        boundary_coords = np.where(boundaries)

        for y, x in zip(boundary_coords[0], boundary_coords[1]):
            current_segment = segmentation_map[y, x]

            # Check 8-connected neighbors
            for dy in [-1, 0, 1]:
                for dx in [-1, 0, 1]:
                    if dy == 0 and dx == 0:
                        continue

                    ny, nx_coord = y + dy, x + dx # Renamed nx to avoid conflict
                    if (0 <= ny < segmentation_map.shape[0] and
                        0 <= nx_coord < segmentation_map.shape[1]):

                        neighbor_segment = segmentation_map[ny, nx_coord]
                        if neighbor_segment != current_segment:
                            # Add edge if not already present
                            if not G.has_edge(current_segment, neighbor_segment):
                                G.add_edge(current_segment, neighbor_segment)

    elif connectivity_type == 'spatial':
        # Connect segments based on spatial distance
        segment_features = extract_segment_features(
            Image.fromarray(np.zeros_like(segmentation_map)), segmentation_map
        )

        threshold_distance = min(segmentation_map.shape) * 0.2  # 20% of image dimension

        for i, seg1 in enumerate(unique_segments):
            for seg2 in unique_segments[i+1:]:
                centroid1 = segment_features[seg1]['centroid']
                centroid2 = segment_features[seg2]['centroid']

                distance = np.sqrt((centroid1[0] - centroid2[0])**2 +
                                 (centroid1[1] - centroid2[1])**2)

                if distance < threshold_distance:
                    G.add_edge(seg1, seg2, weight=1.0/distance)

    return G

def perform_graph_clustering(G, method='louvain', n_clusters=None):
    """Perform clustering on the graph"""

    if method == 'louvain':
        # Community detection using Louvain algorithm
        try:
            # Ensure community is imported if needed.
            # The install command should handle this.
            import community as community_louvain
            partition = community_louvain.best_partition(G)
            clusters = defaultdict(list)
            for node, cluster_id in partition.items():
                clusters[cluster_id].append(node)
            return dict(clusters)
        except ImportError:
            print("Community detection library not available. Using spectral clustering.")
            method = 'spectral'

    if method == 'spectral':
        # Spectral clustering
        if len(G.nodes) < 2:
            return {0: list(G.nodes)}

        if n_clusters is None:
            n_clusters = min(5, len(G.nodes))

        # Convert graph to adjacency matrix
        adjacency_matrix = nx.adjacency_matrix(G).toarray()

        if adjacency_matrix.sum() == 0:  # No edges
            # Assign each node to its own cluster
            return {i: [node] for i, node in enumerate(G.nodes)}

        clustering = SpectralClustering(
            n_clusters=n_clusters,
            affinity='precomputed',
            random_state=42
        )

        cluster_labels = clustering.fit_predict(adjacency_matrix)

        clusters = defaultdict(list)
        for node, cluster_id in zip(G.nodes, cluster_labels):
            clusters[cluster_id].append(node)

        return dict(clusters)

    elif method == 'kmeans':
        # K-means clustering on node features
        if n_clusters is None:
            n_clusters = min(5, len(G.nodes))

        # Use node positions as features (if available)
        node_features = []
        nodes = list(G.nodes)

        for node in nodes:
            # Use node ID as feature (simplified)
            node_features.append([node])

        if len(node_features) < n_clusters:
            n_clusters = len(node_features)

        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        cluster_labels = kmeans.fit_predict(node_features)

        clusters = defaultdict(list)
        for node, cluster_id in zip(nodes, cluster_labels):
            clusters[cluster_id].append(node)

        return dict(clusters)

def create_cluster_visualization(segmentation_map, clusters, original_image):
    """Create visualization of clustered segments"""
    cluster_map = np.zeros_like(segmentation_map)

    # Assign cluster colors
    colors = plt.cm.Set3(np.linspace(0, 1, len(clusters)))

    for cluster_id, segments in clusters.items():
        for segment in segments:
            cluster_map[segmentation_map == segment] = cluster_id

    # Create colored cluster map
    colored_cluster = np.zeros((*cluster_map.shape, 3), dtype=np.uint8)

    for cluster_id in range(len(clusters)):
        mask = cluster_map == cluster_id
        colored_cluster[mask] = (colors[cluster_id][:3] * 255).astype(np.uint8)

    return cluster_map, colored_cluster

# ============================================================================
# 6. Interactive Graph Visualization
# ============================================================================

def create_interactive_graph(G, clusters, segment_features, ade20k_labels):
    """Create interactive graph visualization using Plotly"""

    # Create layout
    # Ensure nx is accessible here
    pos = nx.spring_layout(G, k=3, iterations=50)

    # Prepare node data
    node_x = []
    node_y = []
    node_text = []
    node_colors = []
    node_sizes = []

    # Color map for clusters
    cluster_colors = px.colors.qualitative.Set3

    # Create cluster mapping
    node_to_cluster = {}
    for cluster_id, segments in clusters.items():
        for segment in segments:
            node_to_cluster[segment] = cluster_id

    for node in G.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)

        # Get segment info
        features = segment_features.get(node, {})
        area = features.get('area', 0)
        centroid = features.get('centroid', (0, 0))
        mean_color = features.get('mean_color', [0, 0, 0])

        # Get class label
        class_label = ade20k_labels.get(node, f"Class {node}")

        # Get cluster
        cluster_id = node_to_cluster.get(node, 0)

        node_text.append(
            f"Segment: {node}<br>"
            f"Class: {class_label}<br>"
            f"Cluster: {cluster_id}<br>"
            f"Area: {area} pixels<br>"
            f"Centroid: ({centroid[0]:.1f}, {centroid[1]:.1f})<br>"
            f"Mean Color: RGB({mean_color[0]:.0f}, {mean_color[1]:.0f}, {mean_color[2]:.0f})"
        )

        node_colors.append(cluster_colors[cluster_id % len(cluster_colors)])
        node_sizes.append(max(10, min(50, area / 100)))  # Scale size by area

    # Prepare edge data
    edge_x = []
    edge_y = []

    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])

    # Create edge trace
    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=1, color='rgba(125,125,125,0.5)'),
        hoverinfo='none',
        mode='lines'
    )

    # Create node trace
    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers+text',
        hoverinfo='text',
        text=[str(node) for node in G.nodes()],
        textposition="middle center",
        hovertext=node_text,
        marker=dict(
            size=node_sizes,
            color=node_colors,
            line=dict(width=2, color='black'),
            opacity=0.8
        )
    )

    # Create figure
    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title='Segment Adjacency Graph with Clustering',
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20,l=5,r=5,t=40),
                        annotations=[ dict(
                            text="Node size represents segment area. Colors represent clusters.",
                            showarrow=False,
                            xref="paper", yref="paper",
                            x=0.005, y=-0.002,
                            xanchor='left', yanchor='bottom',
                            font=dict(size=12)
                        )],
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                    )

    return fig

# ============================================================================
# 7. Main Processing Pipeline
# ============================================================================

# ADE20K class labels (subset)
ade20k_labels = {
    0: 'wall', 1: 'building', 2: 'sky', 3: 'floor', 4: 'tree', 5: 'ceiling', 6: 'road',
    7: 'bed', 8: 'windowpane', 9: 'grass', 10: 'cabinet', 11: 'sidewalk', 12: 'person',
    13: 'earth', 14: 'door', 15: 'table', 16: 'mountain', 17: 'plant', 18: 'curtain',
    19: 'chair', 20: 'car', 21: 'water', 22: 'painting', 23: 'sofa', 24: 'shelf',
    25: 'house', 26: 'sea', 27: 'mirror', 28: 'rug', 29: 'field', 30: 'armchair'
}

# Sample image URLs
sample_urls = [
    "https://images.unsplash.com/photo-1506905925346-21bda4d32df4?w=512",  # Landscape
    "https://images.unsplash.com/photo-1449824913935-59a10b8d2000?w=512",  # City
    "https://images.unsplash.com/photo-1551782450-a2132b4ba21d?w=512",   # Food
]

# Load and process image
image_url = sample_urls[0]  # Change index to try different images
image = load_image_from_url(image_url)

print(f"Original image size: {image.size}")

# Perform segmentation
segmentation_map = segment_image(image, processor, model, device)

print(f"Segmentation map shape: {segmentation_map.shape}")
print(f"Number of unique segments: {len(np.unique(segmentation_map))}")

# Extract segment features
segment_features = extract_segment_features(image, segmentation_map)

# Create adjacency graph
print("Creating adjacency graph...")
G = create_adjacency_graph(segmentation_map, connectivity_type='boundary')

print(f"Graph created with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")

# Perform clustering
print("Performing graph clustering...")
clusters = perform_graph_clustering(G, method='spectral', n_clusters=5)

print(f"Found {len(clusters)} clusters:")
for cluster_id, segments in clusters.items():
    class_names = [ade20k_labels.get(seg, f"Class {seg}") for seg in segments]
    print(f"  Cluster {cluster_id}: {len(segments)} segments - {', '.join(set(class_names))}")

# ============================================================================
# 8. Visualization
# ============================================================================

# Create visualizations
blended_image, colored_seg = visualize_segmentation(image, segmentation_map)
cluster_map, colored_cluster = create_cluster_visualization(segmentation_map, clusters, image)

# Plot results
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Original image
axes[0, 0].imshow(image)
axes[0, 0].set_title('Original Image', fontsize=14)
axes[0, 0].axis('off')

# Segmentation map
axes[0, 1].imshow(segmentation_map, cmap='tab20')
axes[0, 1].set_title('Segmentation Map', fontsize=14)
axes[0, 1].axis('off')

# Blended segmentation
axes[0, 2].imshow(blended_image)
axes[0, 2].set_title('Blended Segmentation', fontsize=14)
axes[0, 2].axis('off')

# Cluster map
axes[1, 0].imshow(cluster_map, cmap='Set3')
axes[1, 0].set_title('Cluster Map', fontsize=14)
axes[1, 0].axis('off')

# Colored clusters
axes[1, 1].imshow(colored_cluster)
axes[1, 1].set_title('Colored Clusters', fontsize=14)
axes[1, 1].axis('off')

# Graph visualization (static)
# Ensure nx is accessible here
pos = nx.spring_layout(G, k=2, iterations=50)
node_colors = []
node_to_cluster = {}
for cluster_id, segments in clusters.items():
    for segment in segments:
        node_to_cluster[segment] = cluster_id

for node in G.nodes():
    cluster_id = node_to_cluster.get(node, 0)
    node_colors.append(cluster_id)

axes[1, 2].set_title('Segment Adjacency Graph', fontsize=14)
# Ensure nx is accessible here
nx.draw(G, pos, ax=axes[1, 2],
        node_color=node_colors,
        node_size=[segment_features.get(node, {}).get('area', 100)/10 for node in G.nodes()],
        cmap='Set3',
        with_labels=True,
        font_size=8,
        edge_color='gray',
        alpha=0.8)

plt.tight_layout()
plt.show()

# Create interactive graph
print("Creating interactive graph visualization...")
interactive_fig = create_interactive_graph(G, clusters, segment_features, ade20k_labels)
interactive_fig.show()

# ============================================================================
# 9. Cluster Analysis
# ============================================================================

def analyze_clusters(clusters, segment_features, ade20k_labels):
    """Analyze cluster properties"""
    print("\n" + "="*50)
    print("CLUSTER ANALYSIS")
    print("="*50)

    for cluster_id, segments in clusters.items():
        print(f"\nCluster {cluster_id}:")
        print(f"  Number of segments: {len(segments)}")

        # Class distribution
        class_counts = Counter([ade20k_labels.get(seg, f"Class {seg}") for seg in segments])
        print(f"  Classes: {dict(class_counts)}")

        # Total area
        total_area = sum([segment_features.get(seg, {}).get('area', 0) for seg in segments])
        print(f"  Total area: {total_area} pixels")

        # Average compactness
        compactness_values = [segment_features.get(seg, {}).get('compactness', 0) for seg in segments]
        avg_compactness = np.mean(compactness_values) if compactness_values else 0
        print(f"  Average compactness: {avg_compactness:.4f}")

        # Dominant color
        colors = [segment_features.get(seg, {}).get('mean_color', [0,0,0]) for seg in segments]
        if colors:
            avg_color = np.mean(colors, axis=0)
            print(f"  Average color: RGB({avg_color[0]:.0f}, {avg_color[1]:.0f}, {avg_color[2]:.0f})")

analyze_clusters(clusters, segment_features, ade20k_labels)

# ============================================================================
# 10. Save Results
# ============================================================================

def save_graph_results(G, clusters, segment_features, filename_prefix="graph_result"):
    """Save graph and clustering results"""

    # Save graph as GraphML
    # Ensure nx is accessible here
    nx.write_graphml(G, f"{filename_prefix}_graph.graphml")

    # Save cluster information
    with open(f"{filename_prefix}_clusters.txt", 'w') as f:
        f.write("Cluster Analysis Results\n")
        f.write("="*30 + "\n\n")

        for cluster_id, segments in clusters.items():
            f.write(f"Cluster {cluster_id}:\n")
            f.write(f"  Segments: {segments}\n")
            f.write(f"  Size: {len(segments)}\n")

            # Calculate cluster properties
            total_area = sum([segment_features.get(seg, {}).get('area', 0) for seg in segments])
            f.write(f"  Total area: {total_area} pixels\n")

            class_names = [ade20k_labels.get(seg, f"Class {seg}") for seg in segments]
            f.write(f"  Classes: {set(class_names)}\n\n")

    print(f"Graph results saved with prefix: {filename_prefix}")

# Save results
save_graph_results(G, clusters, segment_features, "segmentation_graph")

# Display graph statistics
print(f"\nGraph Statistics:")
print(f"  Nodes: {G.number_of_nodes()}")
print(f"  Edges: {G.number_of_edges()}")
# Ensure nx is accessible here
print(f"  Density: {nx.density(G):.4f}")
# Ensure nx is accessible here
print(f"  Connected components: {nx.number_connected_components(G)}")
print(f"  Clusters found: {len(clusters)}")

print("\nCompleted! Graph analysis of segmented image finished.")