In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch

from transformers import PretrainedConfig
from model.visual_token_embedding import VisualTokenEmbedding
from utils.visualization import visualize_masks
from utils.logging import get_params_count_summary
from tokenizer.directsam import DirectSAMTokenizer


In [None]:

image_resolution = 1024

visual_tokenizer = DirectSAMTokenizer(
    ckpt="chendelong/DirectSAM-tiny-distilled-30ep-plus-50ep-1024px-0910",
    threshold=0.1,
    image_resolution=image_resolution,
    max_tokens=128,
    device="cuda"
)

In [None]:
visual_token_embedding_config = PretrainedConfig.from_dict({
    "image_resolution": image_resolution,
    "num_heads": 8,
    'embedding_dim': 384,

    # "vision_encoder_type": "rgb_pixel",
    # "vision_encoder_name": "rgb_pixel",
    
    # "vision_encoder_type": "diffusers_vae",
    # "vision_encoder_name": "chendelong/stable-diffusion-3-medium-vae",

    # https://huggingface.co/models?sort=trending&search=facebook%2Fdinov2
    # "vision_encoder_type": "hf_autobacbone",
    # "vision_encoder_name": "facebook/dinov2-giant", # small, base, large, giant
    
    # https://huggingface.co/models?search=facebook/convnextv2
    "vision_encoder_type": "hf_autobacbone",
    "vision_encoder_name": "facebook/convnextv2-tiny-22k-384/stage2", 

    # # https://huggingface.co/models?search=microsoft/resnet
    # "vision_encoder_type": "hf_autobacbone",
    # "vision_encoder_name": "microsoft/resnet-50", # 18, 34, 50, 101

    # https://huggingface.co/timm
    # "vision_encoder_type": "timm_backbone",
    # "vision_encoder_name": "convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_320/-2", 
})


visual_token_embedding = VisualTokenEmbedding(visual_token_embedding_config).cuda()
print(visual_token_embedding.device, visual_token_embedding.dtype)

print(visual_token_embedding.vision_encoder.feature_channels, 'channels')

In [None]:
# get an image
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

feature_maps = visual_token_embedding.vision_encoder([image])
print(feature_maps.shape, feature_maps.dtype, feature_maps.device)
feature_maps = feature_maps.cpu().numpy()

In [None]:
from sklearn.decomposition import PCA
import numpy as np

def apply_pca(feature_maps):
    # feature_maps: N, h, w, feature_dim
    if feature_maps.shape[1] > 3:
        n_components = 3

        X = np.swapaxes(feature_maps, -1, 1)
        X = np.reshape(X, (-1, X.shape[-1]))

        pca = PCA(n_components=n_components)
        pca.fit(X)
        projection = pca.components_ # n_components, feature_dim

        # project the features to the PCA space -> n_images, n_components, h, w
        feature_maps = np.matmul(np.swapaxes(feature_maps, -1, 1), projection.T)
        feature_maps = np.swapaxes(feature_maps, -1, 1)

        # normalize the features to [0, 1]
        feature_maps = (feature_maps - np.min(feature_maps)) / (np.max(feature_maps) - np.min(feature_maps))
    
    return feature_maps


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(20, 10))

plt.subplot(1, 2, 1)
plt.imshow(image.resize((512, 512)))
plt.axis("off")

plt.subplot(1, 2, 2)
feature_maps_rgb = np.transpose(apply_pca(feature_maps)[0], (1, 2, 0))
plt.imshow(feature_maps_rgb)
plt.axis("off")

plt.show()


In [None]:
image = image.resize((image_resolution, image_resolution))
batch_masks = visual_tokenizer(image)


plt.figure(figsize=(20, 10))
plt.subplot(1, 2, 1)
plt.imshow(image)

plt.subplot(1, 2, 2)
plt.imshow(visualize_masks(image, batch_masks[0][:64]))
plt.title(len(batch_masks[0]))
plt.show()

In [None]:
embeddings = visual_token_embedding([image], batch_masks)