In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import cv2
from PIL import Image
import numpy as np
import random
import tqdm
import json
import matplotlib.pyplot as plt

from visual_tokenizer import get_visual_tokenizer
from utils.visualization import visualize_masks
from HEIT.metrics import create_circular_kernel, masks_to_contour, contour_recall


In [2]:
from datasets import load_dataset

splits = [
    'SA1B', 'COCONut_relabeld_COCO_val', 'EntitySeg', 'PascalPanopticParts', 'plantorgans', 'MapillaryMetropolis', 
    'cityscapes', 'NYUDepthv2', 'tcd', 'FoodSeg103', 'ADE20k', 'WireFrame', 'ISAID', 'PhenoBench', 'EgoHOS', 'LIP', 
    'SOBA', 'CIHP', 'LoveDA', 'SPIN', 'SUIM', 'MyFood', 'DIS5K_DIS_VD', 'DUTS_TE', 'Fashionpedia', 'PartImageNetPP', 
    'SeginW', 'LVIS', 'PACO', 'DRAM'
    ]
dataset = load_dataset("chendelong/HEIT", split='EntitySeg')
print(dataset)
print(dataset[0]) 


  from .autonotebook import tqdm as notebook_tqdm


Dataset({
    features: ['image', 'contour'],
    num_rows: 1314
})
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=1024x1024 at 0x7FECC4494CA0>, 'contour': <PIL.PngImagePlugin.PngImageFile image mode=1 size=1024x1024 at 0x7FECC4494E50>}


In [3]:
image_resolution = 384

# config = json.load(open('configs/visual_tokenizer/patch/patch_8_per_side_raster.json'))
# config = json.load(open('configs/visual_tokenizer/patch_16_per_side_random.json'))

# config = json.load(open('configs/visual_tokenizer/directsam/directsam_tiny_dsa_100ep@0.1.json'))

# config = json.load(open('configs/visual_tokenizer/directsam/directsam_tiny_dsa_100ep@0.1_x2.json'))

# config = json.load(open('configs/visual_tokenizer/superpixel/superpixel_slic.json'))

# config = json.load(open('configs/visual_tokenizer/panoptic/panoptic_mask2former_small.json'))
# config = json.load(open('configs/visual_tokenizer/panoptic/panoptic_oneformer_large.json'))

# config = json.load(open('configs/visual_tokenizer/sam/sam_vit_l.json'))
# config = json.load(open('configs/visual_tokenizer/sam/sam_vit_h_64points_1layer.json'))

# config = json.load(open('configs/visual_tokenizer/sam/fastsam.json'))
# config = json.load(open('configs/visual_tokenizer/sam/mobilesamv2.json'))
# config = json.load(open('configs/visual_tokenizer/sam/efficientvit.json'))

# config['threshold'] = 0.1
# config['crop'] = 3
print(config)

# config['threshold'] = 0.1
max_tokens = 256

visual_tokenizer = get_visual_tokenizer(**config, image_resolution=image_resolution, max_tokens=max_tokens)

FileNotFoundError: [Errno 2] No such file or directory: 'configs/visual_tokenizer/directsam/directsam_tiny_dsa_100ep@0.1.json'

### Visualization

In [None]:

# Global circular kernel
tolerance = 10
CIRCULAR_KERNEL = create_circular_kernel(tolerance)


