In [1]:
import model.clip as clip
from model.model import ResidualAttentionBlock
import numpy as np
import torch
from torch import nn
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

import os
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

from collections import OrderedDict
import torch

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [2]:
import os
os.environ['http_proxy'] = "http://192.41.170.23:3128" 
os.environ['https_proxy'] = "http://192.41.170.23:3128"

In [3]:
# clip.available_models()

In [4]:
model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408


In [5]:
# images in skimage to use and their textual descriptions
descriptions = {
    "page": "a page of text about segmentation",
    "chelsea": "a facial photo of a tabby cat",
    "astronaut": "a portrait of an astronaut with the American flag",
    "rocket": "a rocket standing on a launchpad",
    "motorcycle_right": "a red motorcycle standing in a garage",
    "camera": "a person looking at a camera on a tripod",
    "horse": "a black-and-white silhouette of a horse", 
    "coffee": "a cup of coffee on a saucer"
}

In [None]:
original_images = []
images = []
texts = []
plt.figure(figsize=(16, 5))

for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
    name = os.path.splitext(filename)[0]
    if name not in descriptions:
        continue

    image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
  
    plt.subplot(2, 4, len(images) + 1)
    plt.imshow(image)
    plt.title(f"{filename}\n{descriptions[name]}")
    plt.xticks([])
    plt.yticks([])

    original_images.append(image)
    images.append(preprocess(image))
    texts.append(descriptions[name])

plt.tight_layout()


In [None]:
image_input = torch.tensor(np.stack(images)).cuda()
text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()

In [None]:
with torch.no_grad():
    image_features = model.encode_image(image_input).float()
    text_features = model.encode_text(text_tokens).float()

In [None]:
def text_image_similarity(image_feature, text_feature):
    cos = nn.CosineSimilarity(dim=0)
    similarity_patches = []
    for feature_vector in image_feature[1:50]:
        similarity = cos(text_feature[0], feature_vector).cpu()
        similarity_patches.append(similarity)
    cls_similarity = cos(image_feature[0], text_feature[0]).cpu()
    return similarity_patches, cls_similarity

def build_similarity_map(similarity_patches, image_size, patch_size):
    num_patch_per_row = image_size / patch_size
    temp_row = []

    for index, patches in enumerate(similarity_patches):
        if index % num_patch_per_row == 0:
            if index != 0:
                temp_row.append(temp_column)
                temp_column = torch.full((patch_size, patch_size), patches)
            else:
                temp_column = torch.full((patch_size, patch_size), patches)
        else:
            temp_filter = torch.full((patch_size, patch_size), patches)
            temp_column = torch.cat((temp_column, temp_filter), -1)
    temp_row.append(temp_column)
    image_filter = torch.cat(temp_row, 0)
    return image_filter

def plot_image_similarity(image, image_filter, text):
    plt.imshow(image.permute(1,2,0), interpolation='nearest')
    plt.title(text)
    plt.imshow(image_filter, alpha=0.4, cmap='jet')
    plt.colorbar()

In [None]:
image_no = 4
text_no = 4
patch_size = 32
image_size = 224

# text_feature_test = text_feature_test
# text_test = "This is a image of stand camera"
text_test = texts[text_no]
text_feature_test = text_features[text_no]
image_feature_test = image_features[image_no]
image_test = images[image_no]

print("Image Discription:", text_test)

similarity_per_patch, cls_similarity = text_image_similarity(image_feature_test, text_feature_test)
print("Cls image similarity:", cls_similarity)

image_similarity_map = build_similarity_map(similarity_per_patch, image_size,  patch_size)
plot_image_similarity(image_test, image_similarity_map, text_test)

In [None]:
similarity_per_patch