In [None]:
"""Detection visualization with GMM-colored patches and re-ID evaluation."""
%load_ext autoreload 
%autoreload 2

import random
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from src.config.config import MainConfig
from src.data.preprocessed_dataset import PreprocessedDataset
from src.data.coco_loader import COCOLoader
from src.pca.incremental_pca import IncrementalPCAProcessor
from src.codebook.gmm_trainer import load_gmm_model
from src.visualization import (
    draw_bbox, draw_patches, patch_coords_in_crop, get_crop_bounds,
    compute_patch_responsibilities, responsibilities_to_colors
)
from src.evaluation import match_detections_to_gt, get_identity_mapping, get_image_uuid_from_detection_id

In [None]:
# Load config and models
config = MainConfig.from_yaml(Path("config_zebra_test.yaml"))
# config = MainConfig.from_yaml(Path("/fs/ess/PAS2136/ggr_data/GZGC/config_zebra_test.yaml"))
dataset = PreprocessedDataset(config.output_root)
pca = IncrementalPCAProcessor(config.pca, config.output_root)
gmm, gmm_metadata = load_gmm_model(config.gmm_model_path)

# Load COCO annotations and match detections to GT
coco_loader = COCOLoader(config.coco_json_path, config.dataset_root)
matched = match_detections_to_gt(dataset, coco_loader, iou_threshold=0.5, category_ids=[1])  # 1 = zebra_grevys
identity_map = get_identity_mapping(matched)

print(f"Loaded {dataset.get_total_detection_count()} detections")
print(f"Matched to GT: {len(matched)} ({len(matched)/dataset.get_total_detection_count()*100:.1f}%)")
print(f"GMM: {gmm.n_components} components")

In [None]:
# Pick random image with detections
image_paths = list(dataset._index['image_to_detections'].keys())
image_path = random.choice(image_paths)
detections = dataset.get_detections_for_image(image_path)
image = Image.open(image_path).convert('RGB')

print(f"Image: {Path(image_path).name}")
print(f"Detections: {len(detections)}")

In [None]:
# Show image with bounding boxes
fig, ax = plt.subplots(figsize=(12, 8))
ax.imshow(image)
for det in detections:
    draw_bbox(ax, det.bbox, color='red', linewidth=2)
    # draw_bbox(ax, det.square_crop_bbox, color='blue', linewidth=1)
ax.set_title(f"{len(detections)} detections (red=bbox, blue=square_crop)")
ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Show each detection with GMM-colored patches
img_w, img_h = image.size
n_detections = len(detections)
cols = min(4, n_detections)
rows = (n_detections + cols - 1) // cols

fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))
axes = [axes] if n_detections == 1 else axes.flatten()

for idx, det in enumerate(detections):
    ax = axes[idx]
    
    # Get crop
    x1, y1, x2, y2 = get_crop_bounds(det.square_crop_bbox, img_w, img_h)
    crop = image.crop((x1, y1, x2, y2))
    crop_w, crop_h = x2 - x1, y2 - y1
    
    ax.imshow(crop)
    
    # Compute GMM responsibilities and colors
    resp = compute_patch_responsibilities(det.features, det.patch_mask, pca, gmm)
    colors = responsibilities_to_colors(resp)
    
    # Draw colored patches
    coords = patch_coords_in_crop(det.patch_mask, crop_w, crop_h)
    draw_patches(ax, coords, colors=colors, alpha=0.6)
    
    ax.set_title(f"Det {idx}: {det.patch_mask.sum().item():.0f} patches")
    ax.axis('off')

# Hide unused axes
for idx in range(n_detections, len(axes)):
    axes[idx].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Load Fisher vectors
import numpy as np
from src.data.fv_dataset import FisherVectorDataset

fv_dataset = FisherVectorDataset(config.output_root / "fisher_vectors_reduced")
all_det_ids, all_fvs = fv_dataset.get_all_fisher_vectors()
all_fvs_norm = all_fvs / np.linalg.norm(all_fvs, axis=1, keepdims=True)

print(f"Loaded {len(all_det_ids)} Fisher vectors, dim: {all_fvs.shape[1]}")

In [None]:
# Visualize top-5 matches with identity verification (green=correct, red=incorrect)
from matplotlib.patches import Rectangle

