In [8]:
!pip install torch torchvision transformers pillow tqdm matplotlib
!pip install git+https://github.com/openai/CLIP.git

import os
import torch
import clip
from PIL import Image
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import requests
import zipfile
from io import BytesIO
import matplotlib.pyplot as plt

def download_flickr8k():
    print("Downloading Flickr8k images...")
    image_url = "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip"
    response = requests.get(image_url)
    with zipfile.ZipFile(BytesIO(response.content)) as zip_ref:
        zip_ref.extractall(".")

    print("Downloading Flickr8k captions...")
    caption_url = "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip"
    response = requests.get(caption_url)
    with zipfile.ZipFile(BytesIO(response.content)) as zip_ref:
        zip_ref.extractall(".")

def custom_collate(batch):
    images = torch.stack([item['image'] for item in batch])
    image_raw = [item['image_raw'] for item in batch]
    captions = [item['captions'] for item in batch]
    image_ids = [item['image_id'] for item in batch]
    image_paths = [item['image_path'] for item in batch]

    return {
        'image': images,
        'image_raw': image_raw,
        'captions': captions,
        'image_id': image_ids,
        'image_path': image_paths
    }

class Flickr8kDataset(Dataset):
    def __init__(self, image_dir, captions_file, transform=None):
      self.image_dir = image_dir
      self.transform = transform
      self.image_captions = {}
      self.image_paths = []

      with open(captions_file, 'r', encoding='utf-8') as f:
          for line in f:
              parts = line.strip().split('\t')
              if len(parts) != 2:
                  continue

              image_name = parts[0].split('#')[0]
              if '.jpg.1' in image_name:
                  image_name = image_name.replace('.jpg.1', '.jpg')
              caption = parts[1].strip()

              image_path = os.path.join('/content', self.image_dir, image_name)

              if image_name in self.image_captions:
                  self.image_captions[image_name].append(caption)
              else:
                  self.image_captions[image_name] = [caption]
                  self.image_paths.append(image_path)

      self.image_paths = list(dict.fromkeys(self.image_paths))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image_name = os.path.basename(image_path)

        try:
            image = Image.open(image_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            raise

        if self.transform:
            image_tensor = self.transform(image)

        return {
            'image': image_tensor,
            'image_raw': image,
            'captions': self.image_captions[image_name],
            'image_id': idx,
            'image_path': image_path
        }

class CLIPRetrieval:
    def __init__(self, device='cuda'):
        self.device = device
        self.model, self.preprocess = clip.load('ViT-B/32', device)
        self.dataset = None
        self.image_features = None
        self.all_images = []

    def load_dataset(self, image_dir, captions_file):
        self.dataset = Flickr8kDataset(image_dir, captions_file, self.preprocess)
        dataloader = DataLoader(
            self.dataset,
            batch_size=32,
            shuffle=False,
            collate_fn=custom_collate
        )

        image_features = []

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Processing images"):
                images = batch['image'].to(self.device)
                self.all_images.extend(batch['image_raw'])

                batch_image_features = self.model.encode_image(images)
                batch_image_features = F.normalize(batch_image_features, dim=1)
                image_features.append(batch_image_features)

        self.image_features = torch.cat(image_features)

    def evaluate_recalls(self):
        if self.dataset is None:
            raise ValueError("Dataset not loaded. Call load_dataset first.")

        dataloader = DataLoader(
            self.dataset,
            batch_size=32,
            shuffle=False,
            collate_fn=custom_collate
        )
        text_features = []

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Processing text"):
                captions = [caps[0] for caps in batch['captions']]
                text_tokens = clip.tokenize(captions).to(self.device)
                batch_text_features = self.model.encode_text(text_tokens)
                batch_text_features = F.normalize(batch_text_features, dim=1)
                text_features.append(batch_text_features)

        text_features = torch.cat(text_features)
        similarity = torch.matmul(text_features, self.image_features.t())

        results = calculate_recalls(similarity)
        print("\nCLIP Recall@K Results:")
        for k, v in results.items():
            print(f"{k}: {v:.2f}%")

        return results

    def retrieve_images(self, query_caption, k=5):
        with torch.no_grad():
            text_tokens = clip.tokenize([query_caption]).to(self.device)
            text_features = self.model.encode_text(text_tokens)
            text_features = F.normalize(text_features, dim=1)

            similarity = torch.matmul(text_features, self.image_features.t())[0]
            top_k_scores, top_k_indices = similarity.topk(k)

            retrieved_results = []
            for score, idx in zip(top_k_scores.cpu(), top_k_indices.cpu()):
                retrieved_results.append((self.all_images[idx], score.item()))

            return retrieved_results

def calculate_recalls(similarity_matrix, k_values=[1, 5, 10]):
    results = {}
    num_samples = similarity_matrix.shape[0]


    ground_truth = torch.arange(num_samples).to(similarity_matrix.device)
    ranks = torch.where(torch.argsort(similarity_matrix, dim=1, descending=True) == ground_truth.unsqueeze(1))[1]

    for k in k_values:
        recall_k = (ranks < k).float().mean().item()
        results[f'R@{k}'] = recall_k * 100

    return results

def find_image_for_caption(caption_file, target_caption):

    with open(caption_file, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) != 2:
                continue
            image_name = parts[0].split('#')[0]
            if '.jpg.1' in image_name:
                image_name = image_name.replace('.jpg.1', '.jpg')
            caption = parts[1].strip()

            if caption.lower() == target_caption.lower():
                return image_name
    return None

def visualize_retrieval_results(query_caption, top_5_images, save_path='retrieval_results.png', caption_file='/content/Flickr8k.token.txt'):
    image_name = find_image_for_caption(caption_file, query_caption)
    if image_name:
        try:
            original_image = Image.open(os.path.join('/content/Flickr8k_Dataset', image_name)).convert('RGB')
        except Exception as e:
            print(f"Warning: Could not load original image: {e}")
            original_image = None
    else:
        print("Warning: Could not find original image for the caption")
        original_image = None

    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle(f'Query: "{query_caption}"', fontsize=12, wrap=True)

    if original_image:
        axes[0, 0].imshow(original_image)
        axes[0, 0].set_title('Original Image', color='red')
    else:
        axes[0, 0].text(0.5, 0.5, 'Original image\nnot found',
                       ha='center', va='center', color='red')
    axes[0, 0].axis('off')

    for idx, (img, score) in enumerate(top_5_images):
        row = (idx + 1) // 3
        col = (idx + 1) % 3
        axes[row, col].imshow(img)
        axes[row, col].axis('off')
        axes[row, col].set_title(f'Score: {score:.3f}')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    return save_path

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-nro_wskk
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-nro_wskk
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [13]:
if not os.path.exists('Flickr8k_Dataset'):
  download_flickr8k()

  print("\nCurrent directory structure:")
  os.system('ls -R Flickr8k_Dataset')


Downloading Flickr8k images...
Downloading Flickr8k captions...

Current directory structure:


In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clip_retrieval = CLIPRetrieval(device)

clip_retrieval.load_dataset('Flickr8k_Dataset', '/content/Flickr8k.token.txt')

clip_retrieval.evaluate_recalls()

query_caption = "a dog playing in the water"
retrieved_images = clip_retrieval.retrieve_images(query_caption, k=5)

result_path = visualize_retrieval_results(query_caption, retrieved_images)
print(f"\nRetrieval results visualization saved to: {result_path}")

Processing images: 100%|██████████| 253/253 [01:10<00:00,  3.59it/s]
Processing text: 100%|██████████| 253/253 [01:15<00:00,  3.34it/s]



CLIP Recall@K Results:
R@1: 29.80%
R@5: 53.03%
R@10: 63.14%

Retrieval results visualization saved to: retrieval_results.png


In [17]:
query_caption = "A girl going into a wooden building ."
retrieved_images = clip_retrieval.retrieve_images(query_caption, k=5)

result_path = visualize_retrieval_results(query_caption, retrieved_images)
print(f"\nRetrieval results visualization saved to: {result_path}")


Retrieval results visualization saved to: retrieval_results.png
