# Image Search


In [None]:
import pandas as pd
from dotenv import load_dotenv
import os
import json

load_dotenv()


In [3]:
import google.generativeai as genai

genai.configure(api_key=os.getenv("GEMINI_API_KEY"))

In [4]:
pd.set_option('display.max_colwidth', 100)

In [5]:
# input_filename = 'data/single_articles_cleaned.csv'
input_filename = 'data/weekly_articles_cleaned.csv'
raw_df = pd.read_csv(input_filename)

In [None]:
df = raw_df.copy()
df.shape

In [None]:
df.tail(5)

# Indexing

```
git clone https://github.com/FlagOpen/FlagEmbedding.git
cd FlagEmbedding/research/visual_bge
pip install -e .
```

In [None]:
!curl -O https://huggingface.co/BAAI/bge-visualized/resolve/main/Visualized_base_en_v1.5.pth

In [11]:
import torch
from FlagEmbedding.research.visual_bge.modeling import Visualized_BGE


class Encoder:
    def __init__(self, model_name: str, model_path: str):
        self.model = Visualized_BGE(model_name_bge=model_name, model_weight=model_path)
        self.model.eval()

    def encode_query(self, image_path: str, text: str) -> list[float]:
        with torch.no_grad():
            query_emb = self.model.encode(image=image_path, text=text)
        return query_emb.tolist()[0]

    def encode_image(self, image_path: str) -> list[float]:
        with torch.no_grad():
            query_emb = self.model.encode(image=image_path)
        return query_emb.tolist()[0]
    
    def encode_text(self, text: str) -> list[float]:
        with torch.no_grad():
            query_emb = self.model.encode(text=text)
        return query_emb.tolist()[0]


model_name = "BAAI/bge-base-en-v1.5"
model_path = "./Visualized_base_en_v1.5.pth"  # Change to your own value if using a different model path
encoder = Encoder(model_name, model_path)

  self.load_state_dict(torch.load(model_weight, map_location='cpu'))


In [None]:

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")

processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") 

## Image list

In [3]:
import os
from tqdm import tqdm
from glob import glob


# Generate embeddings for the image dataset
data_dir = (
    "./data"  # Change to your own value if using a different data directory
)
image_list = glob(
    os.path.join(data_dir, "images", "*.png")
)

image_list[:10]

['./data/images/image_2025_01_unnamed--36-.png.png',
 './data/images/image_2024_12_unnamed--27--1.png.png',
 './data/images/image_2021_06_Andrew20Letter-1-1.gif.png',
 './data/images/image_2025_01_unnamed--46--1.gif.png',
 './data/images/image_2021_08_AI-in-Regions-Rich-and-Poor-1.gif.png',
 './data/images/image_2025_01_unnamed--37-.png.png',
 './data/images/image_ploads_2021_01_Gender20ASPECT201.png.png',
 './data/images/image_2021_09_Perceptrons-Are-All-You-Need-1.gif.png',
 './data/images/image_2024_12_unnamed--33-.png.png',
 './data/images/image_2023_02_unnamed--14-.jpg.png']

In [4]:
image_dict = {}

for image_path in tqdm(image_list, desc="Generating image embeddings: "):
    try:
        image_dict[image_path] = encoder.encode_image(image_path)
    except Exception as e:
        print(f"Failed to generate embedding for {image_path}. Skipped.")
        continue
print("Number of encoded images:", len(image_dict))

Generating image embeddings: 100%|██████████| 3641/3641 [12:05<00:00,  5.02it/s]

Number of encoded images: 3641





In [8]:
from pymilvus import MilvusClient


dim = len(list(image_dict.values())[0])
collection_name = "the_batch_image_rag"

# Connect to Milvus client given URI
milvus_client = MilvusClient(uri="./thebatch_text.db")

# Create Milvus Collection
# By default, vector field name is "vector"
milvus_client.create_collection(
    collection_name=collection_name,
    auto_id=True,
    dimension=dim,
    enable_dynamic_field=True,
)

# Insert data into collection
milvus_client.insert(
    collection_name=collection_name,
    data=[{"image_path": k, "vector": v} for k, v in image_dict.items()],
)