def show_detection_with_patches(ax, det, title, alpha=0.4, border_color=None):
    img = Image.open(det.image_path).convert('RGB')
    img_w, img_h = img.size
    x1, y1, x2, y2 = get_crop_bounds(det.square_crop_bbox, img_w, img_h)
    crop = img.crop((x1, y1, x2, y2))
    crop_w, crop_h = x2 - x1, y2 - y1
    
    ax.imshow(crop)
    
    # Draw GMM-colored patches
    resp = compute_patch_responsibilities(det.features, det.patch_mask, pca, gmm)
    colors = responsibilities_to_colors(resp)
    coords = patch_coords_in_crop(det.patch_mask, crop_w, crop_h)
    draw_patches(ax, coords, colors=colors, alpha=alpha)
    
    ax.set_title(title, fontsize=9)
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Add colored border using Rectangle patch
    if border_color:
        rect = Rectangle((0, 0), crop_w - 1, crop_h - 1, 
                         linewidth=6, edgecolor=border_color, facecolor='none')
        ax.add_patch(rect)

for det_idx, query_det in enumerate(detections):
    query_identity = identity_map.get(query_det.detection_id)
    query_image_uuid = get_image_uuid_from_detection_id(query_det.detection_id)
    
    # Find top-5 matches (excluding same image)
    query_fv = fv_dataset.get_fisher_vector(query_det.detection_id)
    query_fv_norm = query_fv / np.linalg.norm(query_fv)
    similarities = all_fvs_norm @ query_fv_norm
    
    # Exclude same-image detections
    for i, det_id in enumerate(all_det_ids):
        if get_image_uuid_from_detection_id(det_id) == query_image_uuid:
            similarities[i] = -np.inf
    
    top_indices = np.argsort(similarities)[::-1][:5]
    
    # Plot
    fig, axes = plt.subplots(1, 6, figsize=(18, 3))
    
    # Query (blue border)
    show_detection_with_patches(axes[0], query_det, 
        f"Query: {query_identity[:8] if query_identity else 'unknown'}...", border_color='blue')
    
    for i, idx in enumerate(top_indices):
        match_det = dataset.get_detection(all_det_ids[idx])
        match_identity = identity_map.get(all_det_ids[idx])
        
        # Determine border color: green if same identity, red if different, gray if unknown
        if query_identity and match_identity:
            border_color = 'green' if query_identity == match_identity else 'red'
        else:
            border_color = 'gray'
        
        identity_str = match_identity[:8] if match_identity else 'unknown'
        show_detection_with_patches(axes[i+1], match_det, 
            f"#{i+1} sim={similarities[idx]:.2f}\n{identity_str}...", 
            border_color=border_color)
    
    plt.tight_layout()
    plt.show()

In [None]:
# Compute top-k accuracy metrics
from src.evaluation import compute_reid_accuracy

# Build detection_to_image mapping
det_to_image = {det_id: get_image_uuid_from_detection_id(det_id) for det_id in all_det_ids}

# Build fisher_vectors dict
fisher_vectors = {det_id: all_fvs[i] for i, det_id in enumerate(all_det_ids)}

# Compute metrics
metrics = compute_reid_accuracy(
    fisher_vectors=fisher_vectors,
    identity_mapping=identity_map,
    exclude_same_image=True,
    detection_to_image=det_to_image
)

print("=" * 50)
print("Re-Identification Metrics")
print("=" * 50)
print(f"Queries: {metrics.num_queries} (detections with known identity)")
print(f"Gallery: {metrics.num_gallery}")
print()
print(f"Top-1 Accuracy:  {metrics.top1_accuracy:.2%}")
print(f"Top-5 Accuracy:  {metrics.top5_accuracy:.2%}")
print(f"Top-10 Accuracy: {metrics.top10_accuracy:.2%}")
print(f"Mean Reciprocal Rank: {metrics.mean_reciprocal_rank:.4f}")

In [None]:
# Visualize patch-level matches between query and top-1 match
from src.matching import (
    extract_valid_patches, compute_patch_similarities, find_matches_ratio_test
)
from src.visualization import visualize_patch_matches

for det_idx, query_det in enumerate(detections):
    query_identity = identity_map.get(query_det.detection_id)
    query_image_uuid = get_image_uuid_from_detection_id(query_det.detection_id)
    
    # Find top-1 match (excluding same image)
    query_fv = fv_dataset.get_fisher_vector(query_det.detection_id)
    query_fv_norm = query_fv / np.linalg.norm(query_fv)
    similarities = all_fvs_norm @ query_fv_norm
    
    for i, det_id in enumerate(all_det_ids):
        if get_image_uuid_from_detection_id(det_id) == query_image_uuid:
            similarities[i] = -np.inf
    k = 10
    top1_idx = np.argmax(similarities)
    top1_idx = np.argsort(similarities)[::-1][k]
    top1_det = dataset.get_detection(all_det_ids[top1_idx])
    top1_identity = identity_map.get(all_det_ids[top1_idx])
    
    # Extract patches (with optional PCA)
    patches1, coords1 = extract_valid_patches(query_det.features, query_det.patch_mask, pca=pca)
    patches2, coords2 = extract_valid_patches(top1_det.features, top1_det.patch_mask, pca=pca)
    
    # Compute patch similarities and find matches with ratio test
    patch_sims = compute_patch_similarities(patches1, patches2)
    matches = find_matches_ratio_test(patch_sims, coords1, coords2, ratio=0.99)
    
    # Visualize
    is_correct = query_identity and top1_identity and query_identity == top1_identity
    match_status = "CORRECT" if is_correct else "INCORRECT" if query_identity and top1_identity else "UNKNOWN"
    
    title = f"Query: {query_identity[:8] if query_identity else 'unk'}... -> Top-{k}: {top1_identity[:8] if top1_identity else 'unk'}... [{match_status}]"
    fig = visualize_patch_matches(query_det, top1_det, matches, title=title)
    plt.show()

