In [None]:
%env CUDA_VISIBLE_DEVICES=1
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")

In [None]:
import torch

distances_dino_diff = torch.load("../analysis_data/distances_dino_diff.pt")
distances_diff = torch.load("../analysis_data/distances_diff.pt")
distances_dino = torch.load("../analysis_data/distances_dino.pt")
distances_dino_s = torch.load("../analysis_data/distances_dino_s.pt")
#distances_clip = torch.load("../analysis_data/distances_clip.pt")

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))

x_thresh = 0.1 * 768

max = 1000
bins = 100
print(distances_dino_diff.max(), distances_dino_diff.min())
#dino_diff_hist = torch.histc(distances_dino_diff.float(), bins=bins, min=0, max=max)
diff_hist = torch.histc(distances_diff.float(), bins=bins, min=0, max=max)
dino_hist = torch.histc(distances_dino.float(), bins=bins, min=0, max=max)
#dino_hist_s = torch.histc(distances_dino_s.float(), bins=bins, min=0, max=max)
#clip_hist = torch.histc(distances_clip.float(), bins=100, min=0, max=max)

plt.plot(diff_hist, label="ADD-XL")
plt.plot(dino_hist, label="DINOv2 B/14")
#plt.plot(dino_diff_hist, label="DINOv2 B/14 and ADD-XL")
#plt.plot(dino_hist_s, label="DINOv2 B/14 336x336")
#plt.plot(clip_hist, label="CLIP L/14")

# draw vertical line at threshold
plt.axvline(x = (x_thresh / max) * bins, color = 'r', label = 'img 0.1 threshold')

print(diff_hist[:int((x_thresh / max) * bins)].sum() / diff_hist.sum())
print(dino_hist[:int((x_thresh / max) * bins)].sum() / dino_hist.sum())

plt.legend()
plt.show()

In [None]:
similarity_map_hists_diff = torch.load("../analysis_data/histograms_diff.pt")
similarity_map_hists_dino = torch.load("../analysis_data/histograms_dino.pt")
similarity_map_hists_dino_diff = torch.load("../analysis_data/histograms_dino_diff.pt")
similarity_map_hists_dino_s = torch.load("../analysis_data/histograms_dino_s.pt")

plt.figure(figsize=(10, 5))
plt.plot(similarity_map_hists_diff, label="ADD-XL")
plt.plot(similarity_map_hists_dino, label="DINOv2 B/14")
plt.plot(similarity_map_hists_dino_diff, label="DINOv2 B/14 and ADD-XL")
plt.plot(similarity_map_hists_dino_s, label="DINOv2 B/14 336x336")
plt.legend()
plt.show()

In [None]:
# histograms are from images that are the same size, predictions not

use_aligned_dino = False

predictions_diff, correct_diff, distances_diff = torch.load("../analysis_data/predictions_diff.pt")
predictions_dino, correct_dino, distances_dino = torch.load(f"../analysis_data/predictions_dino{'_s' if use_aligned_dino else ''}.pt")
predictions_dino_diff, correct_dino_diff, distances_dino_diff = torch.load(f"../analysis_data/predictions_dino_diff{'_s' if use_aligned_dino else ''}.pt")

print(len(predictions_diff), len(correct_diff), len(distances_diff))
print(predictions_diff[0].shape, correct_diff[0].shape, distances_diff[0].shape)

# Confusion table with: (diff correct, dino wrong), (diff wrong, dino correct), (diff correct, dino correct), (diff wrong, dino wrong)
# predictions_diff, correct_diff, distances_diff are each lists with multiple point predictions

confusion = torch.zeros(4)
for i in range(len(predictions_diff)):
    for j in range(len(predictions_diff[i])):
        if correct_diff[i][j] and not correct_dino[i][j]:
            confusion[0] += 1
        elif not correct_diff[i][j] and correct_dino[i][j]:
            confusion[1] += 1
        elif correct_diff[i][j] and correct_dino[i][j]:
            confusion[2] += 1
        else:
            confusion[3] += 1

