# Configuration

In [None]:
# Paths configuration
EMBEDDING_DIR_PATH = ""  # Path to the directory containing embedding files
OUTPUT_DIR_PATH = ""  # Path to save retrieval results

# Retrieval parameters
TARGET_EMBEDDING_DIM = 128
K_ANN_CANDIDATES = 100
K_FINAL_RETRIEVAL = 100

# Imports and Helper Functions

In [None]:
import torch
import faiss
import torch.nn as nn
from collections import defaultdict
from PIL import Image
import numpy as np
from tqdm import tqdm
import pickle
import os
from datetime import datetime
import pandas as pd

def to_torch_tensor(data):
    """Convert numpy array or torch tensor to torch tensor."""
    if isinstance(data, np.ndarray):
        return torch.from_numpy(data)
    elif isinstance(data, torch.Tensor):
        return data
    else:
        raise TypeError(f"Unsupported data type: {type(data)}")

def process_embedding_for_ann(emb_raw, target_dim, name="embedding"):
    """
    Process an embedding for ANN indexing by standardizing to target dimension.
    
    Args:
        emb_raw: Raw embedding tensor
        target_dim: Target embedding dimension
        name: Name for logging purposes
    
    Returns:
        Processed single vector or None if invalid
    """
    if emb_raw.numel() == 0:
        print(f"Warning: {name} has an empty embedding. Skipping.")
        return None
    
    # Remove singleton batch dimension
    emb_processed = emb_raw.squeeze(0)
    
    if emb_processed.numel() == 0:
        print(f"Warning: {name} has an empty embedding after processing. Skipping.")
        return None
    
    # Check dimensions
    is_2d_and_correct_dim = (emb_processed.dim() == 2 and emb_processed.shape[1] == target_dim)
    is_1d_and_correct_dim = (emb_processed.dim() == 1 and emb_processed.shape[0] == target_dim)
    
    if not (is_2d_and_correct_dim or is_1d_and_correct_dim):
        print(f"Warning: {name} has unexpected embedding shape: {emb_processed.shape}. Skipping.")
        return None
    
    # Convert to single vector
    if emb_processed.dim() == 2:
        single_vector = torch.mean(emb_processed, dim=0).to(torch.float32)
    else:
        single_vector = emb_processed.to(torch.float32)
    
    return single_vector

def build_faiss_index(embeddings_list, target_dim):
    """
    Build FAISS index from list of embeddings.
    
    Args:
        embeddings_list: List of raw embeddings
        target_dim: Target embedding dimension
    
    Returns:
        Tuple of (faiss_index, ann_vectors, valid_indices, embedding_map)
    """
    processed_ann_vectors = []
    valid_indices = []
    
    print(f"Standardizing embeddings to {target_dim} dimensions...")
    
    for i, emb_raw in enumerate(tqdm(embeddings_list, desc="Processing for ANN")):
        single_vector = process_embedding_for_ann(emb_raw, target_dim, f"embedding {i}")
        if single_vector is not None:
            processed_ann_vectors.append(single_vector)
            valid_indices.append(i)
    
    if not processed_ann_vectors:
        raise ValueError("No valid ANN vectors could be generated.")
    
    ann_vectors = torch.stack(processed_ann_vectors).cpu().numpy()
    
    # Store original embeddings for reranking
    embedding_map = {}
    for original_idx in valid_indices:
        embedding_map[original_idx] = embeddings_list[original_idx]
    
    # Create FAISS index
    d = ann_vectors.shape[1]
    index = faiss.IndexFlatL2(d)
    index.add(ann_vectors)
    print(f"FAISS index created with {index.ntotal} vectors.")
    
    return index, ann_vectors, valid_indices, embedding_map

