In [1]:
!pip install openai-clip
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import CocoCaptions
from torchvision import transforms
import clip
import numpy as np
from PIL import Image
from typing import List, Tuple, Dict
import os
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm, trange
from google.colab import drive

Collecting openai-clip
  Downloading openai-clip-1.0.1.tar.gz (1.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from openai-clip)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: openai-clip
  Building wheel for openai-clip (setup.py) ... [?25l[?25hdone
  Created wheel for openai-clip: filename=openai_clip-1.0.1-py3-none-any.whl size=1368605 sha256=0ea51f3df37216b41ad55767a96d0ea6782e5359f945b21c1a14d53141ea78bb
  Stored in directory: /root/.cache/pip/wheels/08/77/8e/8d2f862df6bf7fb4e2007062d2cbaeae49862ec7b56d041229
Successfully built openai-clip
Installing collected packages: ftfy, openai-clip
Successfully insta

In [2]:
class Model(nn.Module):
    def __init__(self, clip_model: str = "ViT-B/32"):
        super().__init__()
        self.clip_model, self.preprocess = clip.load(clip_model)

        for param in self.clip_model.parameters():
            param.requires_grad = False

    def forward(self, images: torch.Tensor, captions: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
        image_features = self.clip_model.encode_image(images)
        text_features = self.clip_model.encode_text(clip.tokenize(captions).to(images.device))

        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)

        return image_features, text_features

In [3]:
drive.mount('/content/drive')
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Model().to(device)
model.load_state_dict(torch.load("/content/drive/MyDrive/model.pth", weights_only=True))
model.eval()
caption_embeddings = torch.load("/content/drive/MyDrive/caption_embeddings.pth")
caption_pool = torch.load("/content/drive/MyDrive/captions.pth")

Mounted at /content/drive


100%|████████████████████████████████████████| 338M/338M [00:02<00:00, 159MiB/s]
  caption_embeddings = torch.load("/content/drive/MyDrive/caption_embeddings.pth")
  caption_pool = torch.load("/content/drive/MyDrive/captions.pth")


In [4]:
def find_best_caption(model,
                     image,
                     caption_embeddings, caption_pool,
                     device="cuda"):
    transform = model.preprocess
    img_tensor = model.preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.clip_model.encode_image(img_tensor)
    image_features = image_features / image_features.norm(dim=1, keepdim=True)
    text_features = torch.stack(caption_embeddings).squeeze(1)
    similarities = (image_features @ text_features.t()).squeeze()
    best_idx = similarities.argmax().item()

    return caption_pool[best_idx], similarities[best_idx].item()

In [None]:
image = Image.open("image.jpg") # Replace with path to your image
best_caption, similarity = find_best_caption(model, image, caption_embeddings, caption_pool, device)
print(f"Caption: {best_caption}")
image.show()