In [None]:
# In notebook - diagnostic
ratios = []
for i in range(patch_sims.shape[0]):
    sorted_sims = np.sort(patch_sims[i])[::-1]
    if sorted_sims[0] > 0:
        ratios.append(sorted_sims[1] / sorted_sims[0])
        
print(f"Ratio distribution: min={min(ratios):.3f}, max={max(ratios):.3f}, mean={np.mean(ratios):.3f}")
# If mean is close to 1.0, features are basically indistinguishable

In [None]:
# Texture Identity Diagnostic using Fisher Vectors
from pathlib import Path
import numpy as np

from src.config.config import MainConfig
from src.data.fv_dataset import FisherVectorDataset
from src.evaluation import (
    compute_texture_identity_separation,
    plot_diagnostic_distributions,
    match_detections_to_gt,
    get_identity_mapping,
)
from src.data.coco_loader import COCOLoader
from src.data.preprocessed_dataset import PreprocessedDataset

# Load configs
semantic_config = MainConfig.from_yaml(Path("config_zebra_test.yaml"))  # DINO for pose
textural_config = MainConfig.from_yaml(Path("config_zebra_disk.yaml"))  # SIFT for texture

# Load semantic Fisher Vectors (DINO)
print("Loading semantic Fisher Vectors (DINO for pose matching)...")
semantic_fv_dataset = FisherVectorDataset(Path(semantic_config.output_root) / "fisher_vectors_reduced")
sem_det_ids, sem_fv_matrix = semantic_fv_dataset.get_all_fisher_vectors()
semantic_features = {det_id: fv for det_id, fv in zip(sem_det_ids, sem_fv_matrix)}
print(f"Loaded {len(semantic_features)} semantic Fisher Vectors, dim: {sem_fv_matrix.shape[1]}")

# Load textural Fisher Vectors (SIFT)
print("\nLoading textural Fisher Vectors (SIFT for texture)...")
textural_fv_dataset = FisherVectorDataset(Path(textural_config.output_root) / "fisher_vectors_reduced")
tex_det_ids, tex_fv_matrix = textural_fv_dataset.get_all_fisher_vectors()
textural_features = {det_id: fv for det_id, fv in zip(tex_det_ids, tex_fv_matrix)}
print(f"Loaded {len(textural_features)} textural Fisher Vectors, dim: {tex_fv_matrix.shape[1]}")

# Get identity mapping (use semantic dataset for GT matching)
semantic_dataset = PreprocessedDataset(semantic_config.output_root)
coco_loader = COCOLoader(semantic_config.coco_json_path, semantic_config.dataset_root)
matched = match_detections_to_gt(semantic_dataset, coco_loader, iou_threshold=0.5, category_ids=[1])
identity_map = get_identity_mapping(matched)
print(f"\nIdentity map: {len(identity_map)} detections with known identity")

# Run diagnostic with Fisher Vectors
print("\n" + "=" * 60)
print("Running texture identity diagnostic with Fisher Vectors")
print("=" * 60)

for skip_k in [0, 5, 10, 20]:
    result = compute_texture_identity_separation(
        semantic_features=semantic_features,
        textural_features=textural_features,
        identity_map=identity_map,
        n_queries=1000,
        k_pose_neighbors=50,
        skip_top_k=skip_k,
        min_same_identity=1,
    )
    print(f"\nSkip top {skip_k}:")
    print(f"  ROC-AUC: {result.roc_auc:.4f}")
    print(f"  Same-ID sim: {result.same_identity_mean:.4f} ± {result.same_identity_std:.4f}")
    print(f"  Diff-ID sim: {result.diff_identity_mean:.4f} ± {result.diff_identity_std:.4f}")
    print(f"  Pairs: {result.n_same_identity_pairs} same / {result.n_diff_identity_pairs} diff")

# Plot final result
print("\n")
result.print_summary()
plot_diagnostic_distributions(result)