In [1]:
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset

class CocoPerClass(torch.utils.data.Dataset):
    def __init__(self, coco_det, coco_caps, dino_transform=None, clip_transform=None, binarize=True):
        self.coco_det = coco_det
        self.coco_caps = coco_caps
        self.coco = coco_det.coco
        self.dino_transform = dino_transform
        self.clip_transform = clip_transform
        self.binarize = binarize
        
        # Build flat index: [(img_idx, cat_id), ...]
        self._flat_index = []
        for img_idx, img_id in enumerate(coco_det.ids):
            cat_ids = {ann["category_id"] for ann in self.coco.imgToAnns[img_id]}
            self._flat_index.extend((img_idx, cid) for cid in cat_ids)
        
        # ID to name mapping
        self.cat_id2name = {
            c["id"]: c["name"] for c in self.coco.loadCats(self.coco.getCatIds())
        }
        
        # Build image_id to captions mapping
        self.img_id_to_captions = {}
        for i, img_id in enumerate(self.coco_caps.ids):
            _, captions = self.coco_caps[i]
            self.img_id_to_captions[img_id] = captions

    def __len__(self):
        return len(self._flat_index)

    def __getitem__(self, idx):
        img_idx, cat_id = self._flat_index[idx]
        img, anns = self.coco_det[img_idx]
        img_id = self.coco_det.ids[img_idx]
        
        # Keep original image
        original_img = img.copy()
        
        # Build per-class mask
        H, W = img.height, img.width
        mask = np.zeros((H, W), dtype=np.uint8)
        
        for ann in anns:
            if ann["category_id"] == cat_id:
                ann_mask = self.coco.annToMask(ann).astype(bool)
                mask[ann_mask] = 1 if self.binarize else cat_id
        
        # Convert mask to tensor (don't transform it)
        mask = torch.as_tensor(mask, dtype=torch.uint8)
        
        # Apply transforms to images only
        dino_img = img.copy()
        clip_img = img.copy()
        
        if self.dino_transform is not None:
            dino_img = self.dino_transform(dino_img)
        
        if self.clip_transform is not None:
            clip_img = self.clip_transform(clip_img)
        
        # Get captions for this image
        captions = self.img_id_to_captions.get(img_id, [])
        
        return {
            'dino_image': dino_img,
            'clip_image': clip_img,
            'image': original_img,
            'category_id': cat_id,
            'mask': mask,
            'category_name': self.cat_id2name[cat_id],
            'image_id': img_id,
            'captions': captions,
            'idx': idx
        }

def collate_fn(batch):
    dino_images = torch.stack([item['dino_image'] for item in batch])
    clip_images = torch.stack([item['clip_image'] for item in batch])
    original_images = [item['image'] for item in batch]
    masks = [item['mask'] for item in batch]
    category_ids = [item['category_id'] for item in batch]
    category_names = [item['category_name'] for item in batch]
    image_ids = [item['image_id'] for item in batch]
    captions_text = [item['captions'] for item in batch]
    indices = [item['idx'] for item in batch]
    
    return {
        'dino_images': dino_images,
        'clip_images': clip_images,
        'original_images': original_images,
        'masks': masks,
        'category_ids': category_ids,
        'category_names': category_names,
        'image_ids': image_ids,
        'captions_text': captions_text,
        'indices': indices
    }

In [2]:
import os, gc, h5py, torch, torchvision.transforms as T
from tqdm.auto import tqdm
import open_clip                                            
from pathlib import Path
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

DINO_MODEL     = 'dinov2_vitl14_reg'          
CLIP_MODEL     = 'hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K'
DINO_TAG     = 'dinov2_vitl14_reg'          
CLIP_TAG     = 'CLIP-ViT-L-14-DataComp'
BATCH_SIZE     = 32

IMAGENET_MEAN, IMAGENET_STD = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)

class MaybeToTensor(T.ToTensor):
    def __call__(self, pic):
        return pic if isinstance(pic, torch.Tensor) else super().__call__(pic)

normalize = T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
dino_transform = T.Compose([
    T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC), # original dino does it to 256 but it messes up the aspect ratio
#     T.CenterCrop(224),
    MaybeToTensor(),
    normalize,
])

model_clip, _, preprocess_clip = open_clip.create_model_and_transforms(CLIP_MODEL)

transform_list = preprocess_clip.transforms
clip_transform = T.Compose([T.Resize(size=(224, 224), interpolation=T.InterpolationMode.BICUBIC, antialias=True)] 
                            + transform_list[2:])

