In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import cv2
from PIL import Image
import numpy as np
import random
import tqdm
import json
import matplotlib.pyplot as plt
from transformers import AutoModelForSemanticSegmentation, AutoImageProcessor
import torch.nn as nn
import cv2
from PIL import Image
import numpy as np
import torch

from visual_tokenizer import get_visual_tokenizer
from utils.visualization import visualize_masks
from data import get_dataset

In [None]:
dataset = get_dataset('imagenet', '/datasets01/imagenet_full_size/061417', split='train')

In [162]:
checkpoint = 'chendelong/DirectSAM-gen3-1024px-1023'
image_resolution = 1024

image_processor = AutoImageProcessor.from_pretrained("chendelong/DirectSAM-1800px-0424", reduce_labels=True)
image_processor.size['height'] = image_resolution
image_processor.size['width'] = image_resolution 

directsam = AutoModelForSemanticSegmentation.from_pretrained(checkpoint)
directsam = directsam.to('cuda').half().eval()


def get_contour_probability(image, output_resolution=1024, temperature=1.0):
    pixel_values = image_processor([image], return_tensors="pt").pixel_values
    pixel_values = pixel_values.to('cuda').to(directsam.dtype)

    logits = directsam(pixel_values=pixel_values).logits.float().cpu() / temperature
    upsampled_logits = nn.functional.interpolate(
        logits,
        size=(output_resolution, output_resolution),
        mode="bicubic",
        align_corners=False,
    )

    probabilities = torch.sigmoid(upsampled_logits)[0,0].detach().numpy()
    bzp = int(output_resolution / 100)
    probabilities[:bzp, :] = probabilities[-bzp:, :] = probabilities[:, :bzp] = probabilities[:, -bzp:] = 0

    return probabilities

### Visualization

In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
from matplotlib import cm
from matplotlib.colors import LightSource

output_resolution = image_resolution // 4

for i in range(5):
    # Sample a random image from the dataset
    sample = dataset[random.randint(0, len(dataset) - 1)]
    image = sample['image'].resize((image_resolution, image_resolution))

    # Get contour probabilities
    probabilities = get_contour_probability(image, output_resolution=output_resolution, temperature=2.5)

    # Display original image and contour probabilities
    fig = plt.figure(figsize=(30, 10))
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(probabilities, cmap=cm.Blues)
    plt.title('Contour Probabilities')
    plt.axis('off')
    plt.show()


    fig = plt.figure(figsize=(30, 30))
    probabilities = gaussian_filter(probabilities, sigma=2)
    # reverse the y axis
    probabilities = np.flip(probabilities, 0)

    # Create grid coordinates
    X, Y = np.meshgrid(np.arange(probabilities.shape[1]), np.arange(probabilities.shape[0]))
    Z = probabilities

    ax = fig.add_subplot(111, projection='3d')
    ax.view_init(elev=80, azim=-80)

    ls = LightSource(270, 45)
    rgb = ls.shade(Z, cmap=cm.Blues, vert_exag=0.1, blend_mode='soft')

    surf = ax.plot_surface(
        X, Y, Z,
        rcount=output_resolution, ccount=output_resolution,
        facecolors=rgb,
        linewidth=1,
        antialiased=True,
        shade=False
    )
    
    # Optional: Adjust axis limits to ensure the image covers the surface area
    ax.set_xlim(0, probabilities.shape[1])
    ax.set_ylim(0, probabilities.shape[0])
    ax.set_zlim(0, np.max(Z))

    # Optional: Remove gridlines and axes for a cleaner look
    ax.grid(False)
    ax.axis('off')
    plt.show()
