In [1]:
# Configuration Constants

# Model configuration
MODEL_NAME = "openai/clip-vit-base-patch32"
EMBEDDING_BATCH_SIZE = 32
TEXT_MAX_LENGTH = 77

# Dataset configuration  
DATASET_ID = "adityajn105/flickr8k"
EMBEDDINGS_OUTPUT_FILE = "../data/flickr8k_embeddings.h5"

# HDF5 storage configuration
HDF5_COMPRESSION = 'gzip'
HDF5_COMPRESSION_LEVEL = 6


In [2]:
# Logging setup

import os
import sys
import h5py
import torch
import logging
import kagglehub
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from typing import Dict
from datetime import datetime
from torch.utils.data import Dataset
from transformers import CLIPProcessor, CLIPModel

logging.basicConfig(
    level=logging.INFO,
    format='%(levelname)s: %(message)s',
    stream=sys.stdout,
    force=True
)
logger = logging.getLogger(__name__)
logger.info("Logging initialized")


INFO: Logging initialized


In [3]:
# Download the dataset to local data folder

# Set custom cache directory to project's data folder
project_root = os.path.dirname(os.getcwd())
data_folder = os.path.join(project_root, 'data')
os.makedirs(data_folder, exist_ok=True)
os.environ['KAGGLEHUB_CACHE'] = data_folder

# Download and store paths
dataset_path = kagglehub.dataset_download(DATASET_ID)
images_path = os.path.join(dataset_path, 'Images')
captions_path = os.path.join(dataset_path, 'captions.txt')

logger.info(f"Dataset downloaded to: {dataset_path}")
logger.info(f"Images: {images_path}")
logger.info(f"Captions: {captions_path}")


INFO: Dataset downloaded to: /home/capacap/Projects/search_engine_demo/data/datasets/adityajn105/flickr8k/versions/1
INFO: Images: /home/capacap/Projects/search_engine_demo/data/datasets/adityajn105/flickr8k/versions/1/Images
INFO: Captions: /home/capacap/Projects/search_engine_demo/data/datasets/adityajn105/flickr8k/versions/1/captions.txt


In [4]:
# Set device, load model and processor

def get_device():
    # Get device based on cuda support
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Device: {device}")
    if torch.cuda.is_available():
        logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
        logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

    return device

def load_model(model_name, device):
    # Download and load model and processor
    logger.info(f"Loading {model_name}...")
    model = CLIPModel.from_pretrained(model_name).to(device)
    processor = CLIPProcessor.from_pretrained(model_name)
    logger.info(f"Loaded {model_name}")

    return model, processor

# Get the device to use
device = get_device()

# Load model and processor
model, processor= load_model(MODEL_NAME, device)


INFO: Device: cuda
INFO: GPU: NVIDIA GeForce RTX 3060 Ti
INFO: VRAM: 8.2 GB
INFO: Loading openai/clip-vit-base-patch32...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


INFO: Loaded openai/clip-vit-base-patch32


In [5]:
# Custom Dataset Class for Flickr8k

