In [None]:
import json

import matplotlib.pyplot as plt
import numpy as np
import random
import pandas as pd
import tqdm
import numpy as np

device = "cuda:3"

from create_dataset import create_dataset

dataset_configs = json.load(open('config.json', 'r'))
print(dataset_configs.keys())

In [None]:

import matplotlib.pyplot as plt
import torch
from transformers import AutoModelForSemanticSegmentation, AutoImageProcessor
import torch.nn as nn


class DirectSAM():

    def __init__(self, model_name, resolution, device):
        self.model = AutoModelForSemanticSegmentation.from_pretrained(model_name).to(device).half().eval()
        self.processor = AutoImageProcessor.from_pretrained('chendelong/DirectSAM-1800px-0424')

        self.processor.size['height'] = resolution
        self.processor.size['width'] = resolution
        self.resolution = resolution

    def __call__(self, image):
        pixel_values = self.processor(image, return_tensors="pt").pixel_values.to(self.model.device).to(self.model.dtype)
        logits = self.model(pixel_values=pixel_values).logits.float().cpu()

        upsampled_logits = nn.functional.interpolate(
            logits,
            size=(self.resolution, self.resolution),
            mode="bicubic",
        )
        probabilities = torch.sigmoid(upsampled_logits).detach().numpy()[0,0]
        return probabilities


resolution = 768

thickness = 2

bzp_offset = resolution // 100
tolerance = resolution // 100 + resolution % 2

n_samples = 1000
threshold_steps = 0.01


thresholds = np.linspace(threshold_steps, 1-threshold_steps, int(1 / threshold_steps)-1)

model = DirectSAM(
    "chendelong/DirectSAM-tiny-distilled-15ep-768px-0821", 
    # "chendelong/DirectSAM-1800px-0424",
    resolution, 
    device
    )

In [None]:
import numpy as np
import cv2
from scipy.ndimage import distance_transform_edt
import torch.nn.functional as F

def edge_zero_padding(boundary, offset):
    boundary[:offset, :] = 0
    boundary[-offset:, :] = 0
    boundary[:, :offset] = 0
    boundary[:, -offset:] = 0
    return boundary


def calculate_recall_torch(target, predictions, r):

    r = int(r)
    assert r % 2 == 1 and r > 0
    target = torch.tensor(target).to(device).float()
    predictions = torch.tensor(predictions).to(device).float()

    C, H, W = predictions.shape
    kernel = torch.ones((C, 1, r, r)).to(device)

    predictions_blur = F.conv2d(predictions, kernel, groups=C, padding=r//2)
    overlap = target * (predictions_blur > 0)
    recall = overlap.sum(dim=(1, 2)) / target.sum()

    return recall.cpu().numpy()


def get_num_tokens(boundary):
    num_objects, labels = cv2.connectedComponents((1-boundary).astype(np.uint8))
    return num_objects


def get_metrics(target, prob, thresholds, bzp_offset=bzp_offset, tolerance=tolerance, step=threshold_steps):
    target = edge_zero_padding(target, bzp_offset)

    predictions = np.array([prob > threshold for threshold in thresholds])

    all_num_tokens = []
    for prediction in predictions:
        num_tokens = get_num_tokens(prediction)
        all_num_tokens.append(num_tokens)

    all_recall = calculate_recall_torch(target, predictions, tolerance)

    return all_num_tokens, all_recall

In [None]:
# for dataset_name in ['ADE20k', 'EntitySeg', 'COCONut_relabeld_COCO_val', 'LoveDA', 'PascalPanopticParts', 'PartImageNet++']:
for dataset_name in ['PascalPanopticParts']:

    dataset_config = dataset_configs[dataset_name]

    dataset = create_dataset(dataset_config, split='validation', resolution=resolution, thickness=thickness)

    print(dataset_config)
    print(dataset_name)
    print(len(dataset))

    all_num_tokens = []
    all_recall = []

    for i in tqdm.tqdm(range(n_samples)):
        sample = dataset[i]
        # sample = dataset[random.randint(0, len(dataset)-1)]

        if type(sample) == dict:
            image = sample['image']
            target = sample['label']
        else:
            image, target = sample

        prob = model(image)

        num_tokens, recall = get_metrics(target, prob, thresholds)
        all_num_tokens.append(num_tokens)
        all_recall.append(recall)

        # if i<1:

        #     plt.figure(figsize=(20, 5))
        #     plt.subplot(1, 4, 1)
        #     plt.imshow(image)
            
        #     plt.subplot(1, 4, 2)
        #     plt.imshow(target)

        #     plt.subplot(1, 4, 3)
        #     plt.imshow(prob)

        #     plt.subplot(1, 4, 4)
        #     plt.imshow(prob > 0.5)

        #     plt.show()

In [None]:
bin_step = 32
max_num_tokens = 256

y = np.array(all_recall).flatten()
x = np.array(all_num_tokens).flatten()

df = pd.DataFrame({'num_tokens': x, 'recall': y})
bins = np.arange(0, max_num_tokens + bin_step, bin_step)
df['binned_tokens'] = pd.cut(df['num_tokens'], bins)
grouped = df.groupby('binned_tokens')['recall'].agg(['mean', 'std']).reset_index()
print(grouped)

plt.figure(figsize=(20, 10))
plt.scatter(x, y, color='blue', alpha=0.05, s=1)

# plt.errorbar(grouped['binned_tokens'].apply(lambda x: x.mid), grouped['mean'], yerr=grouped['std'], fmt='o', color='red', capsize=5)

plt.scatter(grouped['binned_tokens'].apply(lambda x: x.mid), grouped['mean'], color='red')
plt.plot(grouped['binned_tokens'].apply(lambda x: x.mid), grouped['mean'], color='red')

plt.xlabel('Number of Tokens')
plt.ylabel('Recall')
plt.xlim(0, max_num_tokens)
plt.ylim(0, 1)
plt.show()