In [2]:
%%capture
# Install required packages
!pip install -U transformers faiss-gpu torch Pillay tqdm ipywidgets
!pip install -U git+https://github.com/360CVGroup/FG-CLIP.git


## Initializations

In [None]:
import torch
import faiss
import numpy as np
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm
from transformers import AutoProcessor, AutoModel
import matplotlib.pyplot as plt

# Configuration
IMAGE_ROOT = Path("debbiedebrauwer/Spanje")
MODEL_NAME = "qihoo360/fg-clip-large"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32
INDEX_CONFIG = {
    "nlist": 256,          # Number of IVF clusters
    "nprobe": 16,          # Number of clusters to search
    "metric": faiss.METRIC_INNER_PRODUCT
}

# Initialize FG-CLIP model
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE).eval()




W0615 23:25:17.400000 18612 site-packages\torch\distributed\elastic\multiprocessing\redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


## Vector Embedder (turns images into vector embeddings)

In [2]:
def load_image_paths():
    """Load all images with location metadata"""
    locations = []
    image_paths = []
    
    for loc_dir in IMAGE_ROOT.iterdir():
        if loc_dir.is_dir():
            loc_id = loc_dir.name
            for img_path in loc_dir.glob("*.*"):
                if img_path.suffix.lower() in [".png", ".jpg", ".jpeg"]:
                    image_paths.append(img_path)
                    locations.append(loc_id)
    
    return image_paths, locations


def generate_embeddings(image_paths, batch_size=BATCH_SIZE):
    """Generate FG-CLIP embeddings with batch processing"""
    embeddings = []
    
    for i in tqdm(range(0, len(image_paths), batch_size)):
        batch_paths = image_paths[i:i+batch_size]
        images = [Image.open(p).convert("RGB") for p in batch_paths]
        
        with torch.no_grad():
            inputs = processor(images=images, return_tensors="pt", padding=True)
            inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
            features = model.get_image_features(**inputs)
            features = torch.nn.functional.normalize(features, dim=-1)
            embeddings.append(features.cpu().numpy())
    
    return np.concatenate(embeddings)

# Load and process all images
image_paths, location_ids = load_image_paths()
embeddings = generate_embeddings(image_paths)


  0%|          | 0/30 [00:00<?, ?it/s]

## Store embedding in database (only run once when storing vector embedding)

In [3]:
from qdrant_client import QdrantClient, models
import hashlib
import numpy as np
from datetime import datetime  # ADD THIS IMPORT


def store_embeddings(embeddings: np.ndarray, image_paths: list, location_ids: list):
    """Store embeddings in Qdrant while preserving existing data"""
    
    # Validate input shapes
    assert len(embeddings) == len(image_paths) == len(location_ids), \
        "Mismatched input array lengths"
    assert embeddings.ndim == 2, "Embeddings must be 2D array"

    # Initialize Qdrant client
    client = QdrantClient(url="http://localhost:6333")

    # Collection parameters
    COLLECTION_NAME = "film_locations"
    VECTOR_SIZE = embeddings.shape[1]

    # Create collection only if it doesn't exist
    if not client.collection_exists(COLLECTION_NAME):
        client.create_collection(
            collection_name=COLLECTION_NAME,
            vectors_config=models.VectorParams(
                size=VECTOR_SIZE,
                distance=models.Distance.COSINE,
                on_disk=True
            )
        )
        print(f"Created new collection: {COLLECTION_NAME}")
    else:
        # Verify vector size matches existing collection
        collection_info = client.get_collection(COLLECTION_NAME)
        existing_size = collection_info.config.params.vectors.size
        if existing_size != VECTOR_SIZE:
            raise ValueError(
                f"Existing vector size ({existing_size}) "
                f"doesn't match new embeddings ({VECTOR_SIZE})"
            )

    # Prepare points with conflict resolution
    def create_points():
        for path, loc_id, vector in zip(image_paths, location_ids, embeddings):
            yield models.PointStruct(
                id=hashlib.md5(str(path).encode()).hexdigest(),
                vector=vector.tolist(),
                payload={
                    "image_path": str(path),
                    "location_id": loc_id,
                    "updated_at": datetime.now().isoformat()
                }
            )

    # Upsert in batches with progress tracking
    BATCH_SIZE = 50
    total_points = len(embeddings)
    uploaded = 0
    
    print(f"Starting upsert of {total_points} points...")
    for batch_points in batch_generator(create_points(), BATCH_SIZE):
        client.upsert(
            collection_name=COLLECTION_NAME,
            points=batch_points,
            wait=True
        )
        uploaded += len(batch_points)
        print(f"Progress: {uploaded}/{total_points} ({uploaded/total_points:.1%})")

    print(f"Completed storing {total_points} embeddings in {COLLECTION_NAME}")
    print(f"Collection now contains {client.count(COLLECTION_NAME)} points")

def batch_generator(iterable, batch_size):
    """Helper to batch items from generator"""
    batch = []
    for item in iterable:   
        batch.append(item)
        if len(batch) >= batch_size:
            yield batch
            batch = []
    if batch:
        yield batch

store_embeddings(embeddings, image_paths, location_ids)


