In [1]:
import os
import glob

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

from transformers import CLIPProcessor, CLIPModel
from utils import Transform, ImageTextDataset, collate_fn, tokenizer

In [2]:
device = 0 if torch.cuda.is_available() else "cpu"
device

0

In [3]:
model = CLIPModel.from_pretrained('./out/m1/')

def text_features(captions) -> np.ndarray:
    with torch.no_grad():
        inputs = tokenizer(captions, max_length=32, padding="max_length", return_tensors="pt", truncation=True)
        features = model.get_text_features(**inputs)

    return features.numpy()

In [8]:
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)

outputs = model(**inputs)
logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities

In [9]:
probs

tensor([[9.9962e-01, 3.7815e-04]], grad_fn=<SoftmaxBackward0>)

In [6]:
precomputed_image_embed = np.load('b32w2x3_25k_features.npy')
precomputed_image_embed.shape

(24857, 512)

In [7]:
df = pd.read_csv('meta/valid.tsv', sep='\t')

In [9]:
from IPython.display import Image
from IPython.core.display import HTML

inputs = text_features(["A photo of a train leaving the train station"])
for text_embed in inputs:
    sim = (text_embed @ precomputed_image_embed.T)
    order = (-sim).argsort() # order = order.argsort(descending=True)
    
    for pid in df.iloc[order[:5], 0]:
        display(Image(url=f'https://unsplash.com/photos/{pid}/download?force=true&w=360'))
        display(HTML(f'<a href=https://unsplash.com/photos/{pid}> source </a>'))
    # display(image)

[13214  4722 14901 ...  7521 10265   250]