print(confusion)
print(confusion / confusion.sum())

In [None]:
# How many the 4 confusions switch to correct or wrong when in DINO+Diff
# Plot a Sankey diagram of the 4 confusions
from pysankey import sankey

# Confusion switch for DINO+Diff
confusion_switch = torch.zeros((4, 2))  # Rows for initial states, columns for [switched to correct, switched to wrong]

df = []
leftLabels = ['diff correct, dino wrong', 'diff wrong, dino correct', 'both correct', 'both wrong']
rightLabels = ['correct', 'wrong']

for i in range(len(predictions_diff)):
    for j in range(len(predictions_diff[i])):
        # Determine the initial state
        if correct_diff[i][j] and not correct_dino[i][j]:  # diff correct, dino wrong
            initial_state = 0
        elif not correct_diff[i][j] and correct_dino[i][j]:  # diff wrong, dino correct
            initial_state = 1
        elif correct_diff[i][j] and correct_dino[i][j]:  # both correct
            initial_state = 2
        else:  # both wrong
            initial_state = 3

        # Check if switched to correct or wrong in DINO+Diff
        if correct_dino_diff[i][j]:
            new_state = 0
            confusion_switch[initial_state, 0] += 1  # Switched to correct
        else:
            new_state = 1
            confusion_switch[initial_state, 1] += 1  # Switched to wrong

        df.append([leftLabels[initial_state], rightLabels[new_state]])

proportions = confusion_switch / confusion_switch.sum(dim=1, keepdim=True)

print("DINO+Diff flow:")
print("From diff correct, dino wrong to correct: ", proportions[0, 0].item())
print("From diff correct, dino wrong to wrong: ", proportions[0, 1].item())
print("From diff wrong, dino correct to correct: ", proportions[1, 0].item())
print("From diff wrong, dino correct to wrong: ", proportions[1, 1].item())
print("From both correct to correct: ", proportions[2, 0].item())
print("From both correct to wrong: ", proportions[2, 1].item())
print("From both wrong to correct: ", proportions[3, 0].item())
print("From both wrong to wrong: ", proportions[3, 1].item())

import pandas as pd

pd.options.display.max_rows = 6
df = pd.DataFrame(df, columns=['From', 'To'])

ax = sankey(
    left=df['From'],  # Column 0 of df
    right=df['To'],  # Column 1 of df
    aspect=20, 
    leftLabels=leftLabels,
    rightLabels=rightLabels,
    fontsize=12
)
plt.show()

In [None]:
from utils.visualization import display_image_pair
from utils.dataset import read_dataset_config, load_dataset
from utils.correspondence import flip_points

# Load dataset config
dataset_config = '../dataset_config.yaml'
dataset_config = read_dataset_config(dataset_config)

# Evaluate
dataset_name = "SPair-71k"
config = dataset_config[dataset_name]
dataset = load_dataset(dataset_name, config)

In [None]:
n_both_correct = torch.zeros(len(predictions_diff))
n_both_incorrect = torch.zeros(len(predictions_diff))
n_only_diff_correct = torch.zeros(len(predictions_diff))
n_only_dino_correct = torch.zeros(len(predictions_diff))
for i in range(len(predictions_diff)):
    for j in range(len(predictions_diff[i])):
        if correct_diff[i][j] and correct_dino[i][j]:
            n_both_correct[i] += 1
        elif not correct_diff[i][j] and not correct_dino[i][j]:
            n_both_incorrect[i] += 1
        elif correct_diff[i][j] and not correct_dino[i][j]:
            n_only_diff_correct[i] += 1
        else:
            n_only_dino_correct[i] += 1

    # Divide by number of points
    n_both_correct[i] /= len(predictions_diff[i])
    n_both_incorrect[i] /= len(predictions_diff[i])
    n_only_diff_correct[i] /= len(predictions_diff[i])
    n_only_dino_correct[i] /= len(predictions_diff[i])

