In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import random

from transformers import PretrainedConfig
from model.visual_token_embedding import VisualTokenEmbedding
from utils.visualization import visualize_masks

import torch.nn.functional as F
import json
from visual_tokenizer import get_visual_tokenizer
from data import get_dataset


In [None]:
dataset = get_dataset('imagenet', '/datasets01/imagenet_full_size/061417', split='train')

# dataset = get_dataset('clevr_caption', '/private/home/delong/workspace/data/clevr-caption', split='train')
# dataset = get_dataset('sharegpt4v', '/private/home/delong/workspace/data/ShareGPT4V', split='share-captioner_coco_lcs_sam_1246k_1107.json')

In [None]:
tokenizer_input_resolution = 768
embedding_input_resolution = 768
max_tokens = 256

config = json.load(open('configs/visual_tokenizer/directsam/directsam_tiny_sa1b_2ep@0.05.json'))
# config = json.load(open('configs/visual_tokenizer/superpixel/superpixel_slic.json'))
# config = json.load(open('configs/visual_tokenizer/panoptic/panoptic_mask2former_tiny.json'))
# config = json.load(open('configs/visual_tokenizer/directsam/directsam_tiny_dsa_100ep@0.5.json'))
# config = json.load(open('/private/home/delong/workspace/subobjects-VLM/configs/visual_tokenizer/patch/patch_16_per_side_raster.json'))

visual_tokenizer = get_visual_tokenizer(**config, image_resolution=tokenizer_input_resolution, max_tokens=max_tokens)

In [None]:
# config = json.load(open('configs/visual_embedding/rgb_pixel.json'))
# config = json.load(open('configs/visual_embedding/in1k_mobilenetv3_all.json'))
# config = json.load(open('configs/visual_embedding/vae.json'))
# # config = json.load(open('configs/visual_embedding/convnext_in22k_stage3.json'))
# config = json.load(open('configs/visual_embedding/dinov2_small.json'))
# config = json.load(open('configs/visual_embedding/clip_resnet50.json'))
# config = json.load(open('configs/visual_embedding/clip_vit_l_14_336.json'))
config = json.load(open('configs/visual_embedding/clip_vit_b_32.json'))



# # - - - - - - - - - - - - - - - - - - - - - - - - -

# config = {
#     "token_roi_resolution": 16,

# #     # # # https://huggingface.co/models?sort=trending&search=facebook%2Fdinov2
# #     "vision_encoder_type": "hf_autobacbone",
# #     "vision_encoder_name": "facebook/dinov2-large", # small, base, large, giant
    
# #     # # # https://huggingface.co/models?search=facebook/convnextv2
# #     # "vision_encoder_type": "hf_autobacbone",
# #     # "vision_encoder_name": "facebook/convnextv2-tiny-22k-384/stage3", 

# #     # # https://huggingface.co/models?search=microsoft/resnet
# #     # "vision_encoder_type": "hf_autobacbone",
# #     # "vision_encoder_name": "microsoft/resnet-50", # 18, 34, 50, 101

#     # https://huggingface.co/timm
#     "vision_encoder_type": "timm_backbone",
#     "vision_encoder_name": "tf_mobilenetv3_small_minimal_100.in1k/all", 
#     # "vision_encoder_name": "convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_320/all", 
# }


visual_token_embedding_config = PretrainedConfig.from_dict(config)
visual_token_embedding_config.image_resolution = embedding_input_resolution
visual_token_embedding_config.output_resolution = tokenizer_input_resolution
visual_token_embedding = VisualTokenEmbedding(visual_token_embedding_config).cuda()

print(visual_token_embedding.device, visual_token_embedding.dtype)
print(visual_token_embedding.vision_encoder.feature_channels, 'channels')

In [None]:
n_samples = 5

images = []
for i in range(n_samples):
    image = dataset[random.randint(0, len(dataset) - 1)]['image'].resize((tokenizer_input_resolution, tokenizer_input_resolution))
    images.append(image)

feature_maps = visual_token_embedding.vision_encoder(images)
feature_maps_upsampled = F.interpolate(
    feature_maps, 
    size=(tokenizer_input_resolution, tokenizer_input_resolution),
    mode='bilinear'
)

print(feature_maps.shape, feature_maps.dtype, feature_maps.device)
print(feature_maps_upsampled.shape, feature_maps_upsampled.dtype, feature_maps_upsampled.device)
feature_maps = feature_maps.cpu().numpy()
feature_maps_upsampled = feature_maps_upsampled.cpu().numpy()

