### Tutorial

We will use [**LlamaIndex**](https://huggingface.co/llamaindex/vdr-2b-multi-v1/tree/main) for generating multimodal embeddings and [**Qdrant**](http://qdrant.tech) for storing and retrieving them.

In [None]:
%pip install llama-index-embeddings-huggingface qdrant-client datasets Pillow tqdm

In [None]:
from qdrant_client import QdrantClient, models
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from datasets import load_dataset
from PIL import Image
import os
import json
from tqdm.notebook import tqdm

DATASET_NAME = "sujet-ai/Sujet-Finance-QA-Vision-100k"
SPLIT_NAME = "train"
NUM_IMAGES = 2000
TEMP_IMAGE_DIR = "temp_images"

if not os.path.exists(TEMP_IMAGE_DIR):
    os.makedirs(TEMP_IMAGE_DIR)

print(f"Loading dataset {DATASET_NAME}...")
try:
    dataset = load_dataset(DATASET_NAME, split=SPLIT_NAME)
    dataset = dataset.select(range(NUM_IMAGES))
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Attempting to stream...")
    dataset = load_dataset(DATASET_NAME, split=SPLIT_NAME, streaming=True)
    dataset = dataset.take(NUM_IMAGES)


print(f"Processing {NUM_IMAGES} items...")
image_paths = []
contents = []
metadata_list = []

for i, item in enumerate(tqdm(dataset, total=NUM_IMAGES)):
    try:
        doc_id = item.get("doc_id", f"item_{i}")
        content = item.get("content", "")
        qa_pairs = item.get("qa_pairs", [])
        image = item.get("image")

        if image is None or not hasattr(image, "save"):
            print(f"Warning: Skipping item {i} due to missing or invalid image.")
            continue

        # Define image path
        image_filename = f"{doc_id.replace('/', '_')}_{i}.png"
        image_path = os.path.join(TEMP_IMAGE_DIR, image_filename)

        # Save image locally
        image.save(image_path)

        # Append data to lists
        image_paths.append(image_path)
        contents.append(content)
        metadata_list.append(
            {
                "doc_id": doc_id,
                "qa_pairs": json.dumps(
                    qa_pairs
                ),
            }
        )

    except Exception as e:
        print(f"Error processing item {i} (doc_id: {item.get('doc_id', 'N/A')}): {e}")
        continue


print(f"Successfully processed {len(image_paths)} items.")

In [None]:
client = QdrantClient(
    host="",
    api_key="",
)

Let's embed a very short selection of images and their captions in the **shared embedding space**.

In [None]:
model = HuggingFaceEmbedding(
    model_name="llamaindex/vdr-2b-multi-v1",
    device="cuda",
    trust_remote_code=True,
)

In [None]:
BATCH_SIZE = 32  # Adjust this based on your VRAM capacity

print("Generating text embeddings...")
text_embeddings = []
if contents:
    for i in tqdm(range(0, len(contents), BATCH_SIZE), desc="Text Embedding Batches"):
        batch_contents = contents[i:i + BATCH_SIZE]
        batch_embeddings = model.get_text_embedding_batch(batch_contents, show_progress=False) # Disable inner progress bar
        text_embeddings.extend(batch_embeddings)
else:
    print("Warning: No content found to embed.")


print("Generating image embeddings...")
image_embeddings = []
if image_paths:
    for i in tqdm(range(0, len(image_paths), BATCH_SIZE), desc="Image Embedding Batches"):
        batch_image_paths = image_paths[i:i + BATCH_SIZE]
        batch_embeddings = model.get_image_embedding_batch(batch_image_paths, show_progress=False) # Disable inner progress bar
        image_embeddings.extend(batch_embeddings)
else:
    print("Warning: No images found to embed.")

if not text_embeddings or not image_embeddings:
    print("Error: Embedding generation failed. Cannot proceed.")
elif len(text_embeddings) != len(image_embeddings):
    print(
        f"Error: Mismatch in number of text ({len(text_embeddings)}) and image ({len(image_embeddings)}) embeddings."
    )
else:
    print("Embeddings generated successfully.")

Create a **Collection**

In [None]:
COLLECTION_NAME = "sujet-finance-multi" 

if image_embeddings and text_embeddings:
    if not client.collection_exists(COLLECTION_NAME):
        print(f"Creating collection: {COLLECTION_NAME}")
        client.create_collection(
            collection_name=COLLECTION_NAME,
            vectors_config={
                "image": models.VectorParams(
                    size=len(image_embeddings[0]), # Get size dynamically
                    distance=models.Distance.COSINE
                ),
                "text": models.VectorParams(
                    size=len(text_embeddings[0]), # Get size dynamically
                    distance=models.Distance.COSINE
                ),
            }
        )
    else:
        print(f"Collection {COLLECTION_NAME} already exists.")
else:
    print("Skipping collection creation due to embedding errors.")

Now let's upload our images with captions to the **Collection**. Each image with its caption will create a [Point](https://qdrant.tech/documentation/concepts/points/) in Qdrant.

In [None]:
if text_embeddings and image_embeddings and len(text_embeddings) == len(metadata_list):
    print(f"Uploading {len(metadata_list)} points to collection {COLLECTION_NAME}...")
    client.upload_points(
        collection_name=COLLECTION_NAME,
        points=[
            models.PointStruct(
                id=idx,
                vector={
                    "text": text_embeddings[idx],
                    "image": image_embeddings[idx],
                },
                payload={
                    "doc_id": metadata_list[idx]["doc_id"],
                    "content": contents[idx],
                    "qa_pairs_json": metadata_list[idx]["qa_pairs"],
                    "image_path": image_paths[idx],
                },
            )
            for idx in range(len(metadata_list))
        ],
        wait=True,
    )
    print("Upload complete.")
else:
    print("Skipping point upload due to inconsistencies in generated data.")

In [None]:
import shutil

print(f"Removing temporary image directory: {TEMP_IMAGE_DIR}")
if os.path.exists(TEMP_IMAGE_DIR):
    try:
        shutil.rmtree(TEMP_IMAGE_DIR)
        print("Temporary directory removed.")
    except OSError as e:
        print(f"Error removing temporary directory {TEMP_IMAGE_DIR}: {e}")

print(f"\nStored metadata for {len(metadata_list)} items.")
if metadata_list:
    print("Example metadata (first item):")
    print(f"  doc_id: {metadata_list[0]['doc_id']}")
    # first_qa_pairs = json.loads(metadata_list[0]["qa_pairs"])
    # print(f"  qa_pairs: {first_qa_pairs}")
    print(f"  qa_pairs (JSON string): {metadata_list[0]['qa_pairs'][:100]}...")

Let'see what image we will get to the query "*Adventures on snow hills*"

In [None]:
from PIL import Image

find_image = model.get_query_embedding("Adventures on snow hills")

Image.open(client.query_points(
    collection_name=COLLECTION_NAME,
    query=find_image,
    using="image",
    with_payload=["image"],
    limit=1
).points[0].payload['image'])

Let's also run the same query in Italian and compare the results.

In [None]:
Image.open(client.query_points(
    collection_name=COLLECTION_NAME,
    query=model.get_query_embedding("Avventure sulle colline innevate"),
    using="image",
    with_payload=["image"],
    limit=1
).points[0].payload['image'])

Now let's do a reverse search for the follwing image:

In [None]:
Image.open('images/image-2.png')

In [None]:
client.query_points(
    collection_name=COLLECTION_NAME,
    query=model.get_image_embedding("images/image-2.png"),  
    # Now we are searching only among text vectors with our image query
    using="text",
    with_payload=["caption"],
    limit=1
).points[0].payload['caption']