In [None]:
# Download and install Dependencies.
! pip install --upgrade pymilvus openai datasets opencv-python timm einops ftfy peft tqdm
! pip install -e git+https://github.com/FlagOpen/FlagEmbedding.git

! wget https://huggingface.co/BAAI/bge-visualized/resolve/main/Visualized_base_en_v1.5.pth

In [1]:
import os
from glob import glob

import torch
from FlagEmbedding.visual.modeling import Visualized_BGE
from pymilvus import MilvusClient
from tqdm import tqdm


class Encoder:
    def __init__(self, model_name: str, model_path: str):
        self.model = Visualized_BGE(model_name_bge=model_name, model_weight=model_path)
        self.model.eval()

    def encode_query(self, image_path: str, text: str) -> list[float]:
        with torch.no_grad():
            query_emb = self.model.encode(image=image_path, text=text)
        return query_emb.tolist()[0]

    def encode_image(self, image_path: str) -> list[float]:
        with torch.no_grad():
            query_emb = self.model.encode(image=image_path)
        return query_emb.tolist()[0]
    
model_name = "BAAI/bge-base-en-v1.5"
model_path = "./Visualized_base_en_v1.5.pth"  # Change to your own value if using a different model path
encoder = Encoder(model_name, model_path)

  from pandas.core import (


In [6]:
# load data

data_dir = (
    "/home/stardust/Downloads/test_images/"  # Change to your own value if using a different data directory
)

image_list = glob(
    os.path.join(data_dir, "images", "*.jpg")
)  # We will only use images ending with ".jpg"
image_dict = {}
for image_path in tqdm(image_list, desc="Generating image embeddings: "):
    try:
        image_dict[image_path] = encoder.encode_image(image_path)
    except Exception:
        print(f"Failed to generate embedding for {image_path}. Skipped.")
        continue
print("Number of encoded images:", len(image_dict))

  self.load_state_dict(torch.load(model_weight, map_location='cpu'))
Generating image embeddings: 100%|██████████| 3000/3000 [00:41<00:00, 71.55it/s]

Number of encoded images: 3000





In [7]:

# Insert into Milvus

dim = len(list(image_dict.values())[0])
collection_name = "multimodal_rag_demo"

milvus_client = MilvusClient(uri="./milvus_demo.db")

milvus_client.create_collection(
    collection_name=collection_name,
    auto_id=True,
    dimension=dim,
    enable_dynamic_field=True,
)

milvus_client.insert(
    collection_name=collection_name,
    data=[{"image_path": k, "vector": v} for k, v in image_dict.items()],
)


DEBUG:pymilvus.milvus_client.milvus_client:Created new connection using: a5e253c6a3e24a4c9cb5140c8bcaa3c8
DEBUG:pymilvus.milvus_client.milvus_client:Successfully created collection: multimodal_rag_demo
DEBUG:pymilvus.milvus_client.milvus_client:Successfully created an index on collection: multimodal_rag_demo


{'insert_count': 3000, 'ids': [452766870258581504, 452766870258581505, 452766870258581506, 452766870258581507, 452766870258581508, 452766870258581509, 452766870258581510, 452766870258581511, 452766870258581512, 452766870258581513, 452766870258581514, 452766870258581515, 452766870258581516, 452766870258581517, 452766870258581518, 452766870258581519, 452766870258581520, 452766870258581521, 452766870258581522, 452766870258581523, 452766870258581524, 452766870258581525, 452766870258581526, 452766870258581527, 452766870258581528, 452766870258581529, 452766870258581530, 452766870258581531, 452766870258581532, 452766870258581533, 452766870258581534, 452766870258581535, 452766870258581536, 452766870258581537, 452766870258581538, 452766870258581539, 452766870258581540, 452766870258581541, 452766870258581542, 452766870258581543, 452766870258581544, 452766870258581545, 452766870258581546, 452766870258581547, 452766870258581548, 452766870258581549, 452766870258581550, 452766870258581551, 452766870

In [8]:
# Multimodal Search with Generative Reranker

query_image = os.path.join(
    data_dir, "image.png"
)  # Change to your own query image path
query_text = "find the waiting-benches of this style"

query_vec = encoder.encode_query(image_path=query_image, text=query_text)

search_results = milvus_client.search(
    collection_name=collection_name,
    data=[query_vec],
    output_fields=["image_path"],
    limit=9,  # Max number of search results to return
    search_params={"metric_type": "COSINE", "params": {}},  # Search parameters
)[0]

retrieved_images = [hit.get("entity").get("image_path") for hit in search_results]
print(retrieved_images)

['/home/stardust/Downloads/test_images/images/cam2.1716384066.989815000.jpg', '/home/stardust/Downloads/test_images/images/cam2.1716384062.989702000.jpg', '/home/stardust/Downloads/test_images/images/cam2.1716382373.31635000.jpg', '/home/stardust/Downloads/test_images/images/cam2.1716382521.968060000.jpg', '/home/stardust/Downloads/test_images/images/cam2.1716384122.3443000.jpg', '/home/stardust/Downloads/test_images/images/cam2.1716382459.54305000.jpg', '/home/stardust/Downloads/test_images/images/cam2.1716384057.981531000.jpg', '/home/stardust/Downloads/test_images/images/cam2.1716382375.971737000.jpg', '/home/stardust/Downloads/test_images/images/cam2.1716382517.63982000.jpg']
