In [7]:
import os
import glob

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformers import CLIPProcessor, CLIPModel
from utils import Transform, ImageTextDataset, collate_fn, processor

In [8]:
# Load the open CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("./out/m1").to(device)

test_loader = DataLoader(
    ImageTextDataset('data', "valid", transform=Transform(224, False)),
    batch_size=1,
    collate_fn=collate_fn,
)

In [15]:
# Function that computes the feature vectors for a batch of images
def image_features(batch) -> np.ndarray:
    with torch.no_grad():
        # Encode the photos batch to compute the feature vectors and normalize them
        features = model.get_image_features(batch['pixel_values'].to(device))
        features = F.normalize(features, dim=-1)

    return features.cpu().numpy()

def save_image_features(save_name: str = "25k_features.npy") -> None:
    all_features = []
    for batch in tqdm(test_loader):
        features = image_features(batch)
        all_features.append(features)
    all_features = np.stack(all_features).squeeze()
    np.save(save_name, all_features)

In [16]:
save_image_features()

  0%|          | 0/24857 [00:00<?, ?it/s]

In [18]:
precomputed_image_embeddings = np.load("25k_features.npy")