In [None]:
import json

import matplotlib.pyplot as plt
import numpy as np
import random
import pandas as pd
import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch
from transformers import AutoModelForSemanticSegmentation, AutoImageProcessor
import torch.nn as nn
from model.directsam import DirectSAM
from evaluation.metrics import recall_with_tolerance

device = "cuda:0"

from data.create_dataset import create_dataset

dataset_configs = json.load(open('data/dataset_configs.json', 'r'))
print(dataset_configs.keys())

In [None]:
resolution = 1024
threshold = 0.3

tolerance = resolution // 100 
tolerance += tolerance % 2 == 0

model = DirectSAM(
    # "chendelong/DirectSAM-1800px-0424",
    "/home/dchenbs/workspace/DirectSAM/runs/directsam_pseudo_label_merged/0829-1210-1024px-from-chendelong_DirectSAM-1800px-0424/checkpoint-6000",
    resolution, 
    threshold,
    device
    )

In [None]:
import cv2
from torch.nn import functional as F


def compare_boundaries(target, prediction, tolerance, linewidth, brightness=192):

    target_blured = cv2.GaussianBlur(target.astype(np.float32), (tolerance, tolerance), 0) > 0
    prediction_blured = cv2.GaussianBlur(prediction.astype(np.float32), (tolerance, tolerance), 0) > 0

    gray = target * prediction_blured
    red = target * (prediction_blured == 0)
    blue = prediction * (target_blured == 0)

    gray = cv2.GaussianBlur(gray.astype(np.float32), (linewidth, linewidth), 0) > 0
    red = cv2.GaussianBlur(red.astype(np.float32), (linewidth, linewidth), 0) > 0
    blue = cv2.GaussianBlur(blue.astype(np.float32), (linewidth, linewidth), 0) > 0

    image = np.ones((target.shape[0], target.shape[1], 3)) * 255
    image[gray] = [brightness, brightness, brightness]
    image[red] = [brightness, 0, 0]
    image[blue] = [0, 0, brightness]

    image[:tolerance, :, :] = image[-tolerance:, :, :] = image[:, :tolerance, :] = image[:, -tolerance:, :] = 255

    return image.astype(np.uint8)



In [None]:
# for dataset_name in ['GTA5', 'DRAM', 'SOBA', 'UDA-Part', 'SeginW', 'CIHP', 'Fashionpedia', 'PartIT', 'PascalPanopticParts', 'SPIN', 'PartImageNet++', 'ADE20k', 'EntitySeg', 'LoveDA', 'COCONut_relabeld_COCO_val', 'COCONut-s', 'COCONut-b', 'COCO2017', 'LVIS']:
for dataset_name in ['EntitySeg', 'PascalPanopticParts']:
# for dataset_name in ['directsam_pseudo_label_merged']:
# for dataset_name in ['COIFT', 'DIS5K-DIS-TR', 'DIS5K-DIS-VD', 'DUTS-TE', 'DUTS-TR', 'ecssd', 'fss_all', 'HRSOD', 'MSRA_10K', 'ThinObject5K']:
# for dataset_name in list(dataset_configs.keys()):

    dataset_config = dataset_configs[dataset_name]

    dataset = create_dataset(dataset_config, split='validation', resolution=resolution, thickness=2)

    print(dataset_config)
    print(dataset_name)
    print(len(dataset))


    for i in tqdm.tqdm(range(5)):
        # sample = dataset[random.randint(0, len(dataset)-1)]
        sample = dataset[i]

        if type(sample) == dict:
            image = sample['image']
            target = sample['label']
        else:
            image, target = sample

        prediction, num_tokens = model(image)
        recall = recall_with_tolerance(target, prediction, tolerance)

        plt.figure(figsize=(20, 20))
        plt.subplot(2, 2, 1)
        plt.imshow(image)
        plt.axis('off')
        plt.title('Input image')
        
        plt.subplot(2, 2, 2)
        plt.imshow(compare_boundaries(target, prediction, tolerance=tolerance, linewidth=3))
        plt.title(f'Recall: {recall:.2f}')
        plt.axis('off')

        plt.subplot(2, 2, 3)
        plt.imshow(target, cmap='Reds')
        plt.imshow(image, alpha=0.3)
        plt.axis('off')
        plt.title('Ground Truth Label')


        plt.subplot(2, 2, 4)
        plt.imshow(prediction, cmap='Blues')
        plt.imshow(image, alpha=0.3)
        plt.title(f'DirectSAM Pseudo label ({num_tokens} tokens)')
        plt.axis('off')

        plt.tight_layout()
        plt.show()