{'insert_count': 3641, 'ids': [455903994051573988, 455903994051573989, 455903994051573990, 455903994051573991, 455903994051573992, 455903994051573993, 455903994051573994, 455903994051573995, 455903994051573996, 455903994051573997, 455903994051573998, 455903994051573999, 455903994051574000, 455903994051574001, 455903994051574002, 455903994051574003, 455903994051574004, 455903994051574005, 455903994051574006, 455903994051574007, 455903994051574008, 455903994051574009, 455903994051574010, 455903994051574011, 455903994051574012, 455903994051574013, 455903994051574014, 455903994051574015, 455903994051574016, 455903994051574017, 455903994051574018, 455903994051574019, 455903994051574020, 455903994051574021, 455903994051574022, 455903994051574023, 455903994051574024, 455903994051574025, 455903994051574026, 455903994051574027, 455903994051574028, 455903994051574029, 455903994051574030, 455903994051574031, 455903994051574032, 455903994051574033, 455903994051574034, 455903994051574035, 455903994

In [12]:
query_text = "walking robot"

# Generate query embedding given image and text instructions
query_vec = encoder.encode_text(text=query_text)

search_results = milvus_client.search(
    collection_name=collection_name,
    data=[query_vec],
    output_fields=["image_path"],
    limit=9,  # Max number of search results to return
    search_params={"metric_type": "COSINE", "params": {}},  # Search parameters
)[0]

retrieved_images = [hit.get("entity").get("image_path") for hit in search_results]
print(retrieved_images)

['./data/images/image_2023_08_fdsfd-2.png.png', './data/images/image_2025_01_unnamed--45--1.gif.png', './data/images/image_2025_01_unnamed--45-.gif.png', './data/images/image_2022_10_51dc4ce5-5182-4040-b8a4-a9ad3867b6b8--1-.png.png', './data/images/image_2022_09_51dc4ce5-5182-4040-b8a4-a9ad3867b6b8.png.png', './data/images/image_ploads_2021_01_Knightscope.gif.png', './data/images/image_2021_07_What-the-Watchbot-Sees-1.gif.png', './data/images/image_2021_07_A-Robot-in-Every-Kitchen-1.gif.png', './data/images/image_ploads_2021_01_Toyota20Robot20Resized.gif.png']


## Gemini Text Embeddings

In [52]:
doc_embeddings = genai.embed_content(
    model="models/text-embedding-004", content=df.text
)["embedding"]

In [53]:
with open('embeddings_weekly.txt', 'w') as file:
    file.write(str(doc_embeddings))

In [54]:
with open('embeddings_weekly.txt', 'r') as file:
    doc_embeddings = json.load(file)

## Milvus Configuration

In [63]:
from pymilvus import MilvusClient

milvus_client = MilvusClient(uri="./thebatch_text.db")

collection_name = "the_batch_text_rag"

In [26]:
# if milvus_client.has_collection(collection_name):
#     milvus_client.drop_collection(collection_name)

In [27]:
milvus_client.create_collection(
    collection_name=collection_name,
    dimension=768,  # "models/text-embedding-004" dimension
    vector_field_name= "text_vector",
    metric_type="IP",  # Inner product distance
    consistency_level="Strong",  # Strong consistency level
)

### Indexing data

In [56]:
data = []

for index, row in df.iterrows():
    data.append({
        "id": index,
        "text_vector": doc_embeddings[index],
        "text": row.text,
        "article_url": row.article_url, 
        "image_url": row.images # .image_header_cleaned
        })

In [None]:
milvus_client.insert(collection_name=collection_name, data=data)

# Index Search

In [59]:
question = "What can you tell me about deepseek?"

In [60]:
question_embedding = genai.embed_content(
    model="models/text-embedding-004", content=question
)["embedding"]

In [64]:
search_res = milvus_client.search(
    collection_name=collection_name,
    data=[question_embedding],
    limit=1,  # Return top 1 results
    search_params={"metric_type": "IP", "params": {}},  # Inner product distance
    output_fields=["text", "article_url", "image_url"],  # Return the text field
)

In [None]:
retrieved_lines_with_distances = [
    (
        res["distance"],
        res["entity"]["article_url"],
        res["entity"]["image_url"],
        res["entity"]["text"].replace(u'\u2019', u'\'')
    ) for res in search_res[0]
]
print(json.dumps(retrieved_lines_with_distances, indent=4))

# LLM Answering

In [None]:
context = "\n".join(
    [line_with_distance[3] for line_with_distance in retrieved_lines_with_distances]
)
print(context)

In [39]:
SYSTEM_PROMPT = """
Human: You are an AI assistant. You are able to find answers to the questions from the articles provided.
"""
USER_PROMPT = f"""
Use the following pieces of information enclosed in <context> tags to provide an answer to the question enclosed in <question> tags.
<context>
{context}
</context>
<question>
{question}
</question>
"""

In [40]:
gemini_model = genai.GenerativeModel(
    "gemini-2.0-flash-lite-preview-02-05", system_instruction=SYSTEM_PROMPT
)


In [None]:
gemini_model

In [None]:
response = gemini_model.generate_content(USER_PROMPT)
print(response.text)

In [None]:
# ULEPSZENIA: chunkowanie dokumentów