In [None]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import os
from CLIP import clip
from torchvision.datasets import Food101
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torch.cuda.amp import autocast, GradScaler

# Load the CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# Load food101
food101_dataset = Food101(root='./food101_data', split='train', download=True)

# subset_size = 512
subset_size = 256*2**3
print(subset_size)
print(len(food101_dataset))

subset_indices = torch.randperm(len(food101_dataset))[:subset_size]
subset_dataset = Subset(food101_dataset, subset_indices)


def custom_collate(batch):
    batch_images = torch.stack([transforms.ToTensor()(transforms.Resize((224, 224))(img)) for img, _ in batch])
    batch_images = batch_images.to(device)

    return batch_images

batch_size = 16
dataloader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)

# process images and texts for the entire dataset
subset_features = []
scaler = GradScaler()

for batch_images in dataloader:
  with autocast():
    batch_features = model.encode_image(batch_images)
    subset_features.append(batch_features)

subset_features = torch.cat(subset_features)

In [None]:
def retrieve_images(input, model, transform, subset_features, dataset, top_k=5):
    # Process the text description
    text = clip.tokenize([input]).to(device)
    text_features = model.encode_text(text)

    # get similarity between text and subset images
    similarity_scores = torch.nn.functional.cosine_similarity(text_features, subset_features)

    # Get indices of top-k similar images
    top_k_indices = torch.topk(similarity_scores, top_k).indices.cpu().numpy()

    # Display the top-k images
    plt.figure(figsize=(15, 3))

    plt.subplots_adjust(top=0.8)
    plt.suptitle(f"Input: {input}", fontsize=16)

    for i, idx in enumerate(top_k_indices):
        ax = plt.subplot(1, top_k, i + 1)
        ax.axis("off")
        ax.set_title(f"Top {i + 1}")
        plt.imshow(np.asarray(dataset[idx][0]))

    plt.show()

retrieve_images("soup", model, preprocess, subset_features, subset_dataset)

In [None]:
retrieve_images("cereal", model, preprocess, subset_features, subset_dataset)

In [None]:
retrieve_images("the best food", model, preprocess, subset_features, subset_dataset)

In [None]:
retrieve_images("the worst food", model, preprocess, subset_features, subset_dataset)