# ControlNet Quality Metric using SAM Segmentation & Hungarian Matching

This notebook evaluates the quality of ControlNet image generation by:
1. **Extracting segments** from the original COCO segmentation map (colored regions)
2. **Segmenting** the two generated images using SAM (Segment Anything Model):
   - ControlNet + Spatial Conditioning (conditioned on segmentation maps)
   - ControlNet without Spatial Conditioning
3. **Treating all segments as class-agnostic** (same class)
4. **Using Hungarian algorithm** to find optimal segment matching
5. **Computing maximum IoU** between matched segments

**Key Insight**: The original image is a segmentation map from COCO dataset with colored regions. We extract masks from these colored regions and match them against SAM-generated masks from the two generated images using Hungarian matching to maximize IoU.

## 1. Install Required Dependencies

In [None]:
!pip install -q torch torchvision
!pip install -q opencv-python matplotlib numpy scipy
!pip install -q git+https://github.com/facebookresearch/segment-anything.git
!pip install -q pillow

## 2. Import Libraries

In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
from scipy.optimize import linear_sum_assignment
import torch
from pathlib import Path
from typing import List, Tuple, Dict
import warnings
warnings.filterwarnings('ignore')


from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 3. Download and Load SAM Model



In [None]:
import urllib.request
import os


SAM_MODELS = {
    'vit_h': {
        'url': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
        'checkpoint': 'sam_vit_h_4b8939.pth'
    },
    'vit_l': {
        'url': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
        'checkpoint': 'sam_vit_l_0b3195.pth'
    },
    'vit_b': {
        'url': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
        'checkpoint': 'sam_vit_b_01ec64.pth'
    }
}


MODEL_TYPE = 'vit_h'
checkpoint_path = SAM_MODELS[MODEL_TYPE]['checkpoint']


if not os.path.exists(checkpoint_path):
    print(f"Downloading SAM {MODEL_TYPE} model...")
    urllib.request.urlretrieve(SAM_MODELS[MODEL_TYPE]['url'], checkpoint_path)
    print("Download complete!")
else:
    print(f"SAM model already downloaded: {checkpoint_path}")


device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Loading SAM model on {device}...")
sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint_path)
sam.to(device=device)
print("SAM model loaded successfully!")

In [None]:
original_segmentation_map = "/content/input.jpeg"  # COCO segmentation map
controlnet_spatial_image = "/content/controlnetwithspatial.jpeg"
controlnet_no_spatial_image = "/content/withoutspatial.jpeg"

for path in [original_segmentation_map, controlnet_spatial_image, controlnet_no_spatial_image]:
    if Path(path).exists():
        print(f"✓ Found: {path}")
    else:
        print(f"✗ NOT FOUND: {path}")
        print("  Please update the path above!")

## 4. Helper Functions for Segmentation & Matching

