In [None]:
import json

import matplotlib.pyplot as plt
import numpy as np
import random
import pandas as pd
import tqdm
import numpy as np

device = "cuda:4"

from data.create_dataset import create_dataset

dataset_configs = json.load(open('data/dataset_configs.json', 'r'))
print(dataset_configs.keys())

In [None]:

import matplotlib.pyplot as plt
import torch
from transformers import AutoModelForSemanticSegmentation, AutoImageProcessor
import torch.nn as nn


class DirectSAM():

    def __init__(self, model_name, resolution, device):
        self.model = AutoModelForSemanticSegmentation.from_pretrained(model_name).to(device).half().eval()
        self.processor = AutoImageProcessor.from_pretrained('chendelong/DirectSAM-1800px-0424')

        self.processor.size['height'] = resolution
        self.processor.size['width'] = resolution
        self.resolution = resolution

    def __call__(self, image):
        pixel_values = self.processor(image, return_tensors="pt").pixel_values.to(self.model.device).to(self.model.dtype)
        logits = self.model(pixel_values=pixel_values).logits.float().cpu()

        upsampled_logits = nn.functional.interpolate(
            logits,
            size=(self.resolution, self.resolution),
            mode="bicubic",
        )
        probabilities = torch.sigmoid(upsampled_logits).detach().numpy()[0,0]
        return probabilities


resolution = 1024
thickness = 3

model = DirectSAM(
    "chendelong/DirectSAM-tiny-distilled-15ep-768px-0821",
    resolution, 
    device
    )

In [None]:
# for dataset_name in ['ADE20k', 'EntitySeg', 'COCONut_relabeld_COCO_val', 'LoveDA', 'PascalPanopticParts', 'PartImageNet++']:
for dataset_name in ['EntitySeg']:

    dataset_config = dataset_configs[dataset_name]

    dataset = create_dataset(dataset_config, split='validation', resolution=resolution, thickness=thickness)

    print(dataset_config)
    print(dataset_name)
    print(len(dataset))

    for i in range(5):
        sample = dataset[random.randint(0, len(dataset)-1)]

        if type(sample) == dict:
            image = sample['image']
            target = sample['label']
        else:
            image, target = sample

        prob = model(image)

        plt.figure(figsize=(30, 6))
        plt.subplot(1, 4, 1)
        plt.imshow(image)
        
        plt.subplot(1, 4, 2)
        plt.imshow(target)

        plt.subplot(1, 4, 3)
        plt.imshow(prob)

        plt.subplot(1, 4, 4)
        plt.imshow(prob > 0.5)

        plt.show()