In [1]:
from qdrant_client import QdrantClient
from qdrant_client.http import models

In [2]:
client = QdrantClient(host="localhost", port=6333)

In [3]:
collections = client.get_collections()
print("Kết nối thành công! Các collections:", collections)

Kết nối thành công! Các collections: collections=[CollectionDescription(name='shape_collection'), CollectionDescription(name='image_collection'), CollectionDescription(name='text_collection')]


In [4]:
# Tạo collection cho text embeddings
client.recreate_collection(
    collection_name="text_collection",
    vectors_config=models.VectorParams(
        size=1280,  # Kích thước vector của bạn là 1280
        distance=models.Distance.COSINE
    )
)

  client.recreate_collection(


True

In [6]:
# Tạo collection cho image embeddings
client.recreate_collection(
    collection_name="image_collection",
    vectors_config=models.VectorParams(
        size=1280,
        distance=models.Distance.COSINE
    )
)

  client.recreate_collection(


True

In [7]:
# Tạo collection cho shape embeddings
client.recreate_collection(
    collection_name="shape_collection",
    vectors_config=models.VectorParams(
        size=1280,
        distance=models.Distance.COSINE
    )
)

  client.recreate_collection(


True

In [4]:
import os
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client import models
import glob
from tqdm import tqdm

In [5]:
# Định nghĩa đường dẫn đến dữ liệu
BASE_DIR = "D:/private" # Thay đổi thành đường dẫn chính nếu cần
TEXT_EMBED_DIR = os.path.join(BASE_DIR, "text_embed")
IMAGE_EMBED_DIR = os.path.join(BASE_DIR, "image_embed")
SHAPE_EMBED_DIR = os.path.join(BASE_DIR, "objects_dataset_npy_10000/objects")
OBJECTS_DIR = os.path.join(BASE_DIR, "objects_dataset/objects")
AUGMENT_DIR = os.path.join(BASE_DIR, "augment2d_dataset/objects")
SCENES_DIR = os.path.join(BASE_DIR, "scenes")

In [11]:
def upload_text_embeddings():
    uuid_dirs = [d for d in os.listdir(TEXT_EMBED_DIR) if os.path.isdir(os.path.join(TEXT_EMBED_DIR, d))]

    for uuid in tqdm(uuid_dirs):
        embed_path = os.path.join(TEXT_EMBED_DIR, uuid, "text_embed.npy")
        query_path = os.path.join(SCENES_DIR, uuid, "query.txt")
        if os.path.exists(embed_path):
            embedding = np.load(embed_path)
            with open(query_path, "r") as f:
                query = f.read()


            client.upsert(
                collection_name="text_collection",
                points=[
                    models.PointStruct(
                        id=uuid.replace("-", ""),
                        vector=embedding.flatten().tolist(),
                        payload={
                            "uuid": uuid,
                            "type": "text",
                            "query": query
                        }
                    )
                ]
            )


In [12]:
upload_text_embeddings()

100%|██████████| 50/50 [00:02<00:00, 21.41it/s]


In [19]:
def upload_image_embeddings():
    uuid_dirs = [
        d for d in os.listdir(IMAGE_EMBED_DIR)
        if os.path.isdir(os.path.join(IMAGE_EMBED_DIR, d))
    ]
    point_id = 0

    for uuid in uuid_dirs:
        uuid_dir = os.path.join(IMAGE_EMBED_DIR, uuid)
        embed_files = glob.glob(os.path.join(uuid_dir, "*.npy"))

        for embed_file in tqdm(embed_files):
            file_name = os.path.basename(embed_file)
            origin_path = os.path.join(OBJECTS_DIR, uuid, "image.jpg")
            if file_name[-5] == "1": image_path = origin_path
            else: image_path = os.path.join(AUGMENT_DIR, file_name)


            embedding = np.load(embed_file)

            # unique_id = file_name.replace(".npy", "").replace("-", "").replace("_", "img")
            unique_id = f"{uuid.replace('-', '')}_{os.path.splitext(file_name)[0]}"

            client.upsert(
                collection_name="image_collection",
                points=[
                    models.PointStruct(
                        id=point_id,
                        vector=embedding.flatten().tolist(),
                        payload={
                            "uuid": uuid,
                            "type": "image",
                            "image_path": image_path,
                            "origin_path": origin_path,
                            "file_name": file_name,
                        }
                    )
                ]
            )
            point_id += 1


In [20]:
upload_image_embeddings()

100%|██████████| 13/13 [00:00<00:00, 36.41it/s]
100%|██████████| 13/13 [00:00<00:00, 47.75it/s]
100%|██████████| 13/13 [00:00<00:00, 46.30it/s]
100%|██████████| 13/13 [00:00<00:00, 49.11it/s]
100%|██████████| 13/13 [00:00<00:00, 41.34it/s]
100%|██████████| 13/13 [00:00<00:00, 42.07it/s]
100%|██████████| 13/13 [00:00<00:00, 42.19it/s]
100%|██████████| 13/13 [00:00<00:00, 41.50it/s]
100%|██████████| 13/13 [00:00<00:00, 44.89it/s]
100%|██████████| 13/13 [00:00<00:00, 51.46it/s]
100%|██████████| 13/13 [00:00<00:00, 41.03it/s]
100%|██████████| 13/13 [00:00<00:00, 41.43it/s]
100%|██████████| 13/13 [00:00<00:00, 45.92it/s]
100%|██████████| 13/13 [00:00<00:00, 35.63it/s]
100%|██████████| 13/13 [00:00<00:00, 40.66it/s]
100%|██████████| 13/13 [00:00<00:00, 46.76it/s]
100%|██████████| 13/13 [00:00<00:00, 38.96it/s]
100%|██████████| 13/13 [00:00<00:00, 59.71it/s]
100%|██████████| 13/13 [00:00<00:00, 44.04it/s]
100%|██████████| 13/13 [00:00<00:00, 48.14it/s]
100%|██████████| 13/13 [00:00<00:00, 45.

In [15]:
def upload_shape_embeddings():
    uuid_dirs = [
        d for d in os.listdir(SHAPE_EMBED_DIR)
        if os.path.isdir(
            os.path.join(SHAPE_EMBED_DIR, d)
        )
    ]

    for uuid in tqdm(uuid_dirs):
        embed_path = os.path.join(SHAPE_EMBED_DIR, uuid, "shape_embedding.npy")

        if os.path.exists(embed_path):
            embedding = np.load(embed_path)

            client.upsert(
                collection_name="shape_collection",
                points=[
                    models.PointStruct(
                        id=uuid.replace("-", ""),
                        vector=embedding.flatten().tolist(),
                        payload={
                            "uuid": uuid,
                            "type": "shape",
                            "model_path": os.path.join(SHAPE_EMBED_DIR, uuid, "normalized_model.npy")
                        }
                    )
                ]
            )

In [16]:
upload_shape_embeddings()

100%|██████████| 50/50 [00:01<00:00, 25.59it/s]


In [17]:
# Test thu truy van
scroll_result = client.scroll(
    collection_name="text_collection",
    limit=2,
    with_vectors=True
)
vector = scroll_result[0][1].vector
query = scroll_result[0][1].payload["query"]
uuid = scroll_result[0][1].payload["uuid"]
query, uuid

('a twin-sized bed frame with a folded-down trundle bed, featuring dark brown wooden legs and slats, a light gray fabric headboard and footboard, and two matching gray upholstered seat cushions with decorative buttons.',
 '08df38e7-b9ec-40d1-8652-b1857959a6c7')

In [18]:
search_result = client.search(
    collection_name="shape_collection",
    query_vector=vector,
    limit=5
)

for i, hit in enumerate(search_result):
    print(f"{i}. ID: {hit.id}, Score: {hit.score}")

print("-" * 20)
search_result = client.search(
    collection_name="image_collection",
    query_vector=vector,
    limit=5
)

for i, hit in enumerate(search_result):
    print(f"{i}. ID: {hit.payload['uuid']}, Score: {hit.score}")

0. ID: ec366e16-681d-4621-ac33-56536ce237a1, Score: 0.11819115
1. ID: 27c0c74b-d03c-476c-a8e0-01d2546fc894, Score: 0.11192697
2. ID: a63e6333-b3b8-4487-b3ae-7c8c5e3092e8, Score: 0.101223946
3. ID: 24af0746-d902-450e-b0b5-98ab28db8c2d, Score: 0.101177335
4. ID: 94e953b8-4c64-4475-accb-335b1a120e48, Score: 0.099268466
--------------------
0. ID: a63e6333-b3b8-4487-b3ae-7c8c5e3092e8, Score: 0.3459917
1. ID: a63e6333-b3b8-4487-b3ae-7c8c5e3092e8, Score: 0.3375696
2. ID: a63e6333-b3b8-4487-b3ae-7c8c5e3092e8, Score: 0.3360669
3. ID: a63e6333-b3b8-4487-b3ae-7c8c5e3092e8, Score: 0.33280343
4. ID: a63e6333-b3b8-4487-b3ae-7c8c5e3092e8, Score: 0.32907182


  search_result = client.search(
  search_result = client.search(