def perform_retrieval_with_reranking(query_embeddings, faiss_index, valid_indices, 
                                     embedding_map, metadata, processor, 
                                     k_candidates=100, k_final=100):
    """
    Perform retrieval with FAISS + reranking.
    
    Args:
        query_embeddings: List of query embeddings
        faiss_index: FAISS index for ANN search
        valid_indices: Valid indices mapping
        embedding_map: Map of original embeddings
        metadata: Metadata for candidates
        processor: Model processor for scoring
        k_candidates: Number of candidates from ANN
        k_final: Number of final results
    
    Returns:
        List of search results
    """
    all_search_results = []
    
    # Process queries for ANN
    processed_query_vectors = []
    for q_emb_raw in tqdm(query_embeddings, desc="Processing queries for ANN"):
        single_vector = process_embedding_for_ann(q_emb_raw, TARGET_EMBEDDING_DIM, "query")
        if single_vector is None:
            single_vector = torch.zeros(TARGET_EMBEDDING_DIM, dtype=torch.float32)
        processed_query_vectors.append(single_vector)
    
    if not processed_query_vectors:
        raise ValueError("No valid query vectors could be generated.")
    
    query_ann_vectors = torch.stack(processed_query_vectors).cpu().numpy()
    
    print(f"Total queries to process: {len(query_embeddings)}")
    
    for i in tqdm(range(len(query_embeddings)), desc="Retrieving and Reranking"):
        query_original_emb = query_embeddings[i]
        query_ann_vec = query_ann_vectors[i:i+1]
        
        # FAISS search
        D, I = faiss_index.search(query_ann_vec, k_candidates)
        faiss_candidate_indices = I[0]
        
        # Reranking
        reranked_scores = []
        for faiss_idx in faiss_candidate_indices:
            if faiss_idx == -1:
                continue
            
            original_idx = valid_indices[faiss_idx]
            candidate_original_emb = embedding_map[original_idx]
            
            # Get metadata for this candidate
            candidate_meta = metadata[original_idx]
            
            # Calculate similarity score
            score = processor.score_multi_vector(query_original_emb, candidate_original_emb)
            reranked_scores.append((score.item(), original_idx, candidate_meta))
        
        # Sort by score (descending) and take top K
        reranked_scores.sort(key=lambda x: x[0], reverse=True)
        
        # Store results
        for rank, (score, result_idx, candidate_meta) in enumerate(reranked_scores[:k_final]):
            result = {
                'query_id': i,
                'result_rank': rank + 1,
                'result_id': result_idx,
                'score': float(f"{score:.4f}")
            }
            result.update(candidate_meta)
            all_search_results.append(result)
    
    return all_search_results

print("Helper functions loaded.")

# Retrieval with ColQwen2.5 (Image to Text)

In [None]:
from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor

# Load model and processor
model_name = "Metric-AI/ColQwen2.5-7b-multilingual-v1.0"
model_name_only = model_name.split("/")[-1]
processor = ColQwen2_5_Processor.from_pretrained(model_name)

# Load embeddings
file_name = "image_text_embeddings_ColQwen2.5-3b-multilingual-v1.0_0_3600.pkl"
with open(os.path.join(EMBEDDING_DIR_PATH, file_name), 'rb') as f:
    embeddings = pickle.load(f)

# Extract caption embeddings from all languages as candidates
all_caption_embeddings = []
caption_metadata = []

print("Extracting caption embeddings from all languages...")
for entry_idx, entry in enumerate(tqdm(embeddings, desc="Processing entries")):
    if 'text_embeddings' in entry:
        for lang_key, lang_embeddings in entry['text_embeddings'].items():
            # Extract language code from key (e.g., 'caption_embedding_en' -> 'en')
            lang_code = lang_key.replace('caption_embedding_', '')
            
            for caption_idx, caption_emb in enumerate(lang_embeddings):
                all_caption_embeddings.append(to_torch_tensor(caption_emb))
                caption_metadata.append({
                    'entry_idx': entry_idx,
                    'language': lang_code,
                    'caption_idx': caption_idx,
                    'image_key': entry.get('image_key', f'entry_{entry_idx}')
                })

# Extract image embeddings for queries
all_query_embeddings = []
query_metadata = []

print("Extracting image embeddings for queries...")
for entry_idx, entry in enumerate(tqdm(embeddings, desc="Processing queries")):
    if 'image_embedding' in entry:
        all_query_embeddings.append(to_torch_tensor(entry['image_embedding']))
        query_metadata.append({
            'entry_idx': entry_idx,
            'image_key': entry.get('image_key', f'entry_{entry_idx}')
        })

# Build FAISS index for captions
faiss_index, _, valid_caption_indices, caption_embedding_map = build_faiss_index(
    all_caption_embeddings, TARGET_EMBEDDING_DIM
)

