In [27]:
import torch
from sentence_transformers import SentenceTransformer
import chromadb
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SentenceTransformer('distiluse-base-multilingual-cased-v2', device=device)
client = chromadb.PersistentClient(path="./chromadb-docs")

def get_data(id:int ,data: str):
    embedding = model.encode(data)
    embedding = embedding.tolist()
    collection = client.get_or_create_collection(
        name="embeddings",
        metadata={"hnsw:space":"cosine"}
    )
    collection.add(
        documents=[data],  # Wrap the single string in a list
        embeddings=embedding,  # Wrap the single embedding in a list
        ids=[str(id)]  # Wrap the single id in a list
    )

def get_collection_length():
    collection = client.get_collection(name="embeddings")
    return collection.count()

def query(question: str):
    query_embedding = model.encode(question)
    collection = client.get_collection(name="embeddings")
    results = collection.query(
        query_embeddings=query_embedding.tolist(),
        n_results=2
    )
    # Extract IDs and calculate scores
    ids = results['ids'][0]
    scores = [1 - distance for distance in results['distances'][0]]
    
    # Combine IDs and scores into a list of dictionaries
    result_list = [{'id': id, 'score': score} for id, score in zip(ids, scores)]
    
    return result_list






Question: Where is the Eiffel Tower located?
Results: [{'id': '0', 'score': 0.7442867185246337}, {'id': '4', 'score': 0.20982066127848387}]


In [None]:
test_docs = [
    "The Eiffel Tower is located in Paris.",
    "The Great Wall of China is visible from space.",
    "The Mona Lisa is a famous painting by Leonardo da Vinci.",
    "The Pyramids of Giza are one of the Seven Wonders of the Ancient World.",
    "The Great Wall of China is the longest wall in the world.",
]

test_questions = [
    "Where is the Eiffel Tower located?",
    "Can you see the Great Wall of China from space?",
    "Who painted the Mona Lisa?",
    "What are the Pyramids of Giza?"
]
# Add test documents to the collection
for i, doc in enumerate(test_docs):
    get_data(i, doc)
    print(f"length: {get_collection_length()}")
    print()

In [None]:
test = test_questions[0]
results = query(test)
print(f"Question: {test}")
print("Results:", results)