In [6]:
from sklearn.decomposition import PCA
import numpy as np

def apply_pca(feature_maps):
    # do PCA on the channel dimension
    # feature_maps: N, C, H, W
    N, C, H, W = feature_maps.shape
    if C > 3:
        feature_vectors = feature_maps.swapaxes(0, 1).reshape(C, -1).T # N, H, W, C -> N*H*W, C

        pca = PCA(n_components=3)
        feature_vectors = pca.fit_transform(feature_vectors) # N*H*W, 3

        # make < 0 values to be zero
        # feature_vectors[feature_vectors < 0] = -1
        
        feature_maps = feature_vectors.T.reshape(3, N, H, W).swapaxes(0, 1) # (N*H*W, 3) -> (3, N*H*W) -> (3, N, H, W) -> (N, 3, H, W)

    # N, 3, H, W -> N, H, W, 3
    feature_maps = feature_maps.swapaxes(1, 3).swapaxes(1, 2)
    feature_maps = (feature_maps - np.min(feature_maps)) / (np.max(feature_maps) - np.min(feature_maps))
    return feature_maps

feature_maps_rgb = apply_pca(feature_maps)
feature_maps_upsampled_rgb = apply_pca(feature_maps_upsampled)

In [None]:
import matplotlib.pyplot as plt
batch_masks = visual_tokenizer(images)
batch_masks_cpu = batch_masks.cpu().numpy()


for b, image in enumerate(images):

    mask_sum = (np.sum(batch_masks_cpu[b], axis=(1, 2))>0).sum()
    print(mask_sum)
    image = image.resize((tokenizer_input_resolution, tokenizer_input_resolution))
    plt.figure(figsize=(40, 10))

    plt.subplot(1, 4, 1)
    plt.imshow(image)

    plt.subplot(1, 4, 2)
    plt.imshow(visualize_masks(image, batch_masks_cpu[b][:1024]))

    plt.subplot(1, 4, 3)
    plt.imshow(feature_maps_rgb[b])

    plt.subplot(1, 4, 4)
    plt.imshow(feature_maps_upsampled_rgb[b])

    plt.show()

In [None]:
roi_boxes, roi_masks, embeddings = visual_token_embedding(images, batch_masks)
print('embeddings', embeddings.shape)
print('roi_boxes', roi_boxes.shape)
print('roi_masks', roi_masks.shape)

roi_boxes = roi_boxes.cpu().numpy()
embeddings = embeddings.cpu().numpy()
roi_masks = roi_masks.cpu().numpy()

In [None]:
C = visual_token_embedding.vision_encoder.feature_channels
token_roi_resolution = visual_token_embedding.config.token_roi_resolution

for b, image in enumerate(images):

    plt.figure(figsize=(15, 15))
    n_rows = 6
    n_cols = 6

    # for i in range(n_rows * n_cols):
    #     plt.subplot(n_rows, n_cols, i + 1)
    #     plt.imshow(batch_masks_cpu[0][i])
    #     plt.axis('off')
    #     plt.title(batch_masks_cpu[0][i].sum())
        
    image = image.resize((tokenizer_input_resolution, tokenizer_input_resolution))
    down_sample_ratio = tokenizer_input_resolution // 1
    for i in range(6):
        plt.figure(figsize=(20, 8))
        plt.subplot(1, 6, 1)
        plt.imshow(image)

        plt.subplot(1, 6, 2)
        plt.imshow(batch_masks_cpu[b, i], cmap='inferno')
        plt.imshow(image, alpha=0.2)

        x1, y1, x2, y2 = (roi_boxes[b][i] * tokenizer_input_resolution).astype(int)
        plt.plot([x1, x2, x2, x1, x1], [y1, y1, y2, y2, y1], 'r')
        plt.title(f'ROI [{x1}, {y1}, {x2}, {y2}]')

        plt.subplot(1, 6, 3)
        plt.title(f'Mask {np.average(roi_masks[b][i])}')
        plt.imshow(roi_masks[b][i])

        # plt.subplot(1, 6, 4)
        # plt.title('Embedding')
        # embedding = embeddings[b][i]
        # embedding = embedding.reshape(C, token_roi_resolution, token_roi_resolution)

        # # unsqueeze embedding
        # embedding = np.expand_dims(embedding, axis=0)
        # plt.imshow(apply_pca(embedding)[0])
        # # plt.imshow(feature_maps_rgb * roi_masks[b][i][:, :, None])

        plt.show()
