In [None]:
import json
from collections import defaultdict
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np
from scipy.optimize import linear_sum_assignment


clustering_file = "/ekaterina/work/src/lca/lca/tmp/grevyszebra_consistency/consistency_clustering.json"
node2uuid_file  = "/ekaterina/work/src/lca/lca/tmp/grevyszebra_consistency/consistency_node2uuid_file.json"
annotations_file = '/ekaterina/work/data/zebra/annotations/zebra_duplicates.json'

# Load the JSON file
with open(clustering_file, "r") as f:
    clusters = json.load(f)

with open(node2uuid_file, "r") as f:
    node2uuid = json.load(f)

with open(annotations_file, "r") as f:
    annotations = json.load(f)

# Preview some of the content
print(clusters)  # or print(json.dumps(data, indent=2))
print(node2uuid)
print(annotations.keys())


In [None]:
print(annotations['images'])

In [None]:
non_equal = [
        {
            "LCA clustering": [
                269,
                468,
                469,
                470,
                471,
                473,
                474
            ],
            "GT clustering": [
                265,
                266,
                267,
                268,
                269
            ]
        },
        {
            "LCA clustering": [
                532,
                533,
                534,
                535,
                536,
                537,
                538,
                414
            ],
            "GT clustering": [
                414
            ]
        },
        {
            "LCA clustering": [
                941,
                398,
                944,
                947,
                312,
                926
            ],
            "GT clustering": [
                398
            ]
        },
        {
            "LCA clustering": [
                56,
                57,
                59,
                55
            ],
            "GT clustering": [
                55,
                56,
                57,
                58,
                59,
                60
            ]
        },
        {
            "LCA clustering": [
                623,
                624,
                625,
                726,
                728
            ],
            "GT clustering": [
                624,
                625,
                623
            ]
        },
        {
            "LCA clustering": [
                560,
                289
            ],
            "GT clustering": [
                289,
                290,
                291,
                292
            ]
        },
        {
            "LCA clustering": [
                448,
                609,
                450,
                740
            ],
            "GT clustering": [
                740
            ]
        },
        {
            "LCA clustering": [
                647,
                648,
                650,
                651,
                652
            ],
            "GT clustering": [
                647,
                648,
                649,
                650,
                651,
                652
            ]
        },
        {
            "LCA clustering": [
                846,
                847,
                848,
                849,
                242,
                243
            ],
            "GT clustering": [
                848,
                849,
                846,
                847
            ]
        },
        {
            "LCA clustering": [
                97,
                98,
                103
            ],
            "GT clustering": [
                97,
                98
            ]
        },
        {
            "LCA clustering": [
                876,
                861
            ],
            "GT clustering": [
                861
            ]
        },
        {
            "LCA clustering": [
                874,
                875,
                877,
                878,
                879
            ],
            "GT clustering": [
                874,
                875,
                876,
                877,
                878,
                879
            ]
        },
        {
            "LCA clustering": [
                743,
                744,
                745,
                573,
                575
            ],
            "GT clustering": [
                744,
                745,
                743
            ]
        },
        {
            "LCA clustering": [
                864,
                934,
                845,
                340,
                341,
                862,
                863
            ],
            "GT clustering": [
                864,
                862,
                863
            ]
        },
        {
            "LCA clustering": [
                354,
                657,
                658,
                659,
                660,
                661,
                662
            ],
            "GT clustering": [
                354
            ]
        },
        {
            "LCA clustering": [
                929,
                933
            ],
            "GT clustering": [
                928,
                929,
                930,
                931,
                932,
                933,
                927
            ]
        },
        {
            "LCA clustering": [
                192,
                906,
                907,
                908,
                909,
                910,
                911
            ],
            "GT clustering": [
                192
            ]
        },
        {
            "LCA clustering": [
                475,
                839
            ],
            "GT clustering": [
                839
            ]
        },
        {
            "LCA clustering": [
                576,
                630
            ],
            "GT clustering": [
                630
            ]
        },
        {
            "LCA clustering": [
                806,
                807,
                808,
                71,
                810,
                809,
                94,
                95
            ],
            "GT clustering": [
                94,
                95
            ]
        },
        {
            "LCA clustering": [
                0,
                869,
                870,
                871,
                872,
                873
            ],
            "GT clustering": [
                869,
                870,
                871,
                872,
                873
            ]
        },
        {
            "LCA clustering": [
                384,
                385,
                293,
                294,
                39,
                945,
                949,
                28
            ],
            "GT clustering": [
                28
            ]
        },
        {
            "LCA clustering": [
                32,
                772
            ],
            "GT clustering": [
                768,
                769,
                770,
                771,
                772,
                773,
                774,
                775
            ]
        },
        {
            "LCA clustering": [
                768,
                769,
                770,
                771,
                773,
                774,
                775
            ],
            "GT clustering": [
                768,
                769,
                770,
                771,
                772,
                773,
                774,
                775
            ]
        },
        {
            "LCA clustering": [
                721,
                722,
                723,
                15
            ],
            "GT clustering": [
                15
            ]
        },
        {
            "LCA clustering": [
                409,
                586,
                411,
                156
            ],
            "GT clustering": [
                156
            ]
        },
        {
            "LCA clustering": [
                201,
                445
            ],
            "GT clustering": [
                445
            ]
        },
        {
            "LCA clustering": [
                544,
                540,
                541,
                542,
                543
            ],
            "GT clustering": [
                544,
                539,
                540,
                541,
                542,
                543
            ]
        },
        {
            "LCA clustering": [
                440,
                736
            ],
            "GT clustering": [
                736,
                737,
                738,
                739
            ]
        },
        {
            "LCA clustering": [
                577,
                84
            ],
            "GT clustering": [
                84
            ]
        },
        {
            "LCA clustering": [
                737,
                738,
                739
            ],
            "GT clustering": [
                736,
                737,
                738,
                739
            ]
        },
        {
            "LCA clustering": [
                763,
                764
            ],
            "GT clustering": [
                763,
                764,
                765,
                766
            ]
        },
        {
            "LCA clustering": [
                272,
                273,
                274,
                270
            ],
            "GT clustering": [
                270,
                271,
                272,
                273,
                274
            ]
        },
        {
            "LCA clustering": [
                834,
                187
            ],
            "GT clustering": [
                187
            ]
        },
        {
            "LCA clustering": [
                928,
                930,
                931,
                932,
                927
            ],
            "GT clustering": [
                928,
                929,
                930,
                931,
                932,
                933,
                927
            ]
        },
        {
            "LCA clustering": [
                580,
                68
            ],
            "GT clustering": [
                68
            ]
        },
        {
            "LCA clustering": [
                891,
                69
            ],
            "GT clustering": [
                69
            ]
        },
        {
            "LCA clustering": [
                24,
                25,
                26,
                27
            ],
            "GT clustering": [
                23,
                24,
                25,
                26,
                27
            ]
        },
        {
            "LCA clustering": [
                104,
                105,
                106
            ],
            "GT clustering": [
                104,
                105,
                106,
                103
            ]
        },
        {
            "LCA clustering": [
                765,
                766
            ],
            "GT clustering": [
                763,
                764,
                765,
                766
            ]
        },
        {
            "LCA clustering": [
                664,
                281,
                282
            ],
            "GT clustering": [
                281,
                282
            ]
        },
        {
            "LCA clustering": [
                290,
                291,
                292
            ],
            "GT clustering": [
                289,
                290,
                291,
                292
            ]
        },
        {
            "LCA clustering": [
                43,
                44,
                852
            ],
            "GT clustering": [
                43,
                44
            ]
        },
        {
            "LCA clustering": [
                202,
                866
            ],
            "GT clustering": [
                866
            ]
        },
        {
            "LCA clustering": [
                418,
                890
            ],
            "GT clustering": [
                889,
                890
            ]
        },
        {
            "LCA clustering": [
                164,
                165
            ],
            "GT clustering": [
                164,
                165,
                166
            ]
        },
        {
            "LCA clustering": [
                460,
                461
            ],
            "GT clustering": [
                460,
                461,
                462
            ]
        },
        {
            "LCA clustering": [
                265,
                266,
                267,
                268
            ],
            "GT clustering": [
                265,
                266,
                267,
                268,
                269
            ]
        },
        {
            "LCA clustering": [
                570,
                571
            ],
            "GT clustering": [
                570,
                571,
                572
            ]
        },
        {
            "LCA clustering": [
                372,
                365
            ],
            "GT clustering": [
                365
            ]
        },
        {
            "LCA clustering": [
                601,
                853
            ],
            "GT clustering": [
                853
            ]
        },
        {
            "LCA clustering": [
                476
            ],
            "GT clustering": [
                475,
                476
            ]
        },
        {
            "LCA clustering": [
                2
            ],
            "GT clustering": [
                1,
                2
            ]
        },
        {
            "LCA clustering": [
                34
            ],
            "GT clustering": [
                32,
                34
            ]
        },
        {
            "LCA clustering": [
                649
            ],
            "GT clustering": [
                647,
                648,
                649,
                650,
                651,
                652
            ]
        },
        {
            "LCA clustering": [
                386
            ],
            "GT clustering": [
                384,
                385,
                386
            ]
        },
        {
            "LCA clustering": [
                410
            ],
            "GT clustering": [
                409,
                410,
                411
            ]
        },
        {
            "LCA clustering": [
                521
            ],
            "GT clustering": [
                521,
                522
            ]
        },
        {
            "LCA clustering": [
                943
            ],
            "GT clustering": [
                944,
                947,
                941,
                943
            ]
        },
        {
            "LCA clustering": [
                579
            ],
            "GT clustering": [
                578,
                579
            ]
        },
        {
            "LCA clustering": [
                1
            ],
            "GT clustering": [
                1,
                2
            ]
        },
        {
            "LCA clustering": [
                60
            ],
            "GT clustering": [
                55,
                56,
                57,
                58,
                59,
                60
            ]
        },
        {
            "LCA clustering": [
                572
            ],
            "GT clustering": [
                570,
                571,
                572
            ]
        },
        {
            "LCA clustering": [
                522
            ],
            "GT clustering": [
                521,
                522
            ]
        },
        {
            "LCA clustering": [
                166
            ],
            "GT clustering": [
                164,
                165,
                166
            ]
        },
        {
            "LCA clustering": [
                271
            ],
            "GT clustering": [
                270,
                271,
                272,
                273,
                274
            ]
        },
        {
            "LCA clustering": [
                946
            ],
            "GT clustering": [
                946,
                948,
                950
            ]
        },
        {
            "LCA clustering": [
                948
            ],
            "GT clustering": [
                946,
                948,
                950
            ]
        },
        {
            "LCA clustering": [
                844
            ],
            "GT clustering": [
                844,
                845
            ]
        },
        {
            "LCA clustering": [
                116
            ],
            "GT clustering": [
                115,
                116
            ]
        },
        {
            "LCA clustering": [
                23
            ],
            "GT clustering": [
                23,
                24,
                25,
                26,
                27
            ]
        },
        {
            "LCA clustering": [
                539
            ],
            "GT clustering": [
                544,
                539,
                540,
                541,
                542,
                543
            ]
        },
        {
            "LCA clustering": [
                58
            ],
            "GT clustering": [
                55,
                56,
                57,
                58,
                59,
                60
            ]
        },
        {
            "LCA clustering": [
                889
            ],
            "GT clustering": [
                889,
                890
            ]
        },
        {
            "LCA clustering": [
                741
            ],
            "GT clustering": [
                741,
                742
            ]
        },
        {
            "LCA clustering": [
                742
            ],
            "GT clustering": [
                741,
                742
            ]
        },
        {
            "LCA clustering": [
                574
            ],
            "GT clustering": [
                573,
                574,
                575
            ]
        },
        {
            "LCA clustering": [
                462
            ],
            "GT clustering": [
                460,
                461,
                462
            ]
        },
        {
            "LCA clustering": [
                472
            ],
            "GT clustering": [
                468,
                469,
                470,
                471,
                472,
                473,
                474
            ]
        },
        {
            "LCA clustering": [
                449
            ],
            "GT clustering": [
                448,
                449,
                450
            ]
        },
        {
            "LCA clustering": [
                578
            ],
            "GT clustering": [
                578,
                579
            ]
        },
        {
            "LCA clustering": [
                950
            ],
            "GT clustering": [
                946,
                948,
                950
            ]
        },
        {
            "LCA clustering": [
                727
            ],
            "GT clustering": [
                728,
                726,
                727
            ]
        },
        {
            "LCA clustering": [
                115
            ],
            "GT clustering": [
                115,
                116
            ]
        }
    ]

In [None]:
uuid_to_gt = {ann["uuid"]: ann["name_viewpoint"] for ann in annotations["annotations"]}


# --- Convert cluster dict (str keys) to node → cluster_id ---
node_to_pred_cluster = {}
for cluster_id, node_list in clusters.items():
    for node in node_list:
        node_to_pred_cluster[int(node)] = int(cluster_id)

# --- UUID → Node ID mapping ---
uuid_to_node = {uuid: int(node) for node, uuid in node2uuid.items()}

# --- Build GT clusters based on name_viewpoint ---
gt_clusters = defaultdict(list)
for uuid, gt in uuid_to_gt.items():
    if uuid in uuid_to_node:
        node = uuid_to_node[uuid]
        gt_clusters[gt].append(node)

# --- Create node → GT cluster ID mapping ---
node_to_gt_cluster = {}
for gt_id, (gt_name, node_list) in enumerate(gt_clusters.items()):
    for node in node_list:
        node_to_gt_cluster[node] = gt_id

# --- Match nodes ---
common_nodes = list(set(node_to_gt_cluster.keys()) & set(node_to_pred_cluster.keys()))
y_true = [node_to_gt_cluster[n] for n in common_nodes]
y_pred = [node_to_pred_cluster[n] for n in common_nodes]

# --- Clustering accuracy via Hungarian matching ---
def clustering_accuracy(y_true, y_pred):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    D = max(y_true.max(), y_pred.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(len(y_true)):
        w[y_pred[i], y_true[i]] += 1
    row_ind, col_ind = linear_sum_assignment(w.max() - w)
    return sum(w[r, c] for r, c in zip(row_ind, col_ind)) / len(y_true)

accuracy = clustering_accuracy(np.array(y_true), np.array(y_pred))

def hungarian_remap(y_true, y_pred):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    D = max(y_true.max(), y_pred.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(len(y_true)):
        w[y_pred[i], y_true[i]] += 1
    row_ind, col_ind = linear_sum_assignment(w.max() - w)

    # Create mapping from predicted → true labels
    label_map = {row: col for row, col in zip(row_ind, col_ind)}
    remapped_y_pred = [label_map[p] for p in y_pred]
    return remapped_y_pred

# Remap y_pred to best match y_true
y_pred_aligned = hungarian_remap(y_true, y_pred)

# Compute metrics with aligned labels
precision = precision_score(y_true, y_pred_aligned, average='macro')
recall = recall_score(y_true, y_pred_aligned, average='macro')
f1 = f1_score(y_true, y_pred_aligned, average='macro')


In [None]:
# --- Identify mismatched clusters ---
pred_sets = [set(map(int, nodes)) for nodes in clusters.values()]
gt_sets = [set(nodes) for nodes in gt_clusters.values()]

mismatched_clusters = []
for gt_set in gt_sets:
    if not any(gt_set == pred_set for pred_set in pred_sets):
        mismatched_clusters.append(gt_set)

predicted_sets = [set(map(int, nodes)) for nodes in clusters.values()]
gt_sets = [set(nodes) for nodes in gt_clusters.values()]

# Count how many GT clusters exactly match a predicted cluster
correct_clusters = sum(1 for gt_set in gt_sets if gt_set in predicted_sets)
total_predicted_clusters = len(predicted_sets)

frac_correct = correct_clusters / len(gt_sets) if len(gt_sets) > 0 else 0.0

# --- Output results ---
print("📊 Clustering Evaluation Summary")
print(f"Accuracy         : {accuracy:.4f}")
print(f"Precision (macro): {precision:.4f}")
print(f"Recall (macro)   : {recall:.4f}")
print(f"F1 Score (macro) : {f1:.4f}")
print(f"Correct Clusters : {len(gt_sets) - len(mismatched_clusters)} / {len(gt_sets)}")
print(f"Number of LCA Clusters : {len(predicted_sets)}")
print(f"Fraction of correct clusters : {frac_correct:.4f}")



In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import os

# Set your external image directory here
image_dir = "/ekaterina/work/data/zebra/images"

# Build UUID → file_name map
imageuuid_to_filename = {img["uuid"]: img["file_name"] for img in annotations["images"]}


# Node ID → UUID
uuid_to_imageuuid = {ann["uuid"]: ann["image_uuid"] for ann in annotations["annotations"]}

node_to_uuid = {int(k): v for k, v in node2uuid.items()}

# Map node_id → image_path
def get_image_path_from_node(node_id):
    uuid = node_to_uuid.get(node_id)
    if not uuid:
        return None
    image_uuid = uuid_to_imageuuid.get(uuid)
    if not image_uuid:
        return None
    file_name = imageuuid_to_filename.get(image_uuid)
    if not file_name:
        return None
    return os.path.join(image_dir, file_name)

# Show images for a cluster
def show_cluster_images(cluster_nodes, title, max_images=10):
    fig, axes = plt.subplots(1, min(len(cluster_nodes), max_images), figsize=(15, 3))
    fig.suptitle(title, fontsize=16)
    if len(cluster_nodes) == 1:
        axes = [axes]
    for ax, node_id in zip(axes, cluster_nodes):
        img_path = get_image_path_from_node(node_id)
        if img_path and os.path.exists(img_path):
            img = Image.open(img_path)
            ax.imshow(img)
            ax.set_title(str(node_id), fontsize=8)
        else:
            ax.text(0.5, 0.5, 'Missing', horizontalalignment='center', verticalalignment='center')
        ax.axis('off')
    plt.tight_layout()
    plt.show()



In [None]:

import matplotlib.pyplot as plt
from PIL import Image
import os

# --- Mapping helpers ---
uuid_to_imageuuid = {ann["uuid"]: ann["image_uuid"] for ann in annotations["annotations"]}
uuid_to_bbox = {ann["uuid"]: ann["bbox"] for ann in annotations["annotations"]}
imageuuid_to_filename = {img["uuid"]: img["file_name"] for img in annotations["images"]}
node_to_uuid = {int(k): v for k, v in node2uuid.items()}

# Get image crop from node
def get_bbox_crop(node_id):
    uuid = node_to_uuid.get(node_id)
    if not uuid:
        return None
    image_uuid = uuid_to_imageuuid.get(uuid)
    bbox = uuid_to_bbox.get(uuid)
    file_name = imageuuid_to_filename.get(image_uuid)
    if not file_name or not bbox:
        return None
    img_path = os.path.join(image_dir, file_name)
    if not os.path.exists(img_path):
        return None
    try:
        img = Image.open(img_path)
        x, y, w, h = bbox
        return img.crop((x, y, x + w, y + h))
    except:
        return None

# Visualize one mismatch pair
def show_non_equal_pair(pred_nodes, gt_nodes, index, max_images=10):
    fig, axes = plt.subplots(2, max_images, figsize=(15, 5))
    fig.suptitle(f"❌ Mismatch {index}: LCA Clustering (top) vs GT (bottom)", fontsize=14)

    for row, node_list in enumerate([pred_nodes, gt_nodes]):
        for i in range(max_images):
            ax = axes[row, i]
            if i < len(node_list):
                crop = get_bbox_crop(node_list[i])
                if crop:
                    ax.imshow(crop)
                else:
                    ax.text(0.5, 0.5, "Missing", ha='center', va='center')
            ax.axis('off')

    plt.tight_layout()
    plt.show()

# --- Run visualization for all mismatches in non_equal ---
for i, pair in enumerate(non_equal[:3]):  # visualize only first 3 for speed
    show_non_equal_pair(pair["LCA clustering"], pair["GT clustering"], index=i)


In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import os
import math

# Set image directory
# image_dir = "/your/path/to/images"  # <-- Ensure this is set

# --- Mapping helpers ---
uuid_to_imageuuid = {ann["uuid"]: ann["image_uuid"] for ann in annotations["annotations"]}
uuid_to_bbox = {ann["uuid"]: ann["bbox"] for ann in annotations["annotations"]}
imageuuid_to_filename = {img["uuid"]: img["file_name"] for img in annotations["images"]}
node_to_uuid = {int(k): v for k, v in node2uuid.items()}

# --- Function to get cropped image from node ID ---
def get_bbox_crop(node_id):
    uuid = node_to_uuid.get(node_id)
    if not uuid:
        return None
    image_uuid = uuid_to_imageuuid.get(uuid)
    bbox = uuid_to_bbox.get(uuid)
    file_name = imageuuid_to_filename.get(image_uuid)
    if not file_name or not bbox:
        return None
    img_path = os.path.join(image_dir, file_name)
    if not os.path.exists(img_path):
        return None
    try:
        img = Image.open(img_path)
        x, y, w, h = bbox
        return img.crop((x, y, x + w, y + h))
    except:
        return None

# --- Visualization helper ---
def show_images_grid(title, node_list, images_per_row=3, max_images=9):
    node_list = node_list[:max_images]
    total = len(node_list)
    rows = math.ceil(total / images_per_row)
    fig, axes = plt.subplots(rows, images_per_row, figsize=(images_per_row * 4, rows * 4))
    fig.suptitle(title, fontsize=14)

    # Flatten axes for easy iteration
    axes = np.array(axes).flatten()

    for i in range(rows * images_per_row):
        ax = axes[i]
        if i < total:
            crop = get_bbox_crop(node_list[i])
            if crop:
                ax.imshow(crop)
            else:
                ax.text(0.5, 0.5, 'Missing', ha='center', va='center')
            ax.set_title(str(node_list[i]), fontsize=8)
        ax.axis('off')

    plt.tight_layout()
    plt.show()

# --- Display all mismatches in non_equal list ---
for i, pair in enumerate(non_equal[9:15]):  # limit to 3 mismatches
    print(f"\n❌ Mismatch {i}:")
    show_images_grid("🔷 LCA Predicted Cluster", pair["LCA clustering"], images_per_row=3)
    show_images_grid("🟢 Ground-Truth Cluster", pair["GT clustering"], images_per_row=3)