class Flickr8kDataset(Dataset):
    def __init__(self, captions_file: str, images_dir: str):
        self.images_dir = images_dir
        
        # Load and parse image-caption pairs
        self.data = self._load_data_from_file(captions_file)

        
    def _load_data_from_file(self, captions_file: str) -> pd.DataFrame:
        try:
            # Read the CSV file
            df = pd.read_csv(captions_file)
            
            # Cleanup - Remove any rows with missing values
            df = df.dropna()
            
            # Cleanup - Remove missing or invalid images
            valid_entries = []
            n_missing, n_invalid = 0, 0
            for _, row in df.iterrows():
                image_path = os.path.join(self.images_dir, row['image'])
                
                # Check if image file exists
                if not os.path.exists(image_path):
                    n_missing += 1
                    logger.info(f"Skipping missing image {row['image']}")
                    continue
                
                # Check if image file is valid
                try:
                    with Image.open(image_path) as img:
                        img.verify()
                    valid_entries.append(row)
                except Exception as e:
                    n_invalid += 1
                    logger.info(f"Skipping corrupted image {row['image']}: {e}")

            # Log cleanup results
            if n_missing > 0:
                logger.info(f"Skipped {n_missing} missing images")

            if n_invalid > 0:
                logger.info(f"Skipped {n_invalid} invalid images")

            # Raise error if no valid entries were found
            if not valid_entries:
                raise ValueError(f"No valid image-caption pairs found in {captions_file}: {e}")
                
            logger.info(f"Loaded {len(valid_entries)} image-caption pairs from {captions_file}")

            return pd.DataFrame(valid_entries)
            
        except Exception as e:
            logger.error(f"Error loading captions file: {e}")
            raise
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict:
        # Find image-caption pair
        row = self.data.iloc[idx]
        image_name = row['image']
        caption = row['caption']
        
        # Load image
        image_path = os.path.join(self.images_dir, image_name)
        image = Image.open(image_path).convert('RGB')
        
        return {
            'image': image,
            'caption': caption,
            'image_name': image_name,
            'image_path': image_path
        }

In [6]:
# Function for processing of the dataset and into embeddings data

def generate_embeddings(
    dataset: Flickr8kDataset, 
    model: CLIPModel, 
    processor: CLIPProcessor, 
    device: str, 
    batch_size: int
    ):
    # Initialize embeddings data
    embeddings_data = {
        'image_embeddings': [],
        'text_embeddings': [], 
        'image_names': [],
        'captions': []
    }
    
    # Set model to evaluation mode
    model.eval()
    with torch.no_grad():
        # Process the dataset in batches
        for batch_start in tqdm(range(0, len(dataset), batch_size), desc="Generating embeddings", unit="batch"):
            batch_end = min(batch_start + batch_size, len(dataset))
            batch_images = []
            batch_captions = []
            batch_names = []

            # Collect batch data
            for i in range(batch_start, batch_end):
                item = dataset[i]
                batch_images.append(item["image"])
                batch_captions.append(item["caption"])
                batch_names.append(item["image_name"])

            # Process batch images on device
            image_inputs = processor(
                images=batch_images, 
                return_tensors="pt", 
                padding=True
            ).to(device)
            image_embeddings = model.get_image_features(**image_inputs)

            # Process batch captions on device
            caption_inputs = processor(
                text=batch_captions, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,
                max_length=TEXT_MAX_LENGTH
                ).to(device)
            caption_embeddings = model.get_text_features(**caption_inputs)

            # Normalize image embeddings
            batch_img_embs = image_embeddings.cpu().numpy()
            batch_img_norms = np.linalg.norm(batch_img_embs, axis=1, keepdims=True)
            batch_img_normalized = batch_img_embs / batch_img_norms

            # Normalize text embeddings
            batch_txt_embs = caption_embeddings.cpu().numpy()
            batch_txt_norms = np.linalg.norm(batch_txt_embs, axis=1, keepdims=True)
            batch_txt_normalized = batch_txt_embs / batch_txt_norms

            # Update embeddings data
            embeddings_data["image_embeddings"].extend(batch_img_normalized)
            embeddings_data["text_embeddings"].extend(batch_txt_normalized)
            embeddings_data["image_names"].extend(batch_names)
            embeddings_data["captions"].extend(batch_captions)

            # Clean up cuda cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    return embeddings_data

In [7]:
# Function to save embeddings to HDF5 file

