# CLIP-DINOiser visualization demo 🖼️

In [1]:
from hydra.core.global_hydra import GlobalHydra
import os

from markdown_it.rules_inline import image
from torch import Tensor

from models.builder import build_model
from helpers.visualization import mask2rgb
from segmentation.datasets import PascalVOCDataset
from hydra import compose, initialize
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms as T
import torch.nn.functional as F
import numpy as np
from operator import itemgetter 
import torch
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings('ignore')
GlobalHydra.instance().clear()
initialize(config_path="configs", version_base=None)


  from .autonotebook import tqdm as notebook_tqdm


hydra.initialize()

In [2]:
# def load_support_image(image_path):
#     """Load and preprocess a single support image"""
#     image = Image.open(image_path).convert('RGB')
#     transform = T.Compose([
#         T.Resize((224, 224)),
#         T.ToTensor(),
#         T.Normalize((0.48145466, 0.4578275, 0.40821073), 
#                    (0.26862954, 0.26130258, 0.27577711))
#     ])
#     return transform(image).unsqueeze(0).unsqueeze(0)

# def load_support_images(support_image_paths):
#     """Load multiple support images"""
#     support_images = torch.tensor([])
#     for path in support_image_paths:
#         image_tensor = load_support_image(image_path=path)
#         print(image_tensor.shape)
#         support_images = torch.cat((support_images, image_tensor), dim=1)
    
#     print(support_images.shape)
#     return support_images

### Load and configure a model

In [5]:
from models import MaskClipHead

# Load the model checkpoint
check_path = './checkpoints/last.pt'
check = torch.load(check_path, map_location='cpu')

# Load the configuration
dinoclip_cfg = "clip_dinoiser.yaml"
cfg = compose(config_name=dinoclip_cfg)

# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Build the model
model = build_model(cfg.model, class_names=PascalVOCDataset.CLASSES).to(device)

# Ensure the decode_head is an instance of MaskClipHead
if not isinstance(model.clip_backbone.decode_head, MaskClipHead):
    model.clip_backbone.decode_head = MaskClipHead(
        clip_model=model.clip_backbone.decode_head.clip_model,
        class_names=PascalVOCDataset.CLASSES,
        in_channels=3,
        text_channels=512,
        use_templates=False,
        pretrained='laion2b_s34b_b88k'
    ).to(device)

# Switching off the imagenet templates for fast inference
model.clip_backbone.decode_head.use_templates = False

# Load the model state
model.load_state_dict(check['model_state_dict'], strict=False)
model = model.eval()

# Load the Pascal VOC dataset
dataset = PascalVOCDataset(
    img_dir='path/to/VOCdevkit/VOC2012/JPEGImages',
    ann_dir='path/to/VOCdevkit/VOC2012/SegmentationClass',
    pipeline=[
        dict(type='LoadImageFromFile'),
        dict(type='LoadAnnotations'),
        dict(type='Resize', img_scale=(512, 512), keep_ratio=True),
        dict(type='RandomFlip', flip_ratio=0.5),
        dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True),
        dict(type='Pad', size=(512, 512), pad_val=0, seg_pad_val=255),
        dict(type='DefaultFormatBundle'),
        dict(type='Collect', keys=['img', 'gt_semantic_seg']),
    ],
    split='ImageSets/Segmentation/train.txt'
)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)

def select_support_images(data_loader, num_support_images=2):
    support_images = []
    for i, (img, _) in enumerate(data_loader):
        if i >= num_support_images:
            break
        support_images.append(img)
    return torch.cat(support_images, dim=0)

def load_image(image_path):
    img = Image.open(image_path).convert('RGB')
    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ])
    return transform(img).unsqueeze(0)


# Function to visualize per image
def visualize_per_image(file_path, support_image_paths, palette, model):
    assert os.path.isfile(file_path), f"No such file: {file_path}"

    # Load the input image
    img = Image.open(file_path).convert('RGB')
    img_tens = T.PILToTensor()(img).unsqueeze(0).to("cpu") / 255.

    # Select and process support images from the dataset
    support_images = load_support_images(support_image_paths).to("cpu")

    # Perform inference using the support embeddings
    h, w = img_tens.shape[-2:]
    output = model(img_tens, support_images=support_images).cpu()
    output = F.interpolate(output, scale_factor=model.vit_patch_size, mode="bilinear", align_corners=False)[..., :h, :w]
    output = output[0].argmax(dim=0)
    mask = mask2rgb(output, palette)

    # Visualize the results
    fig = plt.figure(figsize=(3, 1))
    classes = np.unique(output).tolist()
    plt.imshow(np.array(itemgetter(*classes)(palette)).reshape(1, -1, 3))
    plt.xticks(np.arange(len(classes)), [f"Class {i}" for i in classes], rotation=45)
    plt.yticks([])

    return mask, fig, img

RuntimeError: CLIP_DINOiser: MaskClip: number of dims don't match in permute

### Example with 'background' class

In [None]:
# file = 'assets/vintage_bike.jpeg'
# PALETTE = [(0, 0, 0), (156, 143, 189), (79, 158, 101)]
# 
# # specify your prompts
# TEXT_PROMPTS = ['leather bag']
# model.clip_backbone.decode_head.update_vocab(TEXT_PROMPTS)
# model.to(device)
# 
# # set apply FOUND (background detector) to True
# model.apply_found = True
# 
# # run segmentation
# mask, ticks, img = visualize_per_image(file, TEXT_PROMPTS, PALETTE, model)
# 
# fig, ax = plt.subplots(nrows=1, ncols=2)
# alpha=0.5
# blend = (alpha)*np.array(img)/255. + (1-alpha) * mask/255.
# ax[0].imshow(blend)
# ax[1].imshow(mask)
# ax[0].axis('off')
# ax[1].axis('off')

### Example without 'background' class

In [None]:
# Example usage
file = 'assets/vintage_bike.jpeg'
PALETTE = [[25, 29, 136], [128, 112, 112], [85, 124, 85], [250, 112, 112], [250, 250, 0], [250, 0, 0]]
num_support_images = 2  # Number of support images to select from the dataset

# Add a cell to input the paths to support images
support_image_paths = [
    "assets/bike.jpeg",
    "assets/bag.jpeg",
]

# # Load the support images
# support_images_tensor = load_support_images(support_image_paths)

# # add the query image
# support_images_tensor = torch.cat((support_images_tensor, load_support_image(file).to("cpu")), dim=1)
# print(f'Support images tensor shape: {support_images_tensor.shape}')

# Specify whether applying FOUND or not
model.apply_found = True

mask, fig, img = visualize_per_image(file_path, num_support_images, palette, model)


# Display the results
fig, ax = plt.subplots(nrows=1, ncols=2)
alpha = 0.5
blend = (alpha) * np.array(img) / 255. + (1 - alpha) * mask / 255.
ax[0].imshow(blend)
ax[1].imshow(mask)
ax[0].axis('off')
ax[1].axis('off')
plt.show()

torch.Size([1, 1, 3, 224, 224])
torch.Size([1, 1, 3, 224, 224])
torch.Size([1, 2, 3, 224, 224])
Support images tensor shape: torch.Size([1, 3, 3, 224, 224])


NameError: name 'model' is not defined