# Perform retrieval with reranking
all_search_results = perform_retrieval_with_reranking(
    all_query_embeddings, 
    faiss_index, 
    valid_caption_indices, 
    caption_embedding_map, 
    caption_metadata, 
    processor,
    k_candidates=K_ANN_CANDIDATES,
    k_final=K_FINAL_RETRIEVAL
)

# Add query metadata to results
for result in all_search_results:
    query_id = result['query_id']
    result['query_image_key'] = query_metadata[query_id]['image_key']
    result['result_image_key'] = result.pop('image_key')

# Save results
os.makedirs(OUTPUT_DIR_PATH, exist_ok=True)
df = pd.DataFrame(all_search_results)
output_file = os.path.join(OUTPUT_DIR_PATH, f'{file_name}_multilingual_results.csv')
df.to_csv(output_file, index=False)

print(f"Results saved to: {output_file}")
print(f"DataFrame shape: {df.shape}")
print(f"Language distribution:")
print(df['language'].value_counts())
print(f"Sample results:")
print(df.head(10))

# Retrieval with ColQwen2 (Image to Text)

In [None]:
from colpali_engine.models import ColQwen2, ColQwen2Processor

# Load model and processor
model_name = "vidore/colqwen2-v1.0"
model_name_only = model_name.split("/")[-1]
processor = ColQwen2Processor.from_pretrained(model_name)

# Load embeddings
file_name = "image_text_embeddings_colqwen2-v1.0_0_3600.pkl"
with open(os.path.join(EMBEDDING_DIR_PATH, file_name), 'rb') as f:
    embeddings = pickle.load(f)

# Extract embeddings from schema
all_candidate_embeddings = []
candidate_metadata = []
all_query_embeddings = []
query_metadata = []

print("Extracting embeddings from schema...")
for entry_idx, entry in enumerate(tqdm(embeddings, desc="Processing entries")):
    # Extract image embedding (used as query)
    if 'image_embedding' in entry:
        image_emb = to_torch_tensor(entry['image_embedding'])
        all_query_embeddings.append(image_emb)
        query_metadata.append({
            'entry_idx': entry_idx,
            'image_key': entry.get('image_key', f'entry_{entry_idx}')
        })
    else:
        print(f"Warning: No image_embedding found in entry {entry_idx}")
        continue
    
    # Extract text embeddings from all languages (used as candidates)
    if 'text_embeddings' in entry:
        for lang_key, lang_embeddings_list in entry['text_embeddings'].items():
            lang_code = lang_key.replace('caption_embedding_', '')
            for caption_idx, caption_emb_array in enumerate(lang_embeddings_list):
                caption_emb = to_torch_tensor(caption_emb_array)
                all_candidate_embeddings.append(caption_emb)
                candidate_metadata.append({
                    'entry_idx': entry_idx,
                    'type': 'caption',
                    'language': lang_code,
                    'caption_idx': caption_idx,
                    'image_key': entry.get('image_key', f'entry_{entry_idx}')
                })
    else:
        print(f"Warning: No text_embeddings found in entry {entry_idx}")

print(f"Total candidate embeddings: {len(all_candidate_embeddings)}")
print(f"Total query embeddings: {len(all_query_embeddings)}")

# Build FAISS index for candidates
faiss_index, _, valid_indices, embedding_map = build_faiss_index(
    all_candidate_embeddings, TARGET_EMBEDDING_DIM
)

# Perform retrieval with reranking
all_search_results = perform_retrieval_with_reranking(
    all_query_embeddings,
    faiss_index,
    valid_indices,
    embedding_map,
    candidate_metadata,
    processor,
    k_candidates=K_ANN_CANDIDATES,
    k_final=K_FINAL_RETRIEVAL
)

# Add query metadata to results
for result in all_search_results:
    query_id = result['query_id']
    result['query_image_key'] = query_metadata[query_id]['image_key']

# Save results
os.makedirs(OUTPUT_DIR_PATH, exist_ok=True)
df = pd.DataFrame(all_search_results)
output_file = os.path.join(OUTPUT_DIR_PATH, f'{file_name}_results.csv')
df.to_csv(output_file, index=False)

print(f"Results saved to: {output_file}")
print(f"DataFrame shape: {df.shape}")
print(f"Candidate types distribution:")
print(df['type'].value_counts())
if 'language' in df.columns:
    print(f"Language distribution:")
    print(df['language'].value_counts())