#sample = dataset[7]
#display_image_pair(sample, show_bbox=True)

# Display k examples of samples where the most points are correct for both models
k = 2

arg_both_correct = n_both_correct.argsort(descending=True)
arg_both_incorrect = n_both_incorrect.argsort(descending=True)
arg_only_diff_correct = n_only_diff_correct.argsort(descending=True)
arg_only_dino_correct = n_only_dino_correct.argsort(descending=True)

labels = ["Ground Truth", "Diff", "DINO", "Diff+DINO"]
for i in range(k):
    sample = dataset[arg_both_correct[i]]
    print("Both correct")
    sample["target_points"] = [sample["target_points"]] + [flip_points(predictions_diff[arg_both_correct[i]]),
                               flip_points(predictions_dino[arg_both_correct[i]]),
                               flip_points(predictions_dino_diff[arg_both_correct[i]])]
    display_image_pair(sample, show_bbox=True, labels=labels)

for i in range(k):
    sample = dataset[arg_both_incorrect[i]]
    print("Both incorrect")
    sample["target_points"] = [sample["target_points"]] + [flip_points(predictions_diff[arg_both_incorrect[i]]),
                               flip_points(predictions_dino[arg_both_incorrect[i]]),
                               flip_points(predictions_dino_diff[arg_both_incorrect[i]])]
    display_image_pair(sample, show_bbox=True, labels=labels)

for i in range(k):
    sample = dataset[arg_only_diff_correct[i]]
    print("Only diff correct")
    sample["target_points"] = [sample["target_points"]] + [flip_points(predictions_diff[arg_only_diff_correct[i]]),
                               flip_points(predictions_dino[arg_only_diff_correct[i]]),
                               flip_points(predictions_dino_diff[arg_only_diff_correct[i]])]
    display_image_pair(sample, show_bbox=True, labels=labels)

for i in range(k):
    sample = dataset[arg_only_dino_correct[i]]
    print("Only dino correct: Diff")
    sample["target_points"] = [sample["target_points"]] + [flip_points(predictions_diff[arg_only_dino_correct[i]]),
                               flip_points(predictions_dino[arg_only_dino_correct[i]]),
                               flip_points(predictions_dino_diff[arg_only_dino_correct[i]])]
    display_image_pair(sample, show_bbox=True, labels=labels)

In [None]:
# most failure cases are due to occlusions, lack of spatial coherence (when objects are facing different sides) and lack of accuracy

In [None]:
# Plot histogram of distances between both models
dist_dino_diff = []
for pred_diff, pred_dino in zip(predictions_diff, predictions_dino):
    dist_dino_diff.append(torch.linalg.norm(pred_diff.float() - pred_dino.float(), dim=1))
dist_dino_diff = torch.cat(dist_dino_diff)
print(dist_dino_diff.mean(), dist_dino_diff.std())

# Avg dist of (dino correct, diff correct), (dino wrong, diff wrong), (dino correct, diff wrong), (dino wrong, diff correct)

dist_dino_diff_both_correct = []
dist_dino_diff_both_incorrect = []
dist_dino_diff_only_dino_correct = []
dist_dino_diff_only_diff_correct = []
for i in range(len(predictions_diff)):
    for j in range(len(predictions_diff[i])):
        if correct_diff[i][j] and correct_dino[i][j]:
            dist_dino_diff_both_correct.append(torch.linalg.norm(predictions_diff[i][j].float() - predictions_dino[i][j].float()))
        elif not correct_diff[i][j] and not correct_dino[i][j]:
            dist_dino_diff_both_incorrect.append(torch.linalg.norm(predictions_diff[i][j].float() - predictions_dino[i][j].float()))
        elif correct_diff[i][j] and not correct_dino[i][j]:
            dist_dino_diff_only_diff_correct.append(torch.linalg.norm(predictions_diff[i][j].float() - predictions_dino[i][j].float()))
        else:
            dist_dino_diff_only_dino_correct.append(torch.linalg.norm(predictions_diff[i][j].float() - predictions_dino[i][j].float()))

