In [1]:
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
import chromadb
from transformers import AutoTokenizer, AutoModel
import numpy as np
import torch
import json

# Initialize Chroma DB
client = chromadb.Client()

# Create a collection in Chroma DB
collection_name = "rag-example"
collection = client.create_collection(name=collection_name)

# Load the embedding model
model_name = "BAAI/bge-m3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

def embed_text(texts):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        embeddings = model(**inputs).last_hidden_state.mean(dim=1).numpy()
    return embeddings

# Add some documents to the Chroma DB collection
documents = [
    {"id": "1", "text": "LangChain is a library for building language models."},
    {"id": "2", "text": "Transformers are powerful tools for NLP tasks."},
    {"id": "3", "text": "FAISS is a library for efficient similarity search."}
]
embeddings = embed_text([doc["text"] for doc in documents])
vectors = [(doc, embedding) for doc, embedding in zip(documents, embeddings)]

for doc, embedding in vectors:
    collection.upsert(doc["id"], embedding.tolist(), doc["text"])

# Save collection to a JSON file
def save_collection_to_file(collection, file_path):
    all_documents = collection.get_all()
    with open(file_path, 'w') as f:
        json.dump(all_documents, f)

save_collection_to_file(collection, "collection.json")

# Load collection from a JSON file
def load_collection_from_file(client, collection_name, file_path):
    collection = client.create_collection(name=collection_name)
    with open(file_path, 'r') as f:
        all_documents = json.load(f)
    for doc in all_documents:
        collection.upsert(doc["id"], doc["embedding"], doc["metadata"]["text"])
    return collection

# Define a simple retrieval function
def retrieve(query, k=3):
    query_embedding = embed_text([query])[0]
    results = collection.query(query_embedding.tolist(), top_k=k)
    return [(result["metadata"]["text"], result["score"]) for result in results]

# Example usage
query = "What are transformers used for?"
results = retrieve(query)
for result in results:
    print(result)


  from .autonotebook import tqdm as notebook_tqdm


ValueError: Expected metadatas to be a list, got LangChain is a library for building language models.