In [None]:
import random
import matplotlib.pyplot as plt
import numpy as np
import os
import json
import cv2
import tqdm

from data.create_dataset import create_dataset

dataset_configs = json.load(open('data/dataset_configs.json', 'r'))
print(dataset_configs.keys())

In [None]:
# from model.directsam import DirectSAM

# model = DirectSAM(
#     # "chendelong/DirectSAM-1800px-0424",
#     '/home/dchenbs/workspace/DirectSAM/runs/DSA_merged/1004-1425-1024px-from-chendelong_DirectSAM-1800px-0424/checkpoint-5000',
#     resolution=1024,
#     device='cuda'
#     )

model = None

In [None]:
def visualize_colored_label_map(label_map, image):
    
    image = np.array(image)
    canvas = np.zeros((label_map.shape[0], label_map.shape[1], 3))

    for label_id in np.unique(label_map):
        mask = label_map == label_id
        canvas[mask] = np.mean(image[mask], axis=0)

    canvas = np.clip(canvas * 1.1, 0, 255)
    return canvas.astype(np.uint8) 

def visualize_image_with_contour(contour, image):

    image = np.array(image) * 0.6
    image[contour] = [255, 255, 255]
    return image.astype(np.uint8)

In [None]:

# keys = ['DSA_merged']
keys = ['ISAID']
# keys = ['ADE20k', 'EntitySeg', 'COCONut_relabeld_COCO_val', 'LoveDA']

# keys = ['SA1B', 'ADE20k', 'COCONut-s', 'COCONut-b', 'COCONut-l', 'LIP', 'CelebA', 'SOBA', 'CIHP', 'LoveDA', 'EntitySeg', 'PascalPanopticParts', 'SPIN', 'SUIM', 'MyFood', 'COIFT', 'DIS5K-DIS-TR','DUTS-TR', 'ecssd', 'fss_all', 'HRSOD', 'MSRA_10K', 'ThinObject5K', 'Fashionpedia', 'PartImageNetPP', 'SeginW', 'LVIS', 'PACO', 'GTA5']

for key in keys:
    config = dataset_configs[key]
    for split in ['train']:# 'train', 'validation'

        dataset = create_dataset(config, split, 1024, thickness=5)

        print(key, split, len(dataset), len(dataset.image_paths))
        print(config)
        print(f'\t"{key}" # {len(dataset)}')

        for index in range(3):
            index = random.randint(0, len(dataset)-1)
            sample = dataset[index]

            image = sample['image']
            label = sample['label']
            label_map = sample['label_map']

            # if config['type'] != 'DSA':
            #     image_path = dataset.image_paths[index]
            #     assert os.path.exists(image_path)

            plt.figure(figsize=(20, 5))

            plt.subplot(1, 4, 1)
            plt.title('Label Map')
            plt.imshow(visualize_colored_label_map(label_map, image))
            plt.axis('off')

            plt.subplot(1, 4, 2)
            plt.title('Label Contour')
            plt.imshow(visualize_image_with_contour(label, image))
            plt.axis('off')

            if model is None:
                plt.subplot(1, 4, 3)
                plt.title('Label')
                plt.imshow(label, cmap='Greens', alpha=0.5)
                plt.axis('off')
            else:

                probs = model(image)

                plt.subplot(1, 4, 3)
                plt.title('Label and Pseudo Label')
                pseudo_label = probs > 0.5
                pseudo_label = cv2.resize(pseudo_label.astype(np.uint8), (label.shape[1], label.shape[0]), interpolation=cv2.INTER_NEAREST)

                plt.imshow(pseudo_label, cmap='Reds')
                plt.imshow(label, cmap='Greens', alpha=0.5)
                plt.axis('off')

                plt.subplot(1, 4, 4)
                plt.title('DirectSAM Probabilities')
                plt.imshow(probs, cmap='Reds')
                plt.axis('off')

            plt.tight_layout()

            plt.show()