print(torch.tensor(dist_dino_diff_both_correct).mean(), torch.tensor(dist_dino_diff_both_correct).std())
print(torch.tensor(dist_dino_diff_both_incorrect).mean(), torch.tensor(dist_dino_diff_both_incorrect).std())
print(torch.tensor(dist_dino_diff_only_diff_correct).mean(), torch.tensor(dist_dino_diff_only_diff_correct).std())
print(torch.tensor(dist_dino_diff_only_dino_correct).mean(), torch.tensor(dist_dino_diff_only_dino_correct).std())

plt.figure(figsize=(10, 5))
plt.hist(dist_dino_diff, bins=100)

# draw lines at dist_dino_diff_both_correct, dist_dino_diff_both_incorrect, dist_dino_diff_only_diff_correct, dist_dino_diff_only_dino_correct
plt.axvline(x = torch.tensor(dist_dino_diff_both_correct).mean(), color = 'r', label = 'both correct (' + str(torch.tensor(dist_dino_diff_both_correct).mean().item()) + ')')
plt.axvline(x = torch.tensor(dist_dino_diff_both_incorrect).mean(), color = 'g', label = 'both incorrect (' + str(torch.tensor(dist_dino_diff_both_incorrect).mean().item()) + ')')
plt.axvline(x = torch.tensor(dist_dino_diff_only_diff_correct).mean(), color = 'b', label = 'only diff correct (' + str(torch.tensor(dist_dino_diff_only_diff_correct).mean().item()) + ')')
plt.axvline(x = torch.tensor(dist_dino_diff_only_dino_correct).mean(), color = 'y', label = 'only dino correct (' + str(torch.tensor(dist_dino_diff_only_dino_correct).mean().item()) + ')')

plt.legend()
plt.show()

In [None]:
# Display k examples of samples where the distance between the predictions is the highest
k = 5

dist_dino_diff = []
for pred_diff, pred_dino in zip(predictions_diff, predictions_dino):
    dist_dino_diff.append(torch.linalg.norm(pred_diff.float() - pred_dino.float(), dim=1).mean())   
    
arg_dist_dino_diff = torch.tensor(dist_dino_diff).argsort(descending=True)

for i in range(k):
    sample = dataset[arg_dist_dino_diff[i]]
    print("Highest distance")
    sample["target_points"] = [flip_points(predictions_diff[arg_dist_dino_diff[i]]),
                               flip_points(predictions_dino[arg_dist_dino_diff[i]]),
                               flip_points(predictions_dino_diff[arg_dist_dino_diff[i]])]
    display_image_pair(sample, show_bbox=True, labels=["Diff", "DINO", "Diff+DINO"])

In [None]:
# all failure cases are due to lack of spatial coherence (when objects are facing different sides) from both models

In [None]:
# Avg distance between (dino+diff, dino) and (dino+diff, diff)
dist_dino_diff_dino = []
dist_dino_diff_diff = []

for i in range(len(predictions_diff)):
    for j in range(len(predictions_diff[i])):
        dist_dino_diff_dino.append(torch.linalg.norm(predictions_dino_diff[i][j].float() - predictions_dino[i][j].float()))
        dist_dino_diff_diff.append(torch.linalg.norm(predictions_dino_diff[i][j].float() - predictions_diff[i][j].float()))

print(torch.tensor(dist_dino_diff_dino).mean(), torch.tensor(dist_dino_diff_dino).std())
print(torch.tensor(dist_dino_diff_diff).mean(), torch.tensor(dist_dino_diff_diff).std())

In [None]:
# assess individual performance on: object class, view direction, truncation and occlusion
import os
import json

dataset_directory = "/export/group/datasets/SPair-71k"
split = "test"

