In [1]:
import torch
import open_clip
from PIL import Image

In [2]:
torch.cuda.is_available()

True

In [3]:
class OpenCLIPEmbedder:
    def __init__(self, model_name="ViT-L-14", pretrained="laion2b_s32b_b82k", device=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
        self.tokenizer = open_clip.get_tokenizer(model_name)
        self.model = self.model.to(self.device)
    
    def encode_text(self, texts, normalize=True):
        if isinstance(texts, str):
            texts = [texts]
        
        tokens = self.tokenizer(texts).to(self.device)
        with torch.no_grad():
            features = self.model.encode_text(tokens)
        if normalize:
            features /= features.norm(dim=-1, keepdim=True)
        return features.cpu().numpy()
    
    def encode_image(self, images, normalize=True):
        if isinstance(images, str):
            images = [Image.open(images)]
        elif isinstance(images, Image.Image):
            images = [images]
        
        images = torch.stack([self.preprocess(img) for img in images]).to(self.device)
        with torch.no_grad():
            features = self.model.encode_image(images)
        if normalize:
            features /= features.norm(dim=-1, keepdim=True)
        return features.cpu().numpy()


In [4]:
from pathlib import Path
import os
from datasets import load_dataset

In [5]:
CACHE_DIR = os.path.join(Path(os.getcwd()).resolve().parents[1] , "local_only", "data")
os.makedirs(CACHE_DIR, exist_ok=True)
DATASET_NAME = "nlphuji/flickr30k"

In [6]:
def download_flickr30k_dataset(dataset_name, download_location=None):
    dataset = load_dataset(dataset_name, cache_dir=download_location)
    return dataset

In [7]:
dataset = download_flickr30k_dataset(dataset_name = DATASET_NAME, download_location = CACHE_DIR)

In [8]:
data = dataset['test']

In [9]:
data[2]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=375x500>,
 'caption': ['A child in a pink dress is climbing up a set of stairs in an entry way.',
  'A little girl in a pink dress going into a wooden cabin.',
  'A little girl climbing the stairs to her playhouse.',
  'A little girl climbing into a wooden playhouse.',
  'A girl going into a wooden building.'],
 'sentids': ['10', '11', '12', '13', '14'],
 'split': 'train',
 'img_id': '2',
 'filename': '1000268201.jpg'}

In [10]:
embed_model = OpenCLIPEmbedder(device='cuda')

open_clip_pytorch_model.bin:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

In [17]:
record = data[2]
image = [record['image']]
texts = record['caption']

In [18]:
texts

['A child in a pink dress is climbing up a set of stairs in an entry way.',
 'A little girl in a pink dress going into a wooden cabin.',
 'A little girl climbing the stairs to her playhouse.',
 'A little girl climbing into a wooden playhouse.',
 'A girl going into a wooden building.']

In [19]:
%%time
text_embeddings = embed_model.encode_text(texts)

CPU times: user 328 ms, sys: 11.4 ms, total: 339 ms
Wall time: 443 ms


In [21]:
image_embeddings = embed_model.encode_image(image)

In [25]:
import numpy as np

# Ensure correct shapes
image_vec = image_embeddings[0]           # shape: (dim,)
text_vecs = text_embeddings              # shape: (N, dim)

# Cosine similarity (dot product, since normalized)
similarities = np.dot(text_vecs, image_vec)  # shape: (N,)

# Rank captions by similarity
sorted_indices = np.argsort(similarities)[::-1]
for idx in sorted_indices:
    print(f"Caption: {texts[idx]} | Similarity: {similarities[idx]:.4f}")


Caption: A little girl climbing into a wooden playhouse. | Similarity: 0.3564
Caption: A little girl climbing the stairs to her playhouse. | Similarity: 0.3455
Caption: A little girl in a pink dress going into a wooden cabin. | Similarity: 0.3359
Caption: A girl going into a wooden building. | Similarity: 0.3236
Caption: A child in a pink dress is climbing up a set of stairs in an entry way. | Similarity: 0.2667


In [27]:
def get_image_caption_similarities(image_embeddings, text_embeddings):
        # Ensure correct shapes
    image_vec = image_embeddings[0]           # shape: (dim,)
    text_vecs = text_embeddings              # shape: (N, dim)
    
    # Cosine similarity (dot product, since normalized)
    similarities = np.dot(text_vecs, image_vec)  # shape: (N,)
    
    # Rank captions by similarity
    sorted_indices = np.argsort(similarities)[::-1]
    for idx in sorted_indices:
        print(f"Caption: {texts[idx]} | Similarity: {similarities[idx]:.4f}")


In [32]:
for i in range(10):
    data_point = data[i]
    image = [data_point['image']]
    texts = data_point['caption']

    text_embeddings = embed_model.encode_text(texts)
    image_embeddings = embed_model.encode_image(image)

    get_image_caption_similarities(image_embeddings, text_embeddings)

    print(5*'=')

Caption: Two young guys with shaggy hair look at their hands while hanging out in the yard. | Similarity: 0.3397
Caption: Two young, White males are outside near many bushes. | Similarity: 0.2736
Caption: Two men in green shirts are standing in a yard. | Similarity: 0.2658
Caption: Two friends enjoy time spent together. | Similarity: 0.2272
Caption: A man in a blue shirt standing in a garden. | Similarity: 0.1979
=====
Caption: Several men in hard hats are operating a giant pulley system. | Similarity: 0.3361
Caption: Workers look down from up above on a piece of equipment. | Similarity: 0.3074
Caption: Three men on a large rig. | Similarity: 0.2631
Caption: Four men on top of a tall structure. | Similarity: 0.2566
Caption: Two men working on a machine wearing hard hats. | Similarity: 0.2100
=====
Caption: A little girl climbing into a wooden playhouse. | Similarity: 0.3564
Caption: A little girl climbing the stairs to her playhouse. | Similarity: 0.3455
Caption: A little girl in a pin