In [None]:
import random
import matplotlib.pyplot as plt
import numpy as np
import os
import json
import cv2
import tqdm

from data.create_dataset import create_dataset

dataset_configs = json.load(open('data/dataset_configs.json', 'r'))
print(list(dataset_configs.keys()))
print()

In [None]:
from model.directsam import DirectSAM

# model = DirectSAM(
#     # "chendelong/DirectSAM-1800px-0424",
#     '/home/dchenbs/workspace/DirectSAM/runs/DSA_merged/1006-2038-1024px-from-chendelong_DirectSAM-1800px-0424/checkpoint-20000',
#     resolution=1024,
#     device='cuda:1'
#     )

model = None

In [None]:
def visualize_colored_label_map_with_contour(target_contour, label_map, image):
    
    image = np.array(image)
    canvas = np.zeros((label_map.shape[0], label_map.shape[1], 3))

    for label_id in np.unique(label_map):
        mask = label_map == label_id
        canvas[mask] = np.mean(image[mask], axis=0)

    canvas = np.clip(canvas, 0, 255).astype(np.uint8) 
    canvas[target_contour] = [255, 255, 255]
    return canvas

def visualize_image_with_contour(contour, image):

    image = np.array(image) * 0.6
    image[contour] = [255, 255, 255]
    return image.astype(np.uint8)



In [None]:

# keys = ['ADE20k']
keys = ['DSA_gen1']
# keys = ['ADE20k', 'EntitySeg', 'COCONut_relabeld_COCO_val', 'LoveDA']

# keys = ['DSA_merged', 'SA1B_contour', 'SA1B', 'SA1B_116', 'MapillaryMetropolis', 'cityscapes', 'plantorgans', 'NYUDepthv2', 'VegAnn', 'tcd', 'sidewalk', 'FoodSeg103', 'ADE20k', 'COCONut_relabeld_COCO_val', 'COCONut-s', 'COCONut-b', 'COCONut-l', 'WireFrame', 'ISAID', 'OpenEarthMap', 'TreeCount', 'PhenoBench', 'EgoHOS', 'UAVID', 'LIP', 'CelebA', 'SOBA', 'CIHP', 'LoveDA', 'EntitySeg', 'PascalPanopticParts', 'SPIN', 'SUIM', 'MyFood', 'COIFT', 'DIS5K-DIS-TR', 'DIS5K-DIS-VD', 'DUTS-TE', 'DUTS-TR', 'ecssd', 'fss_all', 'HRSOD', 'MSRA_10K', 'ThinObject5K', 'Fashionpedia', 'PartImageNetPP', 'SeginW', 'LVIS', 'PACO', 'GTA5', 'COCO2017', 'DRAM']


for key in keys:
    config = dataset_configs[key]
    for split in ['train']:# 'train', 'validation'

        dataset = create_dataset(config, split, 1024, thickness=5)

        print(key, split, len(dataset), len(dataset.image_paths))
        print(config)
        print(f'\t"{key}" # {len(dataset)}')

        for index in range(5):
            index = random.randint(0, len(dataset)-1)
            sample = dataset[index]

            image = sample['image']
            target_contour = sample['label']
            label_map = sample['label_map']

            if config['type'] == 'DSA':
                num_objects, label_map = cv2.connectedComponents(1-label_map.astype(np.uint8))

            plt.figure(figsize=(25, 15))

            plt.subplot(1, 5, 1)
            plt.title('Image')
            plt.imshow(image)
            plt.axis('off')

            plt.subplot(1, 5, 2)
            plt.title('Label')
            plt.imshow(visualize_colored_label_map_with_contour(target_contour, label_map, image))
            plt.axis('off')


            if model is not None:

                probs = model(image)

                plt.subplot(1, 5, 3)
                plt.title('Label and DirectSAM Label')
                pseudo_label = probs > 0.5
                pseudo_label = cv2.resize(pseudo_label.astype(np.uint8), (target_contour.shape[1], target_contour.shape[0]), interpolation=cv2.INTER_NEAREST)

                plt.imshow(pseudo_label, cmap='Reds')
                plt.imshow(target_contour, cmap='Greens', alpha=0.5)
                plt.axis('off')

                num_objects, pseudo_label_map = cv2.connectedComponents(1-pseudo_label)
                plt.subplot(1, 5, 4)
                plt.title(f'DirectSAM Label Map ({num_objects}) tokens')
                plt.imshow(visualize_colored_label_map_with_contour(pseudo_label > 0, pseudo_label_map, image))
                plt.axis('off')

                plt.subplot(1, 5, 5)
                plt.title('DirectSAM Probabilities')
                plt.imshow((np.array(image) * 0.2).astype(np.uint8))
                plt.imshow(np.ones_like(probs), alpha=probs, cmap='binary')
                plt.axis('off')


            plt.tight_layout()
            plt.show()