clip_tokenizer = open_clip.get_tokenizer(CLIP_MODEL)

model_clip = model_clip.to(device).eval()
model_dino = torch.hub.load('facebookresearch/dinov2', DINO_MODEL).to(device).eval()

Using cache found in /home/a_n29343/.cache/torch/hub/facebookresearch_dinov2_main


In [3]:
from sparc.model.model_global import MultiStreamSparseAutoencoder
import json
# Infer modality dimensions (needed to create MS-SAE skeleton)
d_streams = {
    'dino'     : model_dino.num_features,             # DINOv2 CLS/pool dim
    'clip_img' : model_clip.visual.output_dim,
    'clip_txt' : model_clip.visual.output_dim,
}

print('Loading MS-SAE checkpoint...')

SAE_CHECKPOINT = Path('../../final_results/msae_open_global_with_cross/msae_checkpoint.pth')
with open('../../final_results/msae_open_global_with_cross/run_config.json', 'r') as f:
    config = json.load(f)
msae = MultiStreamSparseAutoencoder(
    d_streams=d_streams,
    n_latents=config['args']['n_latents'],               # MUST match your training args
    k=config['args']['k'],
).to(device)
msae.load_state_dict(torch.load(SAE_CHECKPOINT, map_location=device, weights_only=False))
msae.eval()

Loading MS-SAE checkpoint...


