In [None]:
%env CUDA_VISIBLE_DEVICES=1
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")

In [None]:
from datasets.correspondence import S2K

dataset = S2K({
    "path": "/export/group/datasets/PASCAL-Part",
})

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def generate_part_class_to_color_map(annotations):
    unique_part_classes = set(part['class'] for part in annotations)
    cmap = plt.cm.get_cmap('hsv', len(unique_part_classes) + 1)  # +1 to ensure enough colors
    part_class_to_color = {part_class: cmap(i)[:3] + (0.5,) for i, part_class in enumerate(unique_part_classes)}
    return part_class_to_color

def plot_image_and_masks(image, annotations, uparts, part_class_to_color):
    plt.imshow(image)
    plt.title(annotations['class'] + ': ' + ', '.join([part['class'] for part in [annotations['parts'][i] for i in uparts]]))
    
    for part in [annotations['parts'][i] for i in uparts]:
        part_mask = part['mask']
        colored_part_mask = np.zeros((part_mask.shape[0], part_mask.shape[1], 4))
        color = part_class_to_color.get(part['class'], np.random.rand(3,).tolist() + [0.5])  # Ensures transparency
        colored_part_mask[part_mask == 1] = color
        plt.imshow(colored_part_mask, interpolation='nearest')

# Sample dataset
for i in [0, 25, 60, 80]:
    sample = dataset[i]
    
    # Combine annotations to generate a comprehensive part-to-color mapping
    combined_annotations = sample["source_annotation"]['parts'] + sample["target_annotation"]['parts']
    part_class_to_color = generate_part_class_to_color_map(combined_annotations)

    plt.figure(figsize=(20, 10))

    # Plot source image and masks
    plt.subplot(1, 2, 1)
    plot_image_and_masks(sample["source_image"], sample["source_annotation"], sample['source_parts'], part_class_to_color)

    # Plot target image and masks
    plt.subplot(1, 2, 2)
    plot_image_and_masks(sample["target_image"], sample["target_annotation"], sample['target_parts'], part_class_to_color)

    plt.show()