In [None]:
!pip install datasets
!pip install sentence_transformers
!pip install scikit-learn
!pip install qdrant-client sentence-transformers

In [None]:
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import numpy as np
from google.colab import userdata
from qdrant_client import QdrantClient
from qdrant_client.http import models
from tqdm import tqdm
import uuid

In [None]:
ds = load_dataset("ruslanmv/ai-medical-chatbot")

In [None]:
ds["train"][0]["Description"]

In [None]:
qa_pairs = [(entry["Patient"], entry["Doctor"]) for entry in ds["train"]]
questions, answers = zip(*qa_pairs)

In [None]:
# Initialize model with GPU
embed_model = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")

In [None]:
# ---- Config ----
QDRANT_HOST = userdata.get("QDRANT_HOST")
QDRANT_API_KEY = userdata.get("QDRANT_API_KEY")
COLLECTION_NAME = "ruslanmv-ai-medical-chatbot"

In [None]:
# ---- Initialize Qdrant Client ----
client = QdrantClient(
    url=QDRANT_HOST,
    api_key=QDRANT_API_KEY,
)

In [None]:
# ---- Create Collection ----
if not client.collection_exists(collection_name=COLLECTION_NAME):
    print("Creating collection", COLLECTION_NAME)
    client.create_collection(
        collection_name=COLLECTION_NAME,
        vectors_config=models.VectorParams(
            size=384,  # Depends on your embedding model, all-MiniLM-L6-v2 is a 384 dimensional dense vector space
            distance=models.Distance.COSINE
        )
    )
else:
    print("Collection already exists")

In [None]:
# Process embeddings in batches
batch_size = 64  # Larger batch size for GPU
question_embeddings = embed_model.encode(
    questions,
    batch_size=batch_size,
    convert_to_numpy=True,
    show_progress_bar=True,
    device="cuda"
)

In [None]:
# Upload to Qdrant

for i in tqdm(range(0, len(questions), batch_size)):
    batch_questions = questions[i : i + batch_size]
    batch_vectors = question_embeddings[i : i + batch_size]

    points_batch = [
        models.PointStruct(
            id=str(uuid.uuid4()),
            vector=batch_vectors[j],
            payload={"question": batch_questions[j]}
            )
        for j in range(len(batch_questions))
    ]
    client.upsert(
        collection_name=COLLECTION_NAME,
        points=points_batch,
    )