images_dir = os.path.join(dataset_directory, 'JPEGImages')
annotations_dir = os.path.join(dataset_directory, 'PairAnnotation', split)
annotations_files = [f for f in os.listdir(annotations_dir) if f.endswith('.json')]

annotations = []
for annotation_file in annotations_files:
    with open(os.path.join(annotations_dir, annotation_file), 'r') as file:
        annotation = json.load(file)
    annotations.append(annotation)

print(annotation.keys())
print(annotation['src_pose'], annotation['trg_pose'], annotation['category'], annotation['truncation'], annotation['occlusion'], annotation['viewpoint_variation'], annotation['scale_variation'])

In [None]:
poses = set([annotation['src_pose'] for annotation in annotations])
print(poses)

In [None]:
# Categorize corrects of diff and dino into: category, viewpoint_variation, scale_variation, truncation, occlusion and pose

correct_diff_category = torch.zeros(18)
correct_dino_category = torch.zeros(18)
correct_diff_viewpoint_variation = torch.zeros(3)
correct_dino_viewpoint_variation = torch.zeros(3)
correct_diff_scale_variation = torch.zeros(3)
correct_dino_scale_variation = torch.zeros(3)
correct_diff_truncation = torch.zeros(4)
correct_dino_truncation = torch.zeros(4)
correct_diff_occlusion = torch.zeros(4)
correct_dino_occlusion = torch.zeros(4)
correct_diff_pose = torch.zeros(3) # same, different, opposite
correct_dino_pose = torch.zeros(3) # same, different, opposite

category_map = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'train', 'tvmonitor']
category_map = {category: i for i, category in enumerate(category_map)}

for i in range(len(predictions_diff)):
    for j in range(len(predictions_diff[i])):
        if correct_diff[i][j]:
            cat = category_map[annotations[i]['category']]
            correct_diff_category[cat] += 1
            correct_diff_viewpoint_variation[annotations[i]['viewpoint_variation']] += 1
            correct_diff_scale_variation[annotations[i]['scale_variation']] += 1
            correct_diff_truncation[annotations[i]['truncation']] += 1
            correct_diff_occlusion[annotations[i]['occlusion']] += 1
            if annotations[i]['src_pose'] == annotations[i]['trg_pose']:
                correct_diff_pose[0] += 1
            elif (annotations[i]['src_pose'] == 'Right' and annotations[i]['trg_pose'] == 'Left') or (annotations[i]['src_pose'] == 'Left' and annotations[i]['trg_pose'] == 'Right'):
                correct_diff_pose[2] += 1
            elif (annotations[i]['src_pose'] == 'Frontal' and annotations[i]['trg_pose'] == 'Rear') or (annotations[i]['src_pose'] == 'Rear' and annotations[i]['trg_pose'] == 'Frontal'):
                correct_diff_pose[2] += 1
            elif annotations[i]['src_pose'] != 'Unspecified' and annotations[i]['trg_pose'] != 'Unspecified':
                correct_diff_pose[1] += 1

        if correct_dino[i][j]:
            cat = category_map[annotations[i]['category']]
            correct_dino_category[cat] += 1
            correct_dino_viewpoint_variation[annotations[i]['viewpoint_variation']] += 1
            correct_dino_scale_variation[annotations[i]['scale_variation']] += 1
            correct_dino_truncation[annotations[i]['truncation']] += 1
            correct_dino_occlusion[annotations[i]['occlusion']] += 1
            if annotations[i]['src_pose'] == annotations[i]['trg_pose']:
                correct_dino_pose[0] += 1
            elif (annotations[i]['src_pose'] == 'Right' and annotations[i]['trg_pose'] == 'Left') or (annotations[i]['src_pose'] == 'Left' and annotations[i]['trg_pose'] == 'Right'):
                correct_dino_pose[2] += 1
            elif (annotations[i]['src_pose'] == 'Frontal' and annotations[i]['trg_pose'] == 'Rear') or (annotations[i]['src_pose'] == 'Rear' and annotations[i]['trg_pose'] == 'Frontal'):
                correct_dino_pose[2] += 1
            elif annotations[i]['src_pose'] != 'Unspecified' and annotations[i]['trg_pose'] != 'Unspecified':
                correct_dino_pose[1] += 1

