In [None]:
import torch
import clip
from transformers import AutoTokenizer
from multilingual_clip import pt_multilingual_clip
from PIL import Image
import requests
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

def display_img(image):
    plt.imshow(image)
    plt.axis('off')
    plt.show()

def score(image_features, text_features):
    # logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
    # logit_scale = logit_scale.exp()
    logit_scale = 100
    # normalized features
    image_features = image_features / image_features.norm(dim=1, keepdim=True)
    text_features = text_features / text_features.norm(dim=1, keepdim=True)
    image_features = image_features.type(torch.float32)
    text_features = text_features.type(torch.float32)

    # cosine similarity as logits
    logits_per_image = logit_scale * image_features @ text_features.t()
    logits_per_text = logits_per_image.t()

    return logits_per_image, logits_per_text


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "ViT-L/14"
print(f'Loading ClIP Model {model_name} ...')
model, preprocess = clip.load(model_name, device=device)

model_name_multilingual = 'M-CLIP/LABSE-Vit-L-14'
print(f'Loading Multilingual ClIP Model {model_name_multilingual} ...')
model_multilingual = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_name_multilingual)
tokenizer = AutoTokenizer.from_pretrained(model_name_multilingual)


In [None]:
# Get the images
cat_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
dog_url = "https://farm6.staticflickr.com/5119/7204468584_6eed877236_z.jpg"

cat_image = Image.open(requests.get(cat_url, stream=True).raw)
dog_image = Image.open(requests.get(dog_url, stream=True).raw)
images = [cat_image, dog_image]
# Set the captions in English and Arabic
captions = ["a photo of a cat", "a photo of a dog", "صورة قطة", "صورة كلب"]

# Preporcess the image and get the image embedding from CLIP
processed_images = []
for image in images:
    img = preprocess(image).unsqueeze(0).to(device)
    processed_images.append(img)
img = torch.cat(processed_images)
img_features = model.encode_image(img)

# Tokenize the captions and get the text embedding from CLIP
text = clip.tokenize(captions).to(device)
text_features = model.encode_text(text)

# Tokenize the captions and get the text embedding from Multilingual-CLIP
text_features_multilingual = model_multilingual.forward(captions, tokenizer).to(device)

# Get the similarity score between the image and the text from CLIP
logits_per_image, logits_per_text = score(img_features, text_features)
probs = logits_per_image.softmax(dim=-1).cpu().detach().numpy()

# Get the similarity score between the image and the text from Multilingual-CLIP
logits_per_image_multilingual, logits_per_text_multilingual = score(img_features, text_features_multilingual)
probs_multilingual = logits_per_image_multilingual.softmax(dim=-1).cpu().detach().numpy()

# Display the images and the captions with the similarity score from CLIP and Multilingual-CLIP  
for i in range(len(images)):
    print("Image: ", i+1)
    display(images[i])
    for j in range(len(captions)):
        print("Caption: ", captions[j])
        print(f"CLIP: {probs[i][j] * 100:.2f}%")
        print(f"Multilingual-CLIP: {probs_multilingual[i][j] * 100:.2f}%")
        print("")