In [None]:
def extract_masks_from_segmentation_map(seg_map_path: str) -> List[np.ndarray]:
    """
    Extract individual masks from a colored segmentation map (e.g., COCO format).
    Each unique color represents a different segment.

    Args:
        seg_map_path: Path to segmentation map image

    Returns:
        List of binary masks, one for each colored region
    """
    image = cv2.imread(seg_map_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Reshape to list of pixels
    pixels = image_rgb.reshape(-1, 3)

    # Find unique colors (excluding black background)
    unique_colors = np.unique(pixels, axis=0)

    # Remove black (background) - assuming background is [0, 0, 0]
    unique_colors = unique_colors[~np.all(unique_colors == [0, 0, 0], axis=1)]

    print(f"  Found {len(unique_colors)} unique colored regions")

    # Create binary mask for each color
    masks = []
    h, w = image_rgb.shape[:2]

    for color in unique_colors:
        # Create mask where this color exists
        mask = np.all(image_rgb == color, axis=-1).astype(np.uint8)

        # Only keep masks with reasonable size
        if np.sum(mask) > 100:  # At least 100 pixels
            masks.append(mask)

    # Sort by area (largest first)
    masks = sorted(masks, key=lambda m: m.sum(), reverse=True)

    return masks


def segment_image_with_sam(image_path: str, sam_model) -> List[np.ndarray]:
    """
    Segment an image using SAM automatic mask generation.

    Args:
        image_path: Path to image
        sam_model: Loaded SAM model

    Returns:
        List of binary masks
    """
    # Read image
    image = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Generate masks automatically
    mask_generator = SamAutomaticMaskGenerator(
        model=sam_model,
        points_per_side=32,
        pred_iou_thresh=0.86,
        stability_score_thresh=0.92,
        crop_n_layers=1,
        crop_n_points_downscale_factor=2,
        min_mask_region_area=100,  # Minimum pixels for a valid mask
    )

    print(f"  Generating masks with SAM for {Path(image_path).name}...")
    masks = mask_generator.generate(image_rgb)

    # Extract binary masks
    binary_masks = [mask['segmentation'].astype(np.uint8) for mask in masks]

    # Sort by area (largest first)
    binary_masks = sorted(binary_masks, key=lambda m: m.sum(), reverse=True)

    return binary_masks


def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
    """
    Compute Intersection over Union between two binary masks.
    """
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()

    if union == 0:
        return 0.0

    return intersection / union


def build_cost_matrix(masks1: List[np.ndarray], masks2: List[np.ndarray]) -> np.ndarray:
    """
    Build cost matrix for Hungarian algorithm.
    Cost = 1 - IoU (minimize cost = maximize IoU)
    """
    n1, n2 = len(masks1), len(masks2)
    cost_matrix = np.zeros((n1, n2))

    for i in range(n1):
        for j in range(n2):
            iou = compute_iou(masks1[i], masks2[j])
            cost_matrix[i, j] = 1 - iou  # Convert to cost

    return cost_matrix


def hungarian_matching(masks1: List[np.ndarray], masks2: List[np.ndarray]) -> Tuple[float, List[Tuple[int, int]], List[float]]:
    """
    Find optimal matching using Hungarian algorithm.

    Returns:
        (mean_iou, matching_pairs, individual_ious)
    """
    if len(masks1) == 0 or len(masks2) == 0:
        return 0.0, [], []

    # Build cost matrix
    cost_matrix = build_cost_matrix(masks1, masks2)

    # Apply Hungarian algorithm
    row_indices, col_indices = linear_sum_assignment(cost_matrix)

    # Calculate IoU for matched pairs
    ious = []
    matching_pairs = []

    for i, j in zip(row_indices, col_indices):
        iou = compute_iou(masks1[i], masks2[j])
        ious.append(iou)
        matching_pairs.append((i, j))

    mean_iou = np.mean(ious) if ious else 0.0

    return mean_iou, matching_pairs, ious


print("✓ Helper functions defined!")

## 5. Extract Masks from Original Segmentation Map & Segment Generated Images

- **Original**: Extract colored regions (no SAM needed)
- **Generated Images**: Run SAM segmentation

In [None]:

print("EXTRACTING MASKS & SEGMENTING IMAGES")

# Extract masks from original COCO segmentation map
print("\n1. Original Segmentation Map (COCO)")
print("   Extracting colored regions...")
masks_original = extract_masks_from_segmentation_map(original_segmentation_map)
print(f"   → Extracted {len(masks_original)} segments")

# Segment ControlNet + spatial conditioning with SAM
print("\n2. ControlNet + Spatial Conditioning")
masks_spatial = segment_image_with_sam(controlnet_spatial_image, sam)
print(f"   → Found {len(masks_spatial)} segments")

# Segment ControlNet without spatial conditioning with SAM
print("\n3. ControlNet without Spatial Conditioning")
masks_no_spatial = segment_image_with_sam(controlnet_no_spatial_image, sam)
print(f"   → Found {len(masks_no_spatial)} segments")


print("Mask extraction & segmentation complete!")


## 6. Visualize Segmentation Results

In [None]:
def visualize_masks(image_path: str, masks: List[np.ndarray], title: str, is_segmap: bool = False):
    """Visualize image with segmentation overlay."""
    image = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Create colored overlay
    overlay = np.zeros_like(image_rgb)
    colors = plt.cm.tab20(np.linspace(0, 1, len(masks)))

    for idx, mask in enumerate(masks[:20]):  # Show top 20 segments
        color = (colors[idx][:3] * 255).astype(np.uint8)
        overlay[mask > 0] = color

    # For segmentation map, show it as-is, otherwise blend
    if is_segmap:
        blended = image_rgb  # Show original colored segments
    else:
        blended = cv2.addWeighted(image_rgb, 0.5, overlay, 0.5, 0)

    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    plt.imshow(image_rgb)
    plt.title(f'{title}\n(Original)')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(overlay)
    plt.title(f'Extracted Masks\n({len(masks)} total)')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(blended)
    plt.title('Segmentation Map' if is_segmap else 'Overlay')
    plt.axis('off')

    plt.tight_layout()
    plt.show()


visualize_masks(original_segmentation_map, masks_original, "Original Segmentation Map", is_segmap=True)
visualize_masks(controlnet_spatial_image, masks_spatial, "ControlNet + Spatial")
visualize_masks(controlnet_no_spatial_image, masks_no_spatial, "ControlNet - No Spatial")

## 7. Run Hungarian Matching & Compute IoU

This is the core evaluation step:
1. Match segments from **Original** ↔ **ControlNet+Spatial**
2. Match segments from **Original** ↔ **ControlNet-NoSpatial**
3. Compare IoU scores to determine quality

In [None]:

# Matching 1: Original vs ControlNet + Spatial
print("\n Original vs ControlNet + Spatial Conditioning")

mean_iou_spatial, pairs_spatial, ious_spatial = hungarian_matching(masks_original, masks_spatial)
print(f"   Matched pairs: {len(pairs_spatial)}")
print(f"   Mean IoU: {mean_iou_spatial:.4f}")
print(f"   Max IoU: {max(ious_spatial):.4f}" if ious_spatial else "   Max IoU: 0.0000")
print(f"   Min IoU: {min(ious_spatial):.4f}" if ious_spatial else "   Min IoU: 0.0000")

# Matching 2: Original vs ControlNet - No Spatial
print("\n  Original vs ControlNet - No Spatial Conditioning")
print("-" * 70)
mean_iou_no_spatial, pairs_no_spatial, ious_no_spatial = hungarian_matching(masks_original, masks_no_spatial)
print(f"   Matched pairs: {len(pairs_no_spatial)}")
print(f"   Mean IoU: {mean_iou_no_spatial:.4f}")
print(f"   Max IoU: {max(ious_no_spatial):.4f}" if ious_no_spatial else "   Max IoU: 0.0000")
print(f"   Min IoU: {min(ious_no_spatial):.4f}" if ious_no_spatial else "   Min IoU: 0.0000")


## 8. Final Results & Comparison

In [None]:
improvement = mean_iou_spatial - mean_iou_no_spatial
improvement_percent = (improvement / mean_iou_no_spatial * 100) if mean_iou_no_spatial > 0 else 0


print(" FINAL QUALITY METRICS")

print(f"\n IoU Scores:")
print(f"   Original vs ControlNet+Spatial:    {mean_iou_spatial:.4f}")
print(f"   Original vs ControlNet-NoSpatial:  {mean_iou_no_spatial:.4f}")
print(f"\n Improvement from Spatial Conditioning:")
print(f"   Absolute: {improvement:+.4f}")
print(f"   Relative: {improvement_percent:+.2f}%")

if improvement > 0:
    print("\nRESULT: Spatial conditioning IMPROVES spatial structure preservation")
    print(f"   → ControlNet with segmentation maps better preserves object layout!")
elif improvement < 0:
    print("\n RESULT: Spatial conditioning DECREASES spatial structure preservation")
    print(f"   → Unexpected: spatial conditioning performed worse")
else:
    print("\n RESULT: No difference detected")
    print(f"   → Both methods preserve spatial structure equally")



## 9. Detailed Visualization of Results

In [None]:

fig, axes = plt.subplots(2, 3, figsize=(18, 12))

img_orig = cv2.cvtColor(cv2.imread(original_segmentation_map), cv2.COLOR_BGR2RGB)
img_spatial = cv2.cvtColor(cv2.imread(controlnet_spatial_image), cv2.COLOR_BGR2RGB)
img_no_spatial = cv2.cvtColor(cv2.imread(controlnet_no_spatial_image), cv2.COLOR_BGR2RGB)

# Row 1: Original images
axes[0, 0].imshow(img_orig)
axes[0, 0].set_title('Original Segmentation Map\n(COCO Dataset)', fontsize=14, fontweight='bold')
axes[0, 0].axis('off')

axes[0, 1].imshow(img_spatial)
axes[0, 1].set_title(f'ControlNet + Spatial\nIoU: {mean_iou_spatial:.4f}',
                     fontsize=14, fontweight='bold', color='green' if improvement > 0 else 'red')
axes[0, 1].axis('off')

axes[0, 2].imshow(img_no_spatial)
axes[0, 2].set_title(f'ControlNet - No Spatial\nIoU: {mean_iou_no_spatial:.4f}',
                     fontsize=14, fontweight='bold')
axes[0, 2].axis('off')

# Row 2: Mask overlays
def create_overlay(masks, image, blend: bool = True):
    overlay = np.zeros_like(image)
    colors = plt.cm.tab20(np.linspace(0, 1, min(len(masks), 20)))
    for idx, mask in enumerate(masks[:20]):
        color = (colors[idx][:3] * 255).astype(np.uint8)
        overlay[mask > 0] = color
    if blend:
        return cv2.addWeighted(image, 0.4, overlay, 0.6, 0)
    return overlay

axes[1, 0].imshow(create_overlay(masks_original, img_orig, blend=False))
axes[1, 0].set_title(f'Extracted Masks\n({len(masks_original)} regions)', fontsize=12)
axes[1, 0].axis('off')

axes[1, 1].imshow(create_overlay(masks_spatial, img_spatial))
axes[1, 1].set_title(f'SAM Segments\n({len(masks_spatial)} total)', fontsize=12)
axes[1, 1].axis('off')

axes[1, 2].imshow(create_overlay(masks_no_spatial, img_no_spatial))
axes[1, 2].set_title(f'SAM Segments\n({len(masks_no_spatial)} total)', fontsize=12)
axes[1, 2].axis('off')

plt.suptitle('Segmentation Map vs Generated Images - Quality Comparison', fontsize=16, fontweight='bold', y=0.98)
plt.tight_layout()
plt.show()