# Calculate proportions
correct_diff_category /= correct_diff_category.sum()
correct_dino_category /= correct_dino_category.sum()
correct_diff_viewpoint_variation /= correct_diff_viewpoint_variation.sum()
correct_dino_viewpoint_variation /= correct_dino_viewpoint_variation.sum()
correct_diff_scale_variation /= correct_diff_scale_variation.sum()
correct_dino_scale_variation /= correct_dino_scale_variation.sum()
correct_diff_truncation /= correct_diff_truncation.sum()
correct_dino_truncation /= correct_dino_truncation.sum()
correct_diff_occlusion /= correct_diff_occlusion.sum()
correct_dino_occlusion /= correct_dino_occlusion.sum()
correct_diff_pose /= correct_diff_pose.sum()
correct_dino_pose /= correct_dino_pose.sum()

# Plot
plt.figure(figsize=(10, 5))
plt.bar(category_map.keys(), correct_diff_category, label="ADD-XL")
plt.bar(category_map.keys(), correct_dino_category, label="DINO", alpha=0.5)
plt.xticks(rotation=45)
plt.title("Category")
plt.legend()
plt.show()

plt.figure(figsize=(10, 5))
plt.bar(['easy', 'medium', 'hard'], correct_diff_viewpoint_variation, label="ADD-XL")
plt.bar(['easy', 'medium', 'hard'], correct_dino_viewpoint_variation, label="DINO", alpha=0.5)
plt.title("Viewpoint variation")
plt.legend()
plt.show()

plt.figure(figsize=(10, 5))
plt.bar(['easy', 'medium', 'hard'], correct_diff_scale_variation, label="ADD-XL")
plt.bar(['easy', 'medium', 'hard'], correct_dino_scale_variation, label="DINO", alpha=0.5)
plt.title("Scale variation")
plt.legend()
plt.show()

plt.figure(figsize=(10, 5))
plt.bar(['none', 'source', 'target', 'both'], correct_diff_truncation, label="ADD-XL")
plt.bar(['none', 'source', 'target', 'both'], correct_dino_truncation, label="DINO", alpha=0.5)
plt.title("Truncation")
plt.legend()
plt.show()

plt.figure(figsize=(10, 5))
plt.bar(['none', 'source', 'target', 'both'], correct_diff_occlusion, label="ADD-XL")
plt.bar(['none', 'source', 'target', 'both'], correct_dino_occlusion, label="DINO", alpha=0.5)
plt.title("Occlusion")
plt.legend()
plt.show()

plt.figure(figsize=(10, 5))
plt.bar(['same', 'different', 'opposite'], correct_diff_pose, label="ADD-XL")
plt.bar(['same', 'different', 'opposite'], correct_dino_pose, label="DINO", alpha=0.5)
plt.title("Pose")
plt.legend()

In [None]:
# Aggregate correct part and subpart predictions
# Visualize the results

from datasets.correspondence import S2K

dataset = S2K({
    "path": "/export/group/datasets/PASCAL-Part",
    "only_non_unique": True
})
num_samples = 1000

torch.manual_seed(42)
dataset.data = [dataset.data[i] for i in torch.randperm(len(dataset))]
dataset.data = dataset.data[:num_samples]

In [None]:
s2k_data = [dataset[i] for i in range(num_samples)]

In [None]:
# Load predictions
predictions_s2k_diff, _, _ = torch.load("../analysis_data/predictions_s2k_diff.pt")
predictions_s2k_dino_s, _, _ = torch.load("../analysis_data/predictions_s2k_dino_s.pt")
predictions_s2k_dino_diff_s, _, _ = torch.load("../analysis_data/predictions_s2k_dino_diff_s.pt")

