<a href="https://colab.research.google.com/github/FreddeFrallan/Multilingual-CLIP/blob/main/Multilingual_CLIP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multilingual CLIP

## Install Packages and Setup
Run this to configure everything, this might take some minutes.

In [None]:
#@title  { display-mode: "code" }

import subprocess
CUDA_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]
print("CUDA version:", CUDA_version)

if CUDA_version == "10.0":
    torch_version_suffix = "+cu100"
elif CUDA_version == "10.1":
    torch_version_suffix = "+cu101"
elif CUDA_version == "10.2":
    torch_version_suffix = ""
else:
    torch_version_suffix = "+cu110"

!pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex
!pip install ftfy==5.8
!pip install transformers

import os
import random

from PIL import Image
import numpy as np
import torch

import warnings
warnings.filterwarnings("ignore")

!pip install git+https://github.com/openai/CLIP.git
import clip

!git clone https://github.com/FreddeFrallan/Multilingual-CLIP
%cd Multilingual-CLIP
!bash get-weights.sh

### Load The Multilingual Text Encoder

In [None]:
from src import multilingual_clip

text_model = multilingual_clip.load_model('M-BERT-Distil-40')

### Load The Matching CLIP Model

In [None]:
clip_model, compose = clip.load('RN50x4')

input_resolution = clip_model.input_resolution.item()
context_length = clip_model.context_length.item()
vocab_size = clip_model.vocab_size.item()

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in clip_model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

### Prepare Image Function

In [None]:
def prepare_images(compose, img_paths, device):
  return [preprocess(Image.open(p)).unsqueeze(0).to(device) for p in img_paths]

### Read in the Images

In [None]:
from PIL import Image
main_path = '/content/Multilingual-CLIP/Images/'
demo_images = {
    'Green Apple': 'green apple.jpg',
    'Red Apple': 'red apple.jpg',
    'Purple Apple': 'purple apple.png',
    'Orange Apple': 'Orange Apple.png',
    'Happy Person': 'happy person.jpg',
    'Sad Person': 'sad.jpg',
}

import os
#for p in demo_images.values():
for p in os.listdir(main_path):
  print(os.path.isfile(main_path + p))

#images = {name: Image.open(main_path + p) for name, p in demo_images.items()}

In [None]:
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

preprocess = Compose([
    Resize(input_resolution, interpolation=Image.BICUBIC),
    CenterCrop(input_resolution),
    ToTensor()
])

image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()

In [None]:
# images in skimage to use and their textual descriptions
descriptions_en = {
    "page": "a page of text about segmentation",
    "chelsea": "a facial photo of a tabby cat",
    "astronaut": "a portrait of an astronaut with the American flag",
    "rocket": "a rocket standing on a launchpad",
    "motorcycle_right": "a red motorcycle standing in a garage",
    "camera": "a person looking at a camera on a tripod",
    "horse": "a black-and-white silhouette of a horse", 
    "coffee": "a cup of coffee on a saucer"
    }

descriptions = {
    "page": "en sida med text om segmentering",
    "chelsea": "ett porträttfoto på en randig katt",
    "astronaut": "ett porträtt av en astronaut med den amerikanska flaggan",
    "rocket": "en raket på sin uppskjutningsplats",
    "motorcycle_right": "en röd motorcykel i ett garage",
    "camera": "en person som tittar på en kamera på ett stativ",
    "horse": "en svartvit siluett", 
    "coffee": "en kopp kaffe och ett fat"
}

In [None]:
import matplotlib.pyplot as plt
images = []
texts = []
plt.figure(figsize=(16, 5))

