In [None]:
#https://towardsdev.com/building-visionrag-ai-powered-image-search-with-llama-3-2-qdrant-and-litserve-10bf22df5d41
#https://medium.com/kx-systems/guide-to-multimodal-rag-for-images-and-text-10dab36e3117

In [1]:
import os
import base64
import requests
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct
from sentence_transformers import SentenceTransformer
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch
import uuid

In [2]:
os.environ["OPENAI_API_KEY"] = "sk-xxxx" # Replace with your OpenAI API key

In [3]:
# ========= 1. GPT-4o OCR =========
def gpt4o_ocr(image_path, api_key=None):
    # Always get the API key from argument or environment
    api_key = api_key or os.environ.get("OPENAI_API_KEY")
    if not api_key:
        raise ValueError("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
    endpoint = "https://api.openai.com/v1/chat/completions"


    with open(image_path, "rb") as f:
        img_b64 = base64.b64encode(f.read()).decode("utf-8")

    headers = {"Content-Type": "application/json", "Authorization": f"Bearer  {api_key}"}
    payload = {
        "model": "gpt-4o-mini",   # or "gpt-4o"
        "messages": [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "Tell me the brand of this battery. Just give me the brand name. Do not give other explanation and text. If you cannot recogize, Just tell us NA:"},
                    {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}},
                ],
            }
        ],
        "max_tokens": 512,
    }
    resp = requests.post(endpoint, headers=headers, json=payload)
    resp.raise_for_status()
    return resp.json()["choices"][0]["message"]["content"]

In [4]:
# ========= 2. Embedding =========
device = "cuda" if torch.cuda.is_available() else "cpu"

# Text embedding (GTE-base)
text_model = SentenceTransformer("thenlper/gte-base", device=device)

In [5]:
# Image embedding (CLIP)
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [6]:
def embed_text(text):
    return text_model.encode(text, convert_to_numpy=True)

def embed_image(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = clip_processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        emb = clip_model.get_image_features(**inputs)
    return emb.squeeze().cpu().numpy()

In [16]:
# ========= 3. Qdrant =========
client = QdrantClient(path="/qdrant_data")  

In [17]:
# Install collection, including text_embedding + image_embedding
client.recreate_collection(
    collection_name="multimodal_docs",
    vectors_config={
        "text_embedding": VectorParams(size=768, distance=Distance.COSINE),
        "image_embedding": VectorParams(size=512, distance=Distance.COSINE),
    },
)

In [18]:
# ========= 4. Put images into ocr, text embedding, image embedding =========
def insert_image_doc(image_path):
    ocr_text = gpt4o_ocr(image_path)
    text_vec = embed_text(ocr_text)
    img_vec = embed_image(image_path)

    point_id = str(uuid.uuid4())
    client.upsert(
        collection_name="multimodal_docs",
        points=[
            PointStruct(
                id=point_id,
                vector={
                    "text_embedding": text_vec.tolist(),
                    "image_embedding": img_vec.tolist(),
                },
                payload={
                    "source": image_path,
                    "ocr_text": ocr_text,
                },
            )
        ],
    )
    print(f"Inserted {image_path} with OCR: {ocr_text[:50]}...")


In [19]:
# ========= 5. Query =========
def search_by_text(query, top_k=3):
    q_vec = embed_text(query)
    results = client.search(
        collection_name="multimodal_docs",
        query_vector=("text_embedding", q_vec.tolist()),
        limit=top_k,
    )
    return results

def search_by_image(image_path, top_k=3):
    q_vec = embed_image(image_path)
    results = client.search(
        collection_name="multimodal_docs",
        query_vector=("image_embedding", q_vec.tolist()),
        limit=top_k,
    )
    return results

In [20]:
from glob import glob

# ========= 1. Read all images in the folder =========
def insert_folder_images(folder_path):
    image_paths = glob(os.path.join(folder_path, "*.png")) + \
                  glob(os.path.join(folder_path, "*.jpg")) + \
                  glob(os.path.join(folder_path, "*.jpeg"))

    for img_path in image_paths:
        try:
            insert_image_doc(img_path)
        except Exception as e:
            print(f"❌ Failed {img_path}: {e}")

In [21]:
# ========= 2. Query function: text, image, or hybrid =========
def rag_query(query_text=None, query_image=None, mode="text", top_k=3):
    if mode == "text":
        return search_by_text(query_text, top_k)
    elif mode == "image":
        return search_by_image(query_image, top_k)
    elif mode == "hybrid":
        return search_hybrid(query_text=query_text, query_image=query_image, alpha=0.7, top_k=top_k)
    else:
        raise ValueError("mode must be 'text', 'image', or 'hybrid'")

In [23]:
if __name__ == "__main__":
    folder = "image/rag"  # Image folder path
    insert_folder_images(folder)

In [30]:
print("\n🔍 Text Query RAG:")
hits = rag_query(query_text="Give me panasonic battery images", mode="text", top_k=3)
for h in hits:
    print(h.payload)

In [29]:
print("\n🖼 Image Query RAG:")
hits = rag_query(query_image="image/Panasonic-Alkaline-AA.jpg", mode="image", top_k=3)

for h in hits:
    print(h.payload)