In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5"

import cv2
from PIL import Image
import numpy as np
import random
import tqdm
import json
import matplotlib.pyplot as plt

from visual_tokenizer import get_visual_tokenizer
from utils.visualization import visualize_masks
from data import get_dataset

In [None]:
# dataset = get_dataset('imagenet', '/share/datasets/imagenet', split='train')
# dataset = get_dataset('coco', '/share/datasets/coco2017', split='train')
dataset = get_dataset('clevr_caption', '/home/dchenbs/workspace/datasets/CLEVR_v1.0', split='train')
# dataset = get_dataset('image_paragraph_captioning', '/home/dchenbs/workspace/datasets/VisualGenome', split='train')

# datset = get_dataset('sharegpt4v', '/home/dchenbs/workspace/datasets/sharegpt4v/ShareGPT4V/sharegpt4v_mix665k_cap23k_coco-ap9k_lcs3k_sam9k_div2k.json', split='train')
# dataset = get_dataset('sharegpt4v', '/home/dchenbs/workspace/datasets/sharegpt4v/ShareGPT4V/share-captioner_coco_lcs_sam_1246k_1107.json', split='train')
# dataset = get_dataset('sharegpt4v', '/home/dchenbs/workspace/datasets/sharegpt4v/ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k.json', split='train')

In [None]:
image_resolution = 768

token_per_side = 6
max_tokens = token_per_side * token_per_side

# config = json.load(open('configs/visual_tokenizer/patch_8_per_side_random.json'))
# config = json.load(open('configs/visual_tokenizer/patch_8_per_side_raster.json'))
# config = json.load(open('configs/visual_tokenizer/directsam_0424.json'))
config = json.load(open('configs/visual_tokenizer/directsam_tiny.json'))


In [None]:

print(config)
config['threshold'] = 0.1

visual_tokenizer = get_visual_tokenizer(**config, image_resolution=image_resolution, max_tokens=max_tokens)

### Visualization

In [None]:
for i in range(4):
    # sample = dataset[i]
    sample = dataset[random.randint(0, len(dataset) - 1)]
    image = sample['image'].resize((image_resolution,image_resolution))
    batch_masks = visual_tokenizer(image).cpu().numpy()

    plt.figure(figsize=(20, 10))
    plt.subplot(1, 3, 1)
    plt.imshow(image)

    plt.subplot(1, 3, 2)
    plt.imshow(visualize_masks(image, batch_masks[0]))
    plt.title((np.sum(batch_masks[0], axis=(1, 2)) > 0).sum())


    labels = np.zeros_like(batch_masks[0][0]).astype(np.int32)
    for i, mask in enumerate(batch_masks[0]):
        if np.sum(mask) == 0:
            continue
        labels += (i + 1) * mask

    plt.subplot(1, 3, 3)
    plt.imshow(labels, cmap='plasma')
    plt.title('order')

    print(sample['text'])

    plt.show()

    plt.figure(figsize=(15, 15))

    for i in range(token_per_side * token_per_side):
        plt.subplot(token_per_side, token_per_side, i + 1)
        plt.imshow(batch_masks[0][i])
        plt.axis('off')
        plt.title(batch_masks[0][i].sum())

    plt.show()

### Statistics and Efficiency

In [None]:
steps = 10
effective_masks = []
mask_sizes = []
for _ in tqdm.tqdm(range(steps)):
    image = dataset[random.randint(0, len(dataset) - 1)]['image']
    image = image.resize((image_resolution, image_resolution))
    masks = visual_tokenizer(image)[0].cpu().numpy()

    effective_masks.append((np.sum(masks, axis=(1, 2))>0).sum())
    mask_sizes.append(np.sum(masks, axis=(1, 2)))

mask_sizes = np.array(mask_sizes) / (image_resolution * image_resolution)* 100
avg_mask_sizes = np.mean(mask_sizes, axis=0)

In [None]:

plt.figure(figsize=(20, 10))

for sample_mask_sizes in mask_sizes:
    plt.plot(sample_mask_sizes, alpha=0.1)

plt.plot(avg_mask_sizes, color='black', linewidth=2)

# plot horizontal line of 10%, 1%, 0.1% and mark text
for y in [1, 50, 10, 1, 0.1]:
    plt.axhline(y=y, color='r', linestyle='--')
    plt.text(0, y*1.1, f'{y}%', color='r')

# # plot vertical line of 32 tokens, 64 tokens, 128 tokens and mark text
for x in [4, 8, 32]:
    plt.axvline(x=x-1, color='b', linestyle='--')
    plt.scatter(x-1, avg_mask_sizes[x-1], color='b')
    plt.text(x+1, avg_mask_sizes[x-1], f'{x} tokens ({avg_mask_sizes[x-1]:.3f}%)', color='b')

#log y axis
plt.yscale('log')
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
plt.hist(effective_masks, bins=50)
plt.show()