In [None]:
import json
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import random
import pandas as pd
import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch
from transformers import AutoModelForSemanticSegmentation, AutoImageProcessor
import torch.nn as nn
from model.directsam import DirectSAM
from evaluation.metrics import recall_with_tolerance
from evaluation.visualization import compare_boundaries
import cv2

device = "cuda:2"

from data.create_dataset import create_dataset

dataset_configs = json.load(open('data/dataset_configs.json', 'r'))
print(dataset_configs.keys())

train_datasets = [
    'LIP', 'CelebA', 'SOBA', 'SeginW', 'CIHP', 'Fashionpedia', 'PascalPanopticParts', 'SPIN', 'PartImageNet++', 'ADE20k', 'EntitySeg', 'LoveDA', 'COCONut-s', 'COCONut-b', 'COCONut-l', 'PACO', 'LVIS', 'COIFT', 'DIS5K-DIS-TR', 'DUTS-TR', 'ecssd', 'fss_all', 'HRSOD', 'MSRA_10K', 'ThinObject5K'
    ]

validataion_datases = [
    'LIP', 'DRAM', 'SOBA', 'SeginW', 'CIHP', 'Fashionpedia', 'PascalPanopticParts', 'SPIN', 'PartImageNet++', 'ADE20k', 'EntitySeg', 'LoveDA', 'COCONut_relabeld_COCO_val', 'PACO', 'LVIS', 'DIS5K-DIS-VD', 'DUTS-TE'
    ]

In [None]:
inference_resolution = 768
evaluation_resolution = 768
threshold = 0.3

tolerance = inference_resolution // 100 
tolerance += tolerance % 2 == 0

model = DirectSAM(
    "chendelong/DirectSAM-1800px-0424",
    # "chendelong/DirectSAM-tiny-distilled-70ep-1024px-0920",
    inference_resolution, 
    threshold,
    device
    )

In [None]:
# for dataset_name in ['directsa_plus']:
for dataset_name in ['directsam_pseudo_label_merged_denoised']:
# for dataset_name in ['COCONut-l']:
# for dataset_name in train_datasets: 
# for dataset_name in list(dataset_configs.keys()):

    dataset_config = dataset_configs[dataset_name]

    dataset = create_dataset(dataset_config, split='validation', resolution=evaluation_resolution, thickness=-1)

    print(dataset_config)
    print(dataset_name)
    print(len(dataset))

    all_tokens = []
    all_recall = []
    
    for i in tqdm.tqdm(range(10)):
        sample = dataset[random.randint(0, len(dataset)-1)]
        # sample = dataset[i]

        # image_path = dataset.image_paths[i]
        # image = Image.open(image_path)
        # print(image_path, image.size)

        if type(sample) == dict:
            image = sample['image']
            target = sample['label']
        else:
            image, target = sample

        prediction, num_tokens = model(image, post_processing=True)
        prediction = cv2.resize(prediction.astype(np.float32), (evaluation_resolution, evaluation_resolution), interpolation=cv2.INTER_NEAREST).astype(np.uint8)
        # recall = recall_with_tolerance(target, prediction, tolerance)
        recall = 0
        print(f"Recall: {recall:.4f}\tNumber of tokens: {num_tokens}")

        plt.figure(figsize=(15, 15))
        plt.subplot(2, 2, 1)
        plt.imshow(image)
        plt.axis('off')
        plt.title('Input image')
        
        plt.subplot(2, 2, 2)
        plt.imshow(compare_boundaries(target, prediction, tolerance=tolerance, linewidth=3))
        plt.title(f'Recall: {recall:.2f}')
        plt.axis('off')

        plt.subplot(2, 2, 3)
        plt.imshow(target, cmap='Reds')
        plt.imshow(image, alpha=0.3)
        plt.axis('off')
        plt.title('Ground Truth Label')


        plt.subplot(2, 2, 4)
        plt.imshow(prediction, cmap='Blues')
        plt.imshow(image, alpha=0.3)
        plt.title(f'DirectSAM Pseudo label ({num_tokens} tokens)')
        plt.axis('off')

        plt.tight_layout()
        plt.show()

        all_tokens.append(num_tokens)
        all_recall.append(recall)

    print(f"Average number of tokens: {np.mean(all_tokens):.2f}")
    print(f"Average recall: {np.mean(all_recall):.4f}")