for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
    name = os.path.splitext(filename)[0]
    if name not in descriptions:
        continue

    image = preprocess(Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB"))
    images.append(image)
    texts.append(descriptions[name])

    plt.subplot(2, 4, len(images))
    plt.imshow(image.permute(1, 2, 0))
    plt.title(f"{filename}\n{descriptions[name]}")
    plt.xticks([])
    plt.yticks([])

plt.tight_layout()

NameError: ignored

<Figure size 1152x360 with 0 Axes>

In [None]:
image_input = torch.tensor(np.stack(images)).cuda()
image_input -= image_mean[:, None, None]
image_input /= image_std[:, None, None]

In [None]:
texts = ["Det här är " + desc for desc in texts]

In [None]:
with torch.no_grad():
    image_features = clip_model.encode_image(image_input).float()
    text_features = sweclip(texts).float()

In [None]:
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T

In [None]:
count = len(descriptions)

plt.figure(figsize=(20, 14))
plt.imshow(similarity, vmin=0.1, vmax=0.3)
# plt.colorbar()
plt.yticks(range(count), texts, fontsize=18)
plt.xticks([])
for i, image in enumerate(images):
    plt.imshow(image.permute(1, 2, 0), extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
for x in range(similarity.shape[1]):
    for y in range(similarity.shape[0]):
        plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)

for side in ["left", "top", "right", "bottom"]:
  plt.gca().spines[side].set_visible(False)

plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])

plt.title("Cosine similarity between text and image features", size=20)

In [None]:
from torchvision.datasets import CIFAR100

cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)

In [None]:
swe_cifar_classes = ['äpple', 'akvariefisk', 'bebis', 'Björn', 'bäver', 'säng',
                     'bi', 'skalbagge', 'cykel', 'flaska', 'skål', 'pojke',
                     'bro', 'buss', 'fjäril', 'kamel', 'burk', 'slott', 'larv',
                     'nötkreatur', 'stol', 'schimpans', 'klocka', 'moln',
                     'kackerlacka', 'soffa', "krabba", 'krokodil', 'kopp', 
                     'dinosaurie', 'delfin', 'elefant', 'plattfisk', 'skog',
                     'räv', 'flicka', 'hamster', 'hus', 'känguru', 
                     'tangentbord', 'lampa', 'gräsklippare', 'leopard', 'lejon',
                     'ödla', 'hummer', 'man', 'lönnträd', 'motorcykel', 'fjäll',
                     'mus', 'svamp', 'ekträd', 'orange', 'orkide', 'utter',
                     'palmträd', 'päron', 'pickup', 'tall', 'enkel', 'tallrik',
                     'vallmo', "piggsvin", 'opossum', 'kanin', 'tvättbjörn',
                     'stråle', 'väg', 'raket', 'reste sig', 'hav', 'säl', 'haj',
                     'argbigga', 'skunk', 'skyskrapa', 'snigel', 'orm', 'Spindel',
                     'ekorre', 'spårvagn', 'solros', 'Sötpeppar', 'tabell',
                     'tank', 'telefon', 'tv', 'tiger', 'traktor', 'tåg', 'öring',
                     'tulpan', 'sköldpadda', 'garderob', 'val', 'pilträd', 'Varg',
                     'kvinna', 'mask']

In [None]:
#text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_descriptions = [f"Det här är ett foto på {label}" for label in swe_cifar_classes]
#text_tokens = [[sot_token] + tokenizer.encode(desc) + [eot_token] for desc in text_descriptions]
#text_input = torch.zeros(len(text_tokens), model.context_length, dtype=torch.long)

#for i, tokens in enumerate(text_tokens):
#    text_input[i, :len(tokens)] = torch.tensor(tokens)

#text_input = text_input.cuda()
#text_input.shape

In [None]:
with torch.no_grad():
    #text_features = model.encode_text(text_input).float()
    text_features = sweclip(text_descriptions).float().cpu()
    text_features /= text_features.norm(dim=-1, keepdim=True)

text_probs = (100.0 * image_features.cpu() @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)

In [None]:
plt.figure(figsize=(16, 16))

for i, image in enumerate(images):
    plt.subplot(4, 4, 2 * i + 1)
    plt.imshow(image.permute(1, 2, 0))
    plt.axis("off")

    plt.subplot(4, 4, 2 * i + 2)
    y = np.arange(top_probs.shape[-1])
    plt.grid()
    plt.barh(y, top_probs[i])
    plt.gca().invert_yaxis()
    plt.gca().set_axisbelow(True)
    plt.yticks(y, [swe_cifar_classes[index] for index in top_labels[i].numpy()])
    plt.xlabel("probability")

plt.subplots_adjust(wspace=0.5)
plt.show()