for i in range(1):
    sample = dataset[i]
    sample = dataset[random.randint(0, len(dataset) - 1)]
    image = sample['image'].resize((image_resolution,image_resolution))
    batch_masks = visual_tokenizer(image).cpu().numpy()

    plt.figure(figsize=(20, 10))
    plt.subplot(1, 3, 1)
    plt.imshow(image)

    plt.subplot(1, 3, 2)
    plt.imshow(visualize_masks(image, batch_masks[0]))
    plt.title(f"{(np.sum(batch_masks[0], axis=(1, 2)) > 0).sum()} tokens")

    labels = np.zeros_like(batch_masks[0][0]).astype(np.int32)
    for i, mask in enumerate(batch_masks[0]):
        if np.sum(mask) == 0:
            continue
        labels += (i + 1) * mask

    plt.subplot(1, 3, 3)
    plt.imshow(labels, cmap='plasma')
    plt.title('order')
    plt.show()

    gt_contour = np.array(sample['contour'])
    gt_contour[:tolerance] = gt_contour[-tolerance:] = gt_contour[:, :tolerance] = gt_contour[:, -tolerance:] = 0
    pred_contour = masks_to_contour(batch_masks[0], CIRCULAR_KERNEL)

    # calculate recall
    recall = contour_recall(gt_contour, pred_contour)
    missing_contour = np.logical_and(gt_contour, np.logical_not(pred_contour))

    plt.figure(figsize=(20, 10))
    plt.subplot(1, 3, 1)
    plt.imshow(gt_contour, cmap='Greens')
    plt.title('Ground Truth')

    plt.subplot(1, 3, 2)
    plt.imshow(missing_contour, cmap='Reds')
    plt.imshow(image.resize((1024, 1024)), alpha=0.1)
    plt.title(f'Recall: {recall:.2f}')

    plt.subplot(1, 3, 3)
    plt.imshow(missing_contour, cmap='Reds')
    plt.imshow(pred_contour, cmap='Blues', alpha=0.3)
    plt.title('Predicted')

    plt.show()


    plt.figure(figsize=(20, 20))

    # token_per_side = root of max_tokens + 1
    token_per_side = min(6, int(np.sqrt(max_tokens)))

    for i in range(token_per_side * token_per_side):
        plt.subplot(token_per_side, token_per_side, i + 1)
        plt.imshow(batch_masks[0][i], cmap='Blues')
        plt.imshow(image, alpha=0.1)
        plt.axis('off')
        # plt.title(batch_masks[0][i].sum())

    plt.show()

### Statistics and Efficiency

In [None]:
steps = 5
effective_masks = []
mask_sizes = []
for _ in tqdm.tqdm(range(steps)):
    image = dataset[random.randint(0, len(dataset) - 1)]['image']
    image = image.resize((image_resolution, image_resolution))
    masks = visual_tokenizer(image)[0].cpu().numpy()

    effective_masks.append((np.sum(masks, axis=(1, 2))>0).sum())
    mask_sizes.append(np.sum(masks, axis=(1, 2)))

mask_sizes = np.array(mask_sizes) / (image_resolution * image_resolution)* 100
avg_mask_sizes = np.mean(mask_sizes, axis=0)

In [6]:

# plt.figure(figsize=(20, 10))

# for sample_mask_sizes in mask_sizes:
#     plt.plot(sample_mask_sizes, alpha=0.1)

# plt.plot(avg_mask_sizes, color='black', linewidth=2)

# # plot horizontal line of 10%, 1%, 0.1% and mark text
# for y in [1, 50, 10, 1, 0.1]:
#     plt.axhline(y=y, color='r', linestyle='--')
#     plt.text(0, y*1.1, f'{y}%', color='r')

# # # plot vertical line of 32 tokens, 64 tokens, 128 tokens and mark text
# for x in [4, 9, 16, 25, 36, 64, 81, 100]:
#     if x >= max_tokens:
#         continue
#     plt.axvline(x=x-1, color='b', linestyle='--')
#     plt.scatter(x-1, avg_mask_sizes[x-1], color='b')
#     plt.text(x+1, avg_mask_sizes[x-1], f'{x} tokens ({avg_mask_sizes[x-1]:.3f}%)', color='b')

# #log y axis
# plt.xlim(0, max_tokens)
# plt.yscale('log')
# plt.show()

In [7]:
# plt.figure(figsize=(10, 5))
# plt.hist(effective_masks, bins=50)
# plt.show()