MultiStreamSparseAutoencoder(
  (encoders): ModuleDict(
    (dino): Linear(in_features=1024, out_features=8192, bias=False)
    (clip_img): Linear(in_features=768, out_features=8192, bias=False)
    (clip_txt): Linear(in_features=768, out_features=8192, bias=False)
  )
  (decoders): ModuleDict(
    (dino): Linear(in_features=8192, out_features=1024, bias=False)
    (clip_img): Linear(in_features=8192, out_features=768, bias=False)
    (clip_txt): Linear(in_features=8192, out_features=768, bias=False)
  )
  (pre_biases): ParameterDict(
      (dino): Parameter containing: [torch.cuda.FloatTensor of size 1024 (cuda:0)]
      (clip_img): Parameter containing: [torch.cuda.FloatTensor of size 768 (cuda:0)]
      (clip_txt): Parameter containing: [torch.cuda.FloatTensor of size 768 (cuda:0)]
  )
  (latent_biases): ParameterDict(
      (dino): Parameter containing: [torch.cuda.FloatTensor of size 8192 (cuda:0)]
      (clip_img): Parameter containing: [torch.cuda.FloatTensor of size 8192 (cuda:

### Get the latents

In [4]:
def get_coco_latents(
    coco_dataset,
    model_dino,
    model_clip,
    msae,
    clip_tokenizer,
    batch_size: int = 32,
    num_workers: int = 4
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    def custom_collate(batch):
        collated_batch = collate_fn(batch)
        
        first_captions = []
        for cap_list in collated_batch['captions_text']:
            first_captions.append(cap_list[0] if cap_list else "")
        
        if first_captions:
            tokenized_captions = clip_tokenizer(first_captions)
            collated_batch['tokenized_captions'] = tokenized_captions
        
        return collated_batch

    val_loader = DataLoader(
        coco_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=custom_collate,
        pin_memory=True
    )

    all_results = []

    with torch.no_grad():
        for batch_data in tqdm(val_loader, desc="Extracting latents"):
            dino_images = batch_data['dino_images'].to(device)
            clip_images = batch_data['clip_images'].to(device)
            tokenized_captions = batch_data['tokenized_captions'].to(device)
            image_ids = batch_data['image_ids']
            captions_text_batch = batch_data['captions_text']

            dino_features = model_dino(dino_images)
            if isinstance(dino_features, tuple):
                dino_features = dino_features[0]

            clip_img_features = model_clip.encode_image(clip_images)
            if isinstance(clip_img_features, tuple):
                clip_img_features = clip_img_features[0]

            clip_txt_features = model_clip.encode_text(tokenized_captions)

            msae_inputs = {
                'dino': dino_features,
                'clip_img': clip_img_features,
                'clip_txt': clip_txt_features
            }

            msae_outputs = msae(msae_inputs)

            batch_latents_data = {}
            for stream_name in d_streams.keys():
                sparse_code_key = f'sparse_codes_{stream_name}'
                if sparse_code_key in msae_outputs:
                    batch_latents_data[stream_name] = msae_outputs[sparse_code_key].cpu().numpy()

            for i in range(len(image_ids)):
                sample_latents = {}
                for stream_name, all_stream_latents_batch in batch_latents_data.items():
                    sample_latents[stream_name] = all_stream_latents_batch[i]
                
                all_results.append({
                    'image_id': image_ids[i],
                    'captions': captions_text_batch[i],
                    'category_id': batch_data['category_ids'][i],
                    'category_name': batch_data['category_names'][i],
                    'latents': sample_latents
                })

    return all_results

In [6]:
from torchvision.datasets import CocoDetection, CocoCaptions
from torchvision import transforms

img_dir  = "../../dataset/COCO/val2017"
det_ann_file = "../../dataset/COCO/annotations/instances_val2017.json"
cap_ann_file = "../../dataset/COCO/annotations/captions_val2017.json"

coco_det = CocoDetection(root=img_dir, annFile=det_ann_file)
coco_caps = CocoCaptions(root=img_dir, annFile=cap_ann_file)
dataset = CocoPerClass(coco_det, coco_caps, dino_transform, clip_transform)

loading annotations into memory...
Done (t=0.60s)
creating index...
index created!
loading annotations into memory...
Done (t=0.03s)
creating index...
index created!


In [7]:
import pickle

In [8]:
os.makedirs("../../final_results/", exist_ok=True)

latent_file_path = '../../final_results/segmentation/coco_latents_global.pickle'
if os.path.isfile(latent_file_path):
    with open(latent_file_path, 'rb') as f:
        all_latents = pickle.load(f)
else:
    all_latents = get_coco_latents(dataset, model_dino, model_clip, msae, clip_tokenizer)
    with open(latent_file_path, 'wb') as f:
        pickle.dump(all_latents, f)

In [9]:
latent_class_file_path = '../../final_results/segmentation/coco_latents_count_per_class_global.pickle'
if os.path.isfile(latent_class_file_path):
    with open(latent_class_file_path, 'rb') as f:
        latent_class_count = pickle.load(f)
else:
    with open(latent_class_file_path, 'wb') as f:
        num_latents = len(all_latents[0]['latents']['dino'])
        num_classes = max(dataset.cat_id2name.keys())+1
        num_streams = len(all_latents[0]['latents'])
        latent_class_count = np.zeros((num_streams, num_latents, num_classes))

        for i in tqdm(range(len(all_latents))):
            for stream_idx, stream in enumerate(all_latents[i]['latents']):
                active_idx = np.where(all_latents[i]['latents'][stream]>0)[0]
                class_id = dataset[i]['category_id']
                latent_class_count[stream_idx, active_idx, class_id] += 1
        pickle.dump(latent_class_count, f)

### Heatmaps

In [10]:
from tqdm import tqdm
from scipy.sparse import csr_matrix
from scipy.sparse import vstack
from itertools import combinations
import numpy as np

import matplotlib.pyplot as plt

from sparc.heatmaps.attention_relevance import interpret_sparc, interpret_clip
from sparc.heatmaps.gradcam import compute_gradcam
from sparc.heatmaps.visualization import (show_clean_text_attribution, 
                                    plot_all_attributions, 
                                    show_clean_gradcam_text,  
                                    plot_relevancy_attributions)

from sparc.heatmaps.attention_relevance import get_all_latents
from sparc.heatmaps.attention_relevance import get_attention_blocks, compute_attention_relevancy
from sparc.heatmaps.clip import create_wrapped_clip, patch_clip_keep_last
from sparc.heatmaps.dino import create_wrapped_dinov2, patch_dinov2_keep_last
from sparc.heatmaps.attention_relevance import compute_attention_relevancy

In [11]:
def get_article(word):
    return "An" if word[0].lower() in 'aeiou' else "A"

In [14]:
def get_image_mask(image_relevance):
    batch_size, spatial_size = image_relevance.shape
    dim = int(spatial_size ** 0.5)
    image_relevance = image_relevance.reshape(batch_size, 1, dim, dim)
    image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
    image_relevance = image_relevance.squeeze(1).cuda().data.cpu().numpy()
    flat = image_relevance.reshape(batch_size, -1)
    return ((flat - flat.min(1, keepdims=True)) 
            / (flat.max(1, keepdims=True) - flat.min(1, keepdims=True))).reshape(batch_size, 224, 224)

In [15]:
model_clip = create_wrapped_clip(CLIP_MODEL, device)
model_clip = model_clip.enable_attention_capture()
patch_clip_keep_last(model_clip, last_layer=23)

model_dino = create_wrapped_dinov2(DINO_MODEL, device)
model_dino.enable_attention_capture()
patch_dinov2_keep_last(model_dino, last_layer=23)

Using cache found in /home/a_n29343/.cache/torch/hub/facebookresearch_dinov2_main


In [16]:
global_msae = True

In [18]:
from torch import nn
from typing import Tuple
def interpret_clip_sim(
    image: torch.Tensor, 
    texts: torch.Tensor, 
    model: nn.Module, 
    device: str, 
    start_layer: int = -1, 
    start_layer_text: int = -1
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Interpret CLIP model using attention rollout.
    
    Args:
        image: Input image tensor
        texts: Input text tokens
        model: CLIP model
        device: Device to use
        start_layer: Starting layer for image attention rollout
        start_layer_text: Starting layer for text attention rollout
        
    Returns:
        Tuple of (text_relevance, image_relevance)
    """
    batch_size = texts.shape[0]
    images = image.repeat(batch_size, 1, 1, 1)
    # Forward
    image_features = model.encode_image(images)
    text_features = model.encode_text(texts)

    # Normalized features
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    # Cosine similarity as logits
    logit_scale = model.logit_scale.exp()
    logits_per_image = logit_scale * image_features @ text_features.t()
    
    probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
    index = [i for i in range(batch_size)]
    one_hot = np.zeros((logits_per_image.shape[0], logits_per_image.shape[1]), dtype=np.float32)
    one_hot[torch.arange(logits_per_image.shape[0]), index] = 1
    one_hot = torch.from_numpy(one_hot).requires_grad_(True)
    one_hot = torch.sum(one_hot.cuda() * logits_per_image)
    model.zero_grad()

    image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())

    if start_layer == -1:
        start_layer = len(image_attn_blocks) - 1
    
    image_relevance = compute_attention_relevancy(
        one_hot, image_attn_blocks, device, batch_size, start_layer
    )
    image_relevance = image_relevance[:, 0, 1:]
    return image_relevance

In [19]:
from torch.utils.data import DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=0
)

In [20]:
def image_np(np_array):
    return Image.fromarray(np.uint8(np_array*255)).resize((224, 224))

In [None]:
import os
import json
import numpy as np

def convert_numpy_types(obj):
    """Recursively convert numpy types to Python native types"""
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {key: convert_numpy_types(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_types(item) for item in obj]
    else:
        return obj

def save_segmentation_data(data, filepath):
    """Save complete segmentation data for all methods"""
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    
    # Convert numpy types to Python native types
    clean_data = convert_numpy_types(data)
    
    # Add metadata
    save_data = {
        'methods': list(data.keys()),
        'metadata': {
            'iou_threshold': float(IOU_THR),
            'fallback_threshold': float(THR_FALLBACK),
            'total_methods': len(data.keys())
        },
        'data': clean_data
    }
    
    print(f"Saving segmentation data to {filepath}...")
    with open(filepath, 'w') as f:
        json.dump(save_data, f, indent=2)
    print("✓ Segmentation data saved!")

def load_segmentation_data(filepath):
    """Load complete segmentation data for all methods"""
    print(f"Loading segmentation data from {filepath}...")
    
    with open(filepath, 'r') as f:
        save_data = json.load(f)
    
    data = save_data['data']
    metadata = save_data.get('metadata', {})
    
    print(f"✓ Loaded data for {len(data)} methods: {list(data.keys())}")
    for method in data.keys():
        print(f"  - {method}: {len(data[method]['images'])} images, "
              f"{len(data[method]['anns_dt'])} predictions")
    
    return data, metadata


In [None]:
import numpy as np, cv2, torch
from tqdm.auto import tqdm
from skimage.filters import threshold_otsu
from pycocotools import mask as mu
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

IOU_THR = 0.20
THR_FALLBACK = 0.30

def rle(arr):
    """Convert uint8 H×W array to RLE dict (json-safe)"""
    r = mu.encode(np.asfortranarray(arr))
    r["counts"] = r["counts"].decode("ascii")
    return r

def batch_resize(t, size_hw):
    return torch.nn.functional.interpolate(
        t, size=size_hw, mode="bilinear", align_corners=False)

# Define all methods
methods = ['dino_sparc_sum', 'clip_sparc_sum', 'clip_sparc_cross', 'dino_sparc_cross', 'clip_sim']

# Initialize data structures for all methods
data = {}
for method in methods:
    data[method] = {
        'images': [], 
        'anns_gt': [], 
        'anns_dt': [],
        'cls2id': {},
        'ann_id': 1
    }

# Main processing loop
for i, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
    torch.cuda.reset_peak_memory_stats()

    dino_imgs = batch['dino_images'].to(device)
    clip_imgs = batch['clip_images'].to(device)
    masks = batch['masks']
    cls_names = batch['category_names']
    img_ids = batch['image_ids']
    
    # Get latents
    sparc_texts = [f'{i}' for i in cls_names]
    tokenized_sparc_texts = clip_tokenizer(sparc_texts).to(device)
    
    clip_txt_latent, clip_img_latent, dino_latent, batch_size = get_all_latents(
        tokenized_sparc_texts, msae, model_dino, dino_imgs, model_clip, clip_imgs, global_msae
    )
    
    # Compute targets
    class_id = batch['category_ids']

    class_count_per_latent = latent_class_count.sum(0)
    class_latents = np.where(class_count_per_latent[:, class_id]>50)[0]
    
    targets = {
        'dino_sparc_sum': torch.sum(dino_latent[:, class_latents]),
        'clip_sparc_sum': torch.sum(clip_img_latent[:, class_latents]),
        'clip_sparc_cross': clip_txt_latent @ clip_img_latent.T,
        'dino_sparc_cross': clip_txt_latent @ dino_latent.T
    }
    
    model_clip.zero_grad()
    model_dino.zero_grad()
    
    # Get attention blocks
    clip_img_attn_blocks, clip_txt_attn_blocks = get_attention_blocks(model_clip, 'clip')
    dino_img_attn_blocks, _ = get_attention_blocks(model_dino, 'dino')
    
    # Compute relevancies and masks
    masks_dict = {}
    
    # SPARC methods
    for method in ['dino_sparc_sum', 'dino_sparc_cross']:
        relevance = compute_attention_relevancy(
            targets[method], dino_img_attn_blocks, device, batch_size, start_layer=23
        )
        relevance = relevance[:, 0, 5:]  # Skip CLS + register tokens
        masks_dict[method] = get_image_mask(relevance)
    
    for method in ['clip_sparc_sum', 'clip_sparc_cross']:
        relevance = compute_attention_relevancy(
            targets[method], clip_img_attn_blocks, device, batch_size, start_layer=23
        )
        relevance = relevance[:, 0, 1:]  # Skip CLS token
        masks_dict[method] = get_image_mask(relevance)
    
    # CLIP SIM method
    clip_sim_relevance = interpret_clip_sim(clip_imgs, tokenized_sparc_texts, model_clip, device)
    masks_dict['clip_sim'] = get_image_mask(clip_sim_relevance)
    
    # Convert to numpy
    for method in methods:
        if isinstance(masks_dict[method], torch.Tensor):
            masks_dict[method] = masks_dict[method].detach().cpu().numpy()
    
    B = len(img_ids)
    
    # Process each sample
    for b in range(B):
        img_id = img_ids[b]
        gt_mask = masks[b]
        H, W = gt_mask.shape
        cls = cls_names[b]

        if gt_mask.sum() == 0:
            continue
            
        # Prepare ground truth (same for all methods)
        gt_bin = (gt_mask > 0).numpy().astype(np.uint8)
        g_rle = rle(gt_bin)
        ys, xs = np.where(gt_bin)
        x0, x1, y0, y1 = xs.min(), xs.max()+1, ys.min(), ys.max()+1
        
        # Process all methods
        for method in methods:
            # Add ground truth
            cid = data[method]['cls2id'].setdefault(cls, len(data[method]['cls2id'])+1)
            data[method]['images'].append({"id": img_id, "width": W, "height": H})
            data[method]['anns_gt'].append({
                "id": data[method]['ann_id'], "image_id": img_id, "category_id": cid,
                "segmentation": g_rle, "area": int(mu.area(g_rle)),
                "bbox": [x0,y0,x1-x0,y1-y0], "iscrowd": 0
            })
            data[method]['ann_id'] += 1
            
            # Process prediction
            h = cv2.resize(masks_dict[method][b], (W, H), interpolation=cv2.INTER_LINEAR)
            if np.isnan(h).any():
                continue
            thr = threshold_otsu(h)
            m = h >= thr
            if m.sum() == 0:
                m = h >= (h.max() * THR_FALLBACK)
            if m.sum() > 0:
                m_np = m.astype(np.uint8)
                d_rle = rle(m_np)
                ys, xs = np.where(m_np)
                x0, x1 = xs.min(), xs.max() + 1
                y0, y1 = ys.min(), ys.max() + 1
                data[method]['anns_dt'].append({
                    "image_id": img_id, "category_id": cid, "score": 1.0,
                    "segmentation": d_rle, "bbox": [x0, y0, x1 - x0, y1 - y0]
                })

In [22]:
# os.makedirs("../../final_results/", exist_ok=True)
# save_segmentation_data(data, '../../final_results/segmentation/segmentation_result_latent_global.json')

In [23]:
# Evaluate all methods
def evaluate_all_methods(data, methods):
    results = {}

    for method in methods:
        print("="*60)
        print(f"EVALUATING {method.upper()}")
        print("="*60)

        cats = [{'id': v, 'name': k} for k, v in data[method]['cls2id'].items()]
        gt = {'images': data[method]['images'], 'annotations': data[method]['anns_gt'], 'categories': cats}

        print(f"{method}: {len(data[method]['images'])} images, {len(data[method]['anns_dt'])} predictions")

        coco_gt = COCO()
        coco_gt.dataset = gt
        coco_gt.createIndex()
        coco_dt = coco_gt.loadRes(data[method]['anns_dt'])

        eval_obj = COCOeval(coco_gt, coco_dt, iouType="segm")
        eval_obj.params.iouThrs = np.array([IOU_THR])
        eval_obj.evaluate()
        eval_obj.accumulate()
        eval_obj.summarize()

        results[method] = eval_obj.stats[0]

    # Comparison
    print("="*60)
    print("COMPARISON SUMMARY")
    print("="*60)

    for method, score in results.items():
        print(f"{method} mAP@{IOU_THR}: {score:.4f}")

    print("\nRankings:")
    sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)
    for rank, (method, score) in enumerate(sorted_results, 1):
        print(f"{rank}. {method}: {score:.4f}")

    best_method, best_score = sorted_results[0]
    print(f"\nWinner: {best_method} with {best_score:.4f}")
    return results

In [31]:
data, metadata = load_segmentation_data('../../final_results/segmentation/segmentation_result_latent_global.json')

Loading segmentation data from final_results/segmentation/segmentation_result_latent_global.json...
✓ Loaded data for 5 methods: ['dino_sparc_sum', 'clip_sparc_sum', 'clip_sparc_cross', 'dino_sparc_cross', 'clip_sim']
  - dino_sparc_sum: 14631 images, 14541 predictions
  - clip_sparc_sum: 14631 images, 14523 predictions
  - clip_sparc_cross: 14631 images, 14631 predictions
  - dino_sparc_cross: 14631 images, 14631 predictions
  - clip_sim: 14631 images, 14631 predictions


In [32]:
evaluate_all_methods(data, list(data.keys()))

EVALUATING DINO_SPARC_SUM
dino_sparc_sum: 14631 images, 14541 predictions
creating index...
index created!
Loading and preparing results...
DONE (t=0.01s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *segm*
DONE (t=4.96s).
Accumulating evaluation results...
DONE (t=0.61s).
 Average Precision  (AP) @[ IoU=0.20:0.20 | area=   all | maxDets=100 ] = 0.222
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.20:0.20 | area= small | maxDets=100 ] = 0.019
 Average Precision  (AP) @[ IoU=0.20:0.20 | area=medium | maxDets=100 ] = 0.244
 Average Precision  (AP) @[ IoU=0.20:0.20 | area= large | maxDets=100 ] = 0.264
 Average Recall     (AR) @[ IoU=0.20:0.20 | area=   all | maxDets=  1 ] = 0.352
 Average Recall     (AR) @[ IoU=0.20:0.20 | area=   all | maxDets= 10 ] = 0.352
 Average Recall     (AR) @[ IoU=0.20:0.

{'dino_sparc_sum': 0.22183031405480547,
 'clip_sparc_sum': 0.22248569137634533,
 'clip_sparc_cross': 0.2229675861393922,
 'dino_sparc_cross': 0.22223839494804298,
 'clip_sim': 0.24770786555228405}

### mIOU

In [None]:
import numpy as np
from pycocotools import mask as mu
from collections import defaultdict

def compute_iou(mask1_rle, mask2_rle):
    """Compute IoU between two RLE masks"""
    # Decode RLE to binary masks
    mask1 = mu.decode(mask1_rle)
    mask2 = mu.decode(mask2_rle)
    
    # Compute intersection and union
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    
    if union == 0:
        return 0.0
    
    return intersection / union

def compute_miou(data, method):
    """Compute mIOU for a specific method"""
    print(f"Computing mIOU for {method}...")
    
    # Group ground truth and predictions by image_id and category_id
    gt_by_img_cat = {}
    dt_by_img_cat = {}
    
    # Index ground truth
    for gt_ann in data[method]['anns_gt']:
        key = (gt_ann['image_id'], gt_ann['category_id'])
        gt_by_img_cat[key] = gt_ann
    
    # Index predictions  
    for dt_ann in data[method]['anns_dt']:
        key = (dt_ann['image_id'], dt_ann['category_id'])
        dt_by_img_cat[key] = dt_ann
    
    # Compute IoU for all ground truth annotations (penalize missing predictions)
    ious = []
    matched_pairs = 0
    
    for key, gt_ann in gt_by_img_cat.items():  # Changed: iterate through all GT
        if key in dt_by_img_cat:
            gt_mask = gt_ann['segmentation']
            dt_mask = dt_by_img_cat[key]['segmentation']
            
            iou = compute_iou(gt_mask, dt_mask)
            ious.append(iou)
            matched_pairs += 1
        else:
            # Missing prediction: penalize with IoU = 0
            ious.append(0.0)
    
    if len(ious) == 0:
        return 0.0, 0
    
    miou = np.mean(ious)
    print(f"  {method}: {matched_pairs} matched pairs, mIOU = {miou:.4f}")
    
    return miou, matched_pairs

def compute_all_miou(data):
    """Compute mIOU for all methods"""
    print("="*60)
    print("COMPUTING mIOU FOR ALL METHODS")
    print("="*60)
    
    results = {}
    
    for method in data.keys():
        miou, matched_pairs = compute_miou(data, method)
        results[method] = {
            'mIOU': miou,
            'matched_pairs': matched_pairs
        }
    
    # Print comparison
    print("\n" + "="*40)
    print("mIOU COMPARISON")
    print("="*40)
    
    for method, result in results.items():
        print(f"{method}: mIOU = {result['mIOU']:.4f} ({result['matched_pairs']} pairs)")
    
    # Rankings
    print("\nRankings by mIOU:")
    sorted_results = sorted(results.items(), key=lambda x: x[1]['mIOU'], reverse=True)
    for rank, (method, result) in enumerate(sorted_results, 1):
        print(f"{rank}. {method}: {result['mIOU']:.4f}")
    
    best_method, best_result = sorted_results[0]
    print(f"\nBest mIOU: {best_method} with {best_result['mIOU']:.4f}")
    
    return results

In [34]:
miou_results = compute_all_miou(data)

COMPUTING mIOU FOR ALL METHODS
Computing mIOU for dino_sparc_sum...
  dino_sparc_sum: 14541 matched pairs, mIOU = 0.1368
Computing mIOU for clip_sparc_sum...
  clip_sparc_sum: 14523 matched pairs, mIOU = 0.1307
Computing mIOU for clip_sparc_cross...
  clip_sparc_cross: 14631 matched pairs, mIOU = 0.1381
Computing mIOU for dino_sparc_cross...
  dino_sparc_cross: 14631 matched pairs, mIOU = 0.1429
Computing mIOU for clip_sim...
  clip_sim: 14631 matched pairs, mIOU = 0.1567

mIOU COMPARISON
dino_sparc_sum: mIOU = 0.1368 (14541 pairs)
clip_sparc_sum: mIOU = 0.1307 (14523 pairs)
clip_sparc_cross: mIOU = 0.1381 (14631 pairs)
dino_sparc_cross: mIOU = 0.1429 (14631 pairs)
clip_sim: mIOU = 0.1567 (14631 pairs)

Rankings by mIOU:
1. clip_sim: 0.1567
2. dino_sparc_cross: 0.1429
3. clip_sparc_cross: 0.1381
4. dino_sparc_sum: 0.1368
5. clip_sparc_sum: 0.1307

Best mIOU: clip_sim with 0.1567