In [None]:
# For each prediction, check whether the part it is in the same class as the ground truth part
same_class_diff = 0
same_subclass_diff = 0
same_class_dino_s = 0
same_subclass_dino_s = 0
same_class_dino_diff_s = 0
same_subclass_dino_diff_s = 0
n_points = 0

for i in range(len(predictions_s2k_diff)):
    for j in range(len(predictions_s2k_diff[i])):
        gt_subclass = s2k_data[i]["source_annotation"]["parts"][j]["class"]
        gt_class = gt_subclass.split()[-1]
        y, x = predictions_s2k_diff[i][j]
        correct_subclass = False
        correct_class = False
        for p in s2k_data[i]["target_annotation"]["parts"]:
            m = p["mask"]
            if m[y, x] == 1 and p["class"] == gt_subclass:
                correct_subclass = True
            if m[y, x] == 1 and p["class"].split()[-1] == gt_class:
                correct_class = True
        if correct_subclass:
            same_subclass_diff += 1
        if correct_class:
            same_class_diff += 1
        n_points += 1
        
for i in range(len(predictions_s2k_dino_s)):
    for j in range(len(predictions_s2k_dino_s[i])):
        gt_subclass = s2k_data[i]["source_annotation"]["parts"][j]["class"]
        gt_class = gt_subclass.split()[-1]
        y, x = predictions_s2k_dino_s[i][j]
        correct_subclass = False
        correct_class = False
        for p in s2k_data[i]["target_annotation"]["parts"]:
            m = p["mask"]
            if m[y, x] == 1 and p["class"] == gt_subclass:
                correct_subclass = True
            if m[y, x] == 1 and p["class"].split()[-1] == gt_class:
                correct_class = True
        if correct_subclass:
            same_subclass_dino_s += 1
        if correct_class:
            same_class_dino_s += 1

for i in range(len(predictions_s2k_dino_diff_s)):
    for j in range(len(predictions_s2k_dino_diff_s[i])):
        gt_subclass = s2k_data[i]["source_annotation"]["parts"][j]["class"]
        gt_class = gt_subclass.split()[-1]
        y, x = predictions_s2k_dino_diff_s[i][j]
        correct_subclass = False
        correct_class = False
        for p in s2k_data[i]["target_annotation"]["parts"]:
            m = p["mask"]
            if m[y, x] == 1 and p["class"] == gt_subclass:
                correct_subclass = True
            if m[y, x] == 1 and p["class"].split()[-1] == gt_class:
                correct_class = True
        if correct_subclass:
            same_subclass_dino_diff_s += 1
        if correct_class:
            same_class_dino_diff_s += 1

print("Diff (correct class, correct subclass)", same_class_diff / n_points, same_subclass_diff / n_points)
print("DINO (correct class, correct subclass)", same_class_dino_s / n_points, same_subclass_dino_s / n_points)
print("Diff+DINO (correct class, correct subclass)", same_class_dino_diff_s / n_points, same_subclass_dino_diff_s / n_points)

In [None]:
# Class-wise accuracy on the part and subpart level

class_correct_diff = torch.zeros(18)
subclass_correct_diff = torch.zeros(18)
class_correct_dino_s = torch.zeros(18)
subclass_correct_dino_s = torch.zeros(18)
class_correct_dino_diff_s = torch.zeros(18)
subclass_correct_dino_diff_s = torch.zeros(18)

category_map = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'train', 'tvmonitor']
category_map = {category: i for i, category in enumerate(category_map)}

