In [10]:
!pip install qdrant-client==1.7.3 git+https://github.com/openai/CLIP.git \
opencv-python pillow matplotlib tqdm -q


  Preparing metadata (setup.py) ... [?25l[?25hdone


In [11]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [21]:
DATASET_ROOT = "/content/drive/MyDrive/ring"
OUTPUT_ROOT = "/content/drive/MyDrive/clustered_images"


In [13]:
import os, hashlib, shutil
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm

from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct

import clip

client = QdrantClient(":memory:")
print("✅ Qdrant running in RAM")

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)
clip_model.eval()

print("Device:", device)


✅ Qdrant running in RAM
Device: cpu


In [14]:
TYPES = ["ring","necklace"]
METALS = ["gold","silver","copper"]

TYPE_VOCAB = ["ring jewellery","necklace jewellery"]

METAL_PROMPTS = {
    "gold": ["gold jewelry","yellow gold ring","luxury gold jewellery"],
    "silver": ["silver jewelry","polished silver ring"],
    "copper": ["copper jewelry","reddish copper ring"]
}

tokens = clip.tokenize(TYPE_VOCAB).to(device)
with torch.no_grad():
    type_emb = clip_model.encode_text(tokens)
type_emb = type_emb / type_emb.norm(dim=-1, keepdim=True)

metal_text_emb = {}
for metal, prompts in METAL_PROMPTS.items():
    tokens = clip.tokenize(prompts).to(device)
    with torch.no_grad():
        emb = clip_model.encode_text(tokens)
    emb = emb / emb.norm(dim=-1, keepdim=True)
    metal_text_emb[metal] = emb.mean(dim=0)


In [15]:
def get_clip_embedding(image):
    image = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        emb = clip_model.encode_image(image)
    emb = emb / emb.norm(dim=-1, keepdim=True)
    return emb

def route_embedding(embedding):
    type_scores = (embedding @ type_emb.T).cpu().numpy()
    t = TYPES[np.argmax(type_scores)]

    metal_scores = {}
    for metal, text_emb in metal_text_emb.items():
        metal_scores[metal] = torch.dot(embedding, text_emb.to(device)).item()

    m = max(metal_scores, key=metal_scores.get)
    return t, m

def image_id(path):
    return int(hashlib.md5(path.encode()).hexdigest(), 16) % (10**12)


In [16]:
image_paths = []

for root, _, files in os.walk(DATASET_ROOT):
    for f in files:
        if f.lower().endswith((".jpg",".jpeg",".png")):
            image_paths.append(os.path.join(root,f))

print("Images found:", len(image_paths))

hierarchy = {(t,m): [] for t in TYPES for m in METALS}

for p in tqdm(image_paths):
    with Image.open(p) as img:
        image = img.convert("RGB")
        emb = get_clip_embedding(image)[0]
    t, m = route_embedding(emb)
    hierarchy[(t,m)].append(p)


Images found: 82


100%|██████████| 82/82 [00:22<00:00,  3.70it/s]


In [17]:
for key, paths in hierarchy.items():

    name = f"{key[0]}/{key[1]}"   # ← unchanged

    client.recreate_collection(
        collection_name=name,
        vectors_config=VectorParams(size=512, distance=Distance.COSINE)
    )

    for p in tqdm(paths):
        with Image.open(p) as img:
            image = img.convert("RGB")
            emb = get_clip_embedding(image)[0].cpu().numpy().tolist()

        client.upsert(
            collection_name=name,
            points=[PointStruct(
                id=image_id(p),
                vector=emb,
                payload={"path": p}
            )]
        )

print("✅ RAM DB ready")


100%|██████████| 52/52 [00:12<00:00,  4.22it/s]
100%|██████████| 17/17 [00:04<00:00,  4.17it/s]
100%|██████████| 4/4 [00:01<00:00,  3.14it/s]
100%|██████████| 9/9 [00:02<00:00,  3.82it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]

✅ RAM DB ready





In [22]:
for (t, m), paths in hierarchy.items():

    folder = os.path.join(OUTPUT_ROOT, t, m)
    os.makedirs(folder, exist_ok=True)

    print(f"Saving → {t}/{m} | {len(paths)}")

    for p in tqdm(paths):
        dst = os.path.join(folder, os.path.basename(p))
        shutil.copy2(p, dst)

print("✅ Images saved")


Saving → ring/gold | 52


100%|██████████| 52/52 [00:00<00:00, 71.57it/s]


Saving → ring/silver | 17


100%|██████████| 17/17 [00:00<00:00, 74.99it/s]


Saving → ring/copper | 4


100%|██████████| 4/4 [00:00<00:00, 72.98it/s]


Saving → necklace/gold | 9


100%|██████████| 9/9 [00:00<00:00, 62.20it/s]


Saving → necklace/silver | 0


0it [00:00, ?it/s]


Saving → necklace/copper | 0


0it [00:00, ?it/s]

✅ Images saved





In [19]:
def force_insert(new_path):

    image_paths = []

    if os.path.isfile(new_path):
        image_paths = [new_path]
    else:
        for root, _, files in os.walk(new_path):
            for f in files:
                if f.lower().endswith((".jpg",".jpeg",".png")):
                    image_paths.append(os.path.join(root,f))

    print("Images:", len(image_paths))

    for p in tqdm(image_paths):

        with Image.open(p) as img:
            image = img.convert("RGB")
            emb = get_clip_embedding(image)[0]

        t, m = route_embedding(emb)
        collection = f"{t}/{m}"

        # SAFE CREATE
        if collection not in [c.name for c in client.get_collections().collections]:
            client.create_collection(
                collection_name=collection,
                vectors_config=VectorParams(size=512, distance=Distance.COSINE)
            )

        client.upsert(
            collection_name=collection,
            points=[PointStruct(
                id=image_id(p),
                vector=emb.cpu().numpy().tolist(),
                payload={"path": p}
            )]
        )

        folder = os.path.join(OUTPUT_ROOT, t, m)
        os.makedirs(folder, exist_ok=True)
        shutil.copy2(p, os.path.join(folder, os.path.basename(p)))

        print("Inserted →", collection)

    print("✅ Force insert complete")


In [20]:
force_insert("/content/drive/MyDrive/new_images")


Images: 0


0it [00:00, ?it/s]

✅ Force insert complete



