In [5]:
import pandas as pd
import uuid
import torch
import clip
from qdrant_client import QdrantClient
from dotenv import load_dotenv
import os
from tqdm import tqdm
import json
from qdrant_client.models import PointStruct
from PIL import Image

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# model.eval()

In [8]:
load_dotenv()

QDRANT_KEY = os.getenv("QDRANT_KEY")
QDRANT_URL = os.getenv("QDRANT_URL")

client = QdrantClient(
    url=QDRANT_URL,
    api_key=QDRANT_KEY
)

COLLECTION_NAME = "GNOSIS"


In [9]:
def load_image_safe(path):
    try:
        img = Image.open(path).convert("RGB")
        return img
    except Exception as e:
        print("❌ Failed to load:", path, e)
        return None


In [10]:
IMAGE_ROOT = r"D:/STUDY/PROJECTS/GNOSIS/Resources/Images"

In [11]:
BATCH_SIZE = 32
UPLOAD_BATCH = 256

image_batch = []
meta_batch = []
points_buffer = []


In [12]:
for label in ["real", "fake"]:
    folder = os.path.join(IMAGE_ROOT, label)

    for fname in tqdm(os.listdir(folder), desc=f"Processing {label}"):

        path = os.path.join(folder, fname)
        img = load_image_safe(path)
        if img is None:
            continue

        image_batch.append(preprocess(img))

        meta_batch.append({
            "modality": "image",
            "label": label,
            "source_path": path,
            "filename": fname,
            "domain": "medical",   # or political later
            "dataset": "Medical_Image_Demo",
        })

        # =========================
        # When batch full → embed
        # =========================
        if len(image_batch) >= BATCH_SIZE:

            images_tensor = torch.stack(image_batch).to(device)

            with torch.no_grad():
                vecs = model.encode_image(images_tensor).cpu().numpy()

            # Build points
            for vec, meta in zip(vecs, meta_batch):
                point = PointStruct(
                    id=str(uuid.uuid4()),
                    vector={
                        "vision": vec.tolist()
                    },
                    payload=meta
                )
                points_buffer.append(point)

            image_batch = []
            meta_batch = []

            # =========================
            # Upload chunk
            # =========================
            if len(points_buffer) >= UPLOAD_BATCH:
                client.upsert(collection_name=COLLECTION_NAME, points=points_buffer)
                print("⬆️ Uploaded", len(points_buffer), "image points")
                points_buffer = []


Processing real:  35%|███▍      | 267/765 [00:09<00:54,  9.17it/s]

⬆️ Uploaded 256 image points


Processing real:  68%|██████▊   | 517/765 [00:17<00:29,  8.45it/s]

⬆️ Uploaded 256 image points


Processing real: 100%|██████████| 765/765 [00:23<00:00, 32.02it/s]
Processing fake:   1%|          | 10/897 [00:02<03:07,  4.74it/s]

⬆️ Uploaded 256 image points


Processing fake:  30%|██▉       | 265/897 [00:08<00:29, 21.22it/s]

⬆️ Uploaded 256 image points


Processing fake:  58%|█████▊    | 520/897 [00:15<00:50,  7.53it/s]

⬆️ Uploaded 256 image points


Processing fake:  87%|████████▋ | 779/897 [00:23<00:09, 12.01it/s]

⬆️ Uploaded 256 image points


Processing fake: 100%|██████████| 897/897 [00:25<00:00, 35.17it/s]


In [13]:
if image_batch:
    images_tensor = torch.stack(image_batch).to(device)

    with torch.no_grad():
        vecs = model.encode_image(images_tensor).cpu().numpy()

    for vec, meta in zip(vecs, meta_batch):
        point = PointStruct(
            id=str(uuid.uuid4()),
            vector={
                "vision": vec.tolist()
            },
            payload=meta
        )
        points_buffer.append(point)

if points_buffer:
    client.upsert(collection_name=COLLECTION_NAME, points=points_buffer)
    print("⬆️ Final upload:", len(points_buffer))

print("✅ Image ingestion complete")


⬆️ Final upload: 126
✅ Image ingestion complete


In [15]:
query_img = Image.open("test.jpeg").convert("RGB")
img_tensor = preprocess(query_img).unsqueeze(0).to(device)

with torch.no_grad():
    qvec = model.encode_image(img_tensor).cpu().numpy()[0]

result = client.query_points(
    collection_name=COLLECTION_NAME,
    query=qvec.tolist(),
    using="vision",
    limit=5,
    with_payload=True
)

for hit in result.points:
    print(hit.score, hit.payload.get("filename"))


0.99999905 Ags4N6XvgDwDCU8gxwcpBmupXbkH6akCetnwNWxaIoWF.jpeg
0.99999905 Ags4N6XvgDwDCU8gxwcpBmupXbkH6akCetnwNWxaIoWF.jpeg
0.92613494 AiYPWxuanAdVGba89NylTcRpNXWK8zfK39ouZQIRoXZs.jpeg
0.9221602 Ap5wy9cEivyD5Sh6Y2OhyaKBdaLLOeaDXICgPkK5KaHy.jpeg
0.91311544 ArbNSctglNjG6bALN4HGcZe8nZOImUqmZ6gLe1DeR7M0.jpeg