for i in range(len(predictions_s2k_diff)):
    src_class = category_map[s2k_data[i]["source_category"]]
    trg_class = category_map[s2k_data[i]["target_category"]]
    for j in range(len(predictions_s2k_diff[i])):
        gt_subclass = s2k_data[i]["source_annotation"]["parts"][j]["class"]
        gt_class = gt_subclass.split()[-1]
        y, x = predictions_s2k_diff[i][j]
        correct_subclass = False
        correct_class = False
        for p in s2k_data[i]["target_annotation"]["parts"]:
            m = p["mask"]
            if m[y, x] == 1 and p["class"] == gt_subclass:
                correct_subclass = True
            if m[y, x] == 1 and p["class"].split()[-1] == gt_class:
                correct_class = True
        if correct_subclass:
            subclass_correct_diff[src_class] += 1
            subclass_correct_diff[trg_class] += 1
        if correct_class:
            class_correct_diff[src_class] += 1
            class_correct_diff[trg_class] += 1

for i in range(len(predictions_s2k_dino_s)):
    src_class = category_map[s2k_data[i]["source_category"]]
    trg_class = category_map[s2k_data[i]["target_category"]]
    for j in range(len(predictions_s2k_dino_s[i])):
        gt_subclass = s2k_data[i]["source_annotation"]["parts"][j]["class"]
        gt_class = gt_subclass.split()[-1]
        y, x = predictions_s2k_dino_s[i][j]
        correct_subclass = False
        correct_class = False
        for p in s2k_data[i]["target_annotation"]["parts"]:
            m = p["mask"]
            if m[y, x] == 1 and p["class"] == gt_subclass:
                correct_subclass = True
            if m[y, x] == 1 and p["class"].split()[-1] == gt_class:
                correct_class = True
        if correct_subclass:
            subclass_correct_dino_s[src_class] += 1
            subclass_correct_dino_s[trg_class] += 1
        if correct_class:
            class_correct_dino_s[src_class] += 1
            class_correct_dino_s[trg_class] += 1

for i in range(len(predictions_s2k_dino_diff_s)):
    src_class = category_map[s2k_data[i]["source_category"]]
    trg_class = category_map[s2k_data[i]["target_category"]]
    for j in range(len(predictions_s2k_dino_diff_s[i])):
        gt_subclass = s2k_data[i]["source_annotation"]["parts"][j]["class"]
        gt_class = gt_subclass.split()[-1]
        y, x = predictions_s2k_dino_diff_s[i][j]
        correct_subclass = False
        correct_class = False
        for p in s2k_data[i]["target_annotation"]["parts"]:
            m = p["mask"]
            if m[y, x] == 1 and p["class"] == gt_subclass:
                correct_subclass = True
            if m[y, x] == 1 and p["class"].split()[-1] == gt_class:
                correct_class = True
        if correct_subclass:
            subclass_correct_dino_diff_s[src_class] += 1
            subclass_correct_dino_diff_s[trg_class] += 1
        if correct_class:
            class_correct_dino_diff_s[src_class] += 1
            class_correct_dino_diff_s[trg_class] += 1

class_correct_diff /= n_points
class_correct_dino_s /= n_points
class_correct_dino_diff_s /= n_points

plt.figure(figsize=(10, 5))
plt.bar(category_map.keys(), class_correct_diff, label="ADD-XL")
plt.bar(category_map.keys(), class_correct_dino_s, label="DINO", alpha=0.5)
plt.bar(category_map.keys(), class_correct_dino_diff_s, label="DINO+ADD-XL", alpha=0.5)
plt.xticks(rotation=45)
plt.title("Class")
plt.legend()
plt.show()

subclass_correct_diff /= n_points
subclass_correct_dino_s /= n_points
subclass_correct_dino_diff_s /= n_points

plt.figure(figsize=(10, 5))
plt.bar(category_map.keys(), subclass_correct_diff, label="ADD-XL")
plt.bar(category_map.keys(), subclass_correct_dino_s, label="DINO", alpha=0.5)
plt.bar(category_map.keys(), subclass_correct_dino_diff_s, label="DINO+ADD-XL", alpha=0.5)
plt.xticks(rotation=45)
plt.title("Part")
plt.legend()
plt.show()