def save_embeddings_to_file(
    embeddings_data: dict, 
    filepath: str, 
    model_name: str,
    dataset_path: str,
    images_path: str,
    captions_path: str
    ):
    logger.info(f"Saving {len(embeddings_data['image_embeddings'])} embeddings to {filepath}...")
    try:
        # Ensure target directory exists
        os.makedirs(os.path.dirname(filepath), exist_ok=True)

        # Convert to numpy arrays
        img_embs = np.array(embeddings_data['image_embeddings'], dtype=np.float32)
        txt_embs = np.array(embeddings_data['text_embeddings'], dtype=np.float32)
        
        # Create the .h5 file
        with h5py.File(filepath, 'w') as f:
            # Store image embeddings
            f.create_dataset('image_embeddings', data=img_embs, compression=HDF5_COMPRESSION, compression_opts=HDF5_COMPRESSION_LEVEL)
            
            # Store text embeddings
            f.create_dataset('text_embeddings', data=txt_embs, compression=HDF5_COMPRESSION, compression_opts=HDF5_COMPRESSION_LEVEL)
            
            # Set data type for string data
            dt = h5py.string_dtype(encoding='utf-8')

            # Store image names
            f.create_dataset('image_names', data=embeddings_data['image_names'], dtype=dt, compression=HDF5_COMPRESSION)
            
            # Store image captions
            f.create_dataset('captions', data=embeddings_data['captions'], dtype=dt, compression=HDF5_COMPRESSION)
            
            # Set metadata
            f.attrs['creation_date'] = datetime.now().isoformat()
            f.attrs['total_samples'] = len(embeddings_data['image_embeddings'])
            f.attrs['embedding_dim'] = img_embs.shape[1]
            f.attrs['model_name'] = model_name
            f.attrs['dataset_path'] = dataset_path
            f.attrs['images_path'] = images_path
            f.attrs['captions_path'] = captions_path
            
        logger.info(f"Embeddings data saved to {filepath}")
        logger.info(f"Images path stored: {images_path}")

    except Exception as e:
        logger.error(f"Failed to save embeddings data: {e}")
        raise

In [8]:
# Execute the main workflow

# Load dataset
dataset = Flickr8kDataset(captions_path, images_path)

# Generate embeddings
embeddings_data = generate_embeddings(
    dataset=dataset,
    model=model,
    processor=processor,
    device=device,
    batch_size=EMBEDDING_BATCH_SIZE
)

# Save data to file
save_embeddings_to_file(
    embeddings_data=embeddings_data,
    filepath=EMBEDDINGS_OUTPUT_FILE, 
    model_name=MODEL_NAME,
    dataset_path=dataset_path,
    images_path=images_path,
    captions_path=captions_path
)

# Basic validation of saved data
with h5py.File(EMBEDDINGS_OUTPUT_FILE, 'r') as f:
    print(f"File structure: {list(f.keys())}")
    print(f"Metadata: {dict(f.attrs)}")

INFO: Loaded 40455 image-caption pairs from /home/capacap/Projects/search_engine_demo/data/datasets/adityajn105/flickr8k/versions/1/captions.txt


Generating embeddings: 100%|██████████| 1265/1265 [04:52<00:00,  4.33batch/s]

INFO: Saving 40455 embeddings to ../data/flickr8k_embeddings.h5...





INFO: Embeddings data saved to ../data/flickr8k_embeddings.h5
INFO: Images path stored: /home/capacap/Projects/search_engine_demo/data/datasets/adityajn105/flickr8k/versions/1/Images
File structure: ['captions', 'image_embeddings', 'image_names', 'text_embeddings']
Metadata: {'captions_path': '/home/capacap/Projects/search_engine_demo/data/datasets/adityajn105/flickr8k/versions/1/captions.txt', 'creation_date': '2025-08-28T21:38:35.963176', 'dataset_path': '/home/capacap/Projects/search_engine_demo/data/datasets/adityajn105/flickr8k/versions/1', 'embedding_dim': np.int64(512), 'images_path': '/home/capacap/Projects/search_engine_demo/data/datasets/adityajn105/flickr8k/versions/1/Images', 'model_name': 'openai/clip-vit-base-patch32', 'total_samples': np.int64(40455)}