Created new collection: film_locations
Starting upsert of 954 points...
Progress: 50/954 (5.2%)
Progress: 100/954 (10.5%)
Progress: 150/954 (15.7%)
Progress: 200/954 (21.0%)
Progress: 250/954 (26.2%)
Progress: 300/954 (31.4%)
Progress: 350/954 (36.7%)
Progress: 400/954 (41.9%)
Progress: 450/954 (47.2%)
Progress: 500/954 (52.4%)
Progress: 550/954 (57.7%)
Progress: 600/954 (62.9%)
Progress: 650/954 (68.1%)
Progress: 700/954 (73.4%)
Progress: 750/954 (78.6%)
Progress: 800/954 (83.9%)
Progress: 850/954 (89.1%)
Progress: 900/954 (94.3%)
Progress: 950/954 (99.6%)
Progress: 954/954 (100.0%)
Completed storing 954 embeddings in film_locations
Collection now contains count=954 points


## Similarity Search (run this for the similarity search below to work)

In [2]:
import matplotlib.pyplot as plt
from qdrant_client import QdrantClient, models
from PIL import Image

# Collection parameters
COLLECTION_NAME = "film_locations"

def text_search(query: str, top_k: int = 5):
    """Search using text query."""
    with torch.no_grad():
        text_inputs = processor(text=query, return_tensors="pt", padding=True)
        text_inputs = {k: v.to(DEVICE) for k, v in text_inputs.items()}
        text_features = model.get_text_features(**text_inputs)
        query_embedding = torch.nn.functional.normalize(text_features, dim=-1)
    
    results = client.search(
        collection_name=COLLECTION_NAME,
        query_vector=query_embedding[0].cpu().numpy().tolist(),
        limit=top_k
    )
    
    return [{
        "score": hit.score,
        "image_path": hit.payload["image_path"],
        "location": hit.payload["location_id"]
    } for hit in results]

def plot_search_results(results, query, query_type="Text"):
    n = len(results)
    fig, axs = plt.subplots(1, n, figsize=(4*n, 4))
    if n == 1:
        axs = [axs]
    for ax, res in zip(axs, results):
        img = Image.open(res["image_path"])
        ax.imshow(img)
        ax.axis('off')
        ax.set_title(f"Loc: {res['location']}\nScore: {res['score']:.2f}")
    plt.show()

def image_search(image_path: str, top_k: int = 5):
    """Search using image query."""
    image = Image.open(image_path).convert("RGB")
    with torch.no_grad():
        image_inputs = processor(images=image, return_tensors="pt")
        image_inputs = {k: v.to(DEVICE) for k, v in image_inputs.items()}
        image_features = model.get_image_features(**image_inputs)
        query_embedding = torch.nn.functional.normalize(image_features, dim=-1)
    
    results = client.search(
        collection_name=COLLECTION_NAME,
        query_vector=query_embedding[0].cpu().numpy().tolist(),
        limit=top_k
    )
    
    return [{
        "score": hit.score,
        "image_path": hit.payload["image_path"],
        "location": hit.payload["location_id"]
    } for hit in results]    

# QDRANT path
QDRANT_PATH = "./qdrant_data"

# Initialize Qdrant client in local (on-disk) mode
client = QdrantClient(path=QDRANT_PATH)

## Similarity search query (run above for this to work)

In [None]:
query = "outside nature look with mountains"
text_results = text_search(query, top_k=5)
plot_search_results(text_results, query, query_type="Text")

#image_results = image_search("path/to/your/query_image.jpg", top_k=5)
#for res in image_results:
#    print(f"{res['score']:.3f} | {res['location']}: {res['image_path']}")

In [4]:

save_embeddings(embeddings, 'fgclip_embeddings.npy')
'''
save_faiss_index(index, 'faiss_index.index')
save_metadata(image_paths, location_ids, 'metadata.pkl')
'''

Embeddings saved to fgclip_embeddings.npy


"\nsave_faiss_index(index, 'faiss_index.index')\nsave_metadata(image_paths, location_ids, 'metadata.pkl')\n"

In [3]:
import os
import pickle

def load_embeddings(filename='fgclip_embeddings.npy'):
    if os.path.exists(filename):
        embeddings = np.load(filename)
        print(f"Embeddings loaded from {filename}, shape: {embeddings.shape}")
        return embeddings
    else:
        print(f"File {filename} not found.")
        return None
    
def load_faiss_index(filename='faiss_index.index'):
    if os.path.exists(filename):
        index = faiss.read_index(filename)
        print(f"FAISS index loaded from {filename}")
        return index
    else:
        print(f"File {filename} not found.")
        return None
    
def load_metadata(filename='metadata.pkl'):
    if os.path.exists(filename):
        with open(filename, 'rb') as f:
            data = pickle.load(f)
        print(f"Metadata loaded from {filename}")
        return data['image_paths'], data['location_ids']
    else:
        print(f"File {filename} not found.")
        return None, None

In [4]:
# load saved data
embeddings = load_embeddings('fgclip_embeddings.npy')
'''
index = load_faiss_index('faiss_index.index')
image_paths, location_ids = load_metadata('metadata.pkl')

# Initiate retrieval system
retrieval_system = LocationRetrievalSystem(index, image_paths, location_ids)
'''

Embeddings loaded from fgclip_embeddings.npy, shape: (823, 768)


"\nindex = load_faiss_index('faiss_index.index')\nimage_paths, location_ids = load_metadata('metadata.pkl')\n\n# Initiate retrieval system\nretrieval_system = LocationRetrievalSystem(index, image_paths, location_ids)\n"