# Image Search


In [1]:
from dotenv import load_dotenv
import os
import google.generativeai as genai

genai.configure(api_key=os.getenv("GEMINI_API_KEY"))

load_dotenv()


True

In [5]:
from pymilvus import MilvusClient


milvus_client = MilvusClient(uri=os.getenv("MILVUS_ENDPOINT"))


## Indexing

```
git clone https://github.com/FlagOpen/FlagEmbedding.git
cd FlagEmbedding/research/visual_bge
pip install -e .
!curl -O https://huggingface.co/BAAI/bge-visualized/resolve/main/Visualized_base_en_v1.5.pth
```

In [2]:
import torch
from FlagEmbedding.research.visual_bge.modeling import Visualized_BGE


class Encoder:
    def __init__(self, model_name: str, model_path: str):
        self.model = Visualized_BGE(model_name_bge=model_name, model_weight=model_path)
        self.model.eval()

    def encode_query(self, image_path: str, text: str) -> list[float]:
        with torch.no_grad():
            query_emb = self.model.encode(image=image_path, text=text)
        return query_emb.tolist()[0]

    def encode_image(self, image_path: str) -> list[float]:
        with torch.no_grad():
            query_emb = self.model.encode(image=image_path)
        return query_emb.tolist()[0]
    
    def encode_text(self, text: str) -> list[float]:
        with torch.no_grad():
            query_emb = self.model.encode(text=text)
        return query_emb.tolist()[0]


model_name = "BAAI/bge-base-en-v1.5"
model_path = "./Visualized_base_en_v1.5.pth"
encoder = Encoder(model_name, model_path)

  self.load_state_dict(torch.load(model_weight, map_location='cpu'))


### Get Images

In [3]:
import os
from tqdm import tqdm
from glob import glob


data_dir = "./data"
image_list = glob(os.path.join(data_dir, "images", "*.png"))

image_list[:10]

['./data/images/2024_03_SIMA-1_jpg.png',
 './data/images/ploads_2021_03_Batch-Ad-March-3rd-1_png.png',
 './data/images/2021_08_ID-By-Eyeglasses-1_gif.png',
 './data/images/2024_07_DL_AI-Ad--6-_png.png',
 './data/images/2021_06_TheBatch-WorkingAIOmoju_jpeg.png',
 './data/images/2023_11_unnamed--30--2_jpg.png',
 './data/images/2025_01_The-Batch-ads-and-exclusive-banners--5--1_png.png',
 './data/images/ploads_2021_01_ezgif_com-optimize207_gif.png',
 './data/images/2024_12_unnamed--37-_gif.png',
 './data/images/2022_09_e5e01574-4b75-48c5-a144-8335f407450e_png.png']

### Get image embeddings

In [4]:
image_dict = {}

for image_path in tqdm(image_list, desc="Generating image embeddings: "):
    try:
        image_dict[image_path] = encoder.encode_image(image_path)
    except Exception as e:
        print(f"Failed to generate embedding for {image_path}. Skipped.")
        continue
print("Number of encoded images:", len(image_dict))

Generating image embeddings: 100%|██████████| 3641/3641 [10:59<00:00,  5.52it/s]

Number of encoded images: 3641





### Insert into Milvus

In [None]:
collection_name = "the_batch_image_rag_v2"

# if milvus_client.has_collection(collection_name):
#     milvus_client.drop_collection(collection_name)

In [7]:
dim = len(list(image_dict.values())[0])

milvus_client.create_collection(
    collection_name=collection_name,
    auto_id=True,
    dimension=dim,
    enable_dynamic_field=True,
)

milvus_client.insert(
    collection_name=collection_name,
    data=[{"image_path": k, "vector": v} for k, v in image_dict.items()],
)

{'insert_count': 3641, 'ids': [455940887088726016, 455940887088726017, 455940887088726018, 455940887088726019, 455940887088726020, 455940887088726021, 455940887088726022, 455940887088726023, 455940887088726024, 455940887088726025, 455940887088726026, 455940887088726027, 455940887088726028, 455940887088726029, 455940887088726030, 455940887088726031, 455940887088726032, 455940887088726033, 455940887088726034, 455940887088726035, 455940887088726036, 455940887088726037, 455940887088726038, 455940887088726039, 455940887088726040, 455940887088726041, 455940887088726042, 455940887088726043, 455940887088726044, 455940887088726045, 455940887088726046, 455940887088726047, 455940887088726048, 455940887088726049, 455940887088726050, 455940887088726051, 455940887088726052, 455940887088726053, 455940887088726054, 455940887088726055, 455940887088726056, 455940887088726057, 455940887088726058, 455940887088726059, 455940887088726060, 455940887088726061, 455940887088726062, 455940887088726063, 455940887

In [1]:
image_dict

NameError: name 'image_dict' is not defined

## Search

In [9]:
query_text = "group of friends"

query_vec = encoder.encode_text(text=query_text)

search_results = milvus_client.search(
    collection_name=collection_name,
    data=[query_vec],
    output_fields=["image_path"],
    limit=3,
    search_params={"metric_type": "COSINE", "params": {}},
)[0]

retrieved_images = [hit.get("entity").get("image_path") for hit in search_results]
print(retrieved_images)

['./data/images/2022_07_Screen-Shot-2022-06-22-at-9--1-_jpg.png', './data/images/2022_06_Screen-Shot-2022-06-22-at-9--1-_jpg.png', './data/images/0Shot%202022-03-01%20at%203_35_19%20PM_png?upscale=true&width=1200&upscale=true&name=Screen%20Shot%202022-03-01%20at%203_35_19%20PM_png.png']
