In [None]:
# %pip install lightrag-hku
# %pip install pipmaster

In [None]:
import os
import asyncio
import nest_asyncio
nest_asyncio.apply()

import numpy as np

from pathlib import Path
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc, setup_logger
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.llm.ollama import ollama_embed
from lightrag.llm.azure_openai import azure_openai_complete, azure_openai_complete_if_cache
from mistralai import Mistral
import ollama


In [None]:
from dotenv import load_dotenv
load_dotenv()

API_KEY = os.getenv("MISTRAL_API_KEY")
COMPLETION_MODEL = os.getenv("MISTRAL_MODEL", "mistral-small-latest")
MISTRAL_EMBED_MODEL = os.getenv("MISTRAL_EMBED_MODEL", "mistral-embed")
OLLAMA_EMBED_MODEL = "nomic-embed-text" # for ollama-embed

In [None]:

mistral_client = Mistral(api_key=API_KEY)

In [None]:
# --- Embedding function using chunked calls ---
async def get_mistral_embeddings(texts: list[str], chunk_size: int = 50) -> list[list[float]]:
    """
    Async wrapper: compute embeddings via Mistral SDK in chunks, each call offloaded
    to a thread so as not to block the event loop.
    """
    embeddings: list[list[float]] = []
    for i in range(0, len(texts), chunk_size):
        chunk = texts[i : i + chunk_size]
        # run the blocking SDK call in a thread
        resp = await asyncio.to_thread(
            mistral_client.embeddings.create,
            model=MISTRAL_EMBED_MODEL,
            inputs=chunk
        )
        embeddings.extend([d.embedding for d in resp.data])
    return embeddings

# --- Async wrapper for LLM using Mistral SDK ---
async def mistral_llm(prompt: str, system_prompt: str = None, history_messages: list = None, **kwargs) -> str:
    """
    Async wrapper that uses Mistral's official SDK in a thread pool.
    """
    # Build chat messages
    messages = []
    messages.append({"role": "system", "content": system_prompt or "You are a helpful assistant."})
    messages.append({"role": "user", "content": prompt})

    def call_sdk():
        resp = mistral_client.chat.complete(model=COMPLETION_MODEL, messages=messages)
        return resp.choices[0].message.content

    # Delegate blocking call to thread
    return await asyncio.to_thread(call_sdk)

async def initialize_rag(working_dir: str):
    """
    Initializes the LightRAG instance with Mistral embeddings and LLM.
    """
    setup_logger("lightrag", level="INFO")
    
    rag = LightRAG(
        working_dir=working_dir,
        # embedding_func=EmbeddingFunc(
        #     func=get_mistral_embeddings,
        #     embedding_dim=1024,
        #     max_token_size=8192
        # ),
        embedding_func=EmbeddingFunc(
            embedding_dim=768,
            max_token_size=8192,
            func=lambda texts: ollama_embed(
                texts, embed_model="nomic-embed-text", host="http://localhost:11434"
            ),
        ),
        # llm_model_func=mistral_llm,
        llm_model_func=azure_openai_complete,
        # enable_llm_cache=False,
        # enable_llm_cache_for_entity_extract=False,
    )
    await rag.initialize_storages()
    await initialize_pipeline_status()
    return rag

async def load_rag(working_dir: str):
    # (re-)create the LightRAG object with the same config you used originally
    setup_logger("lightrag", level="INFO")

    rag = LightRAG(
        working_dir=working_dir,
        # embedding_func=EmbeddingFunc(
        #     func=get_mistral_embeddings,
        #     embedding_dim=1024,
        #     max_token_size=8192
        # ),
        embedding_func=EmbeddingFunc(
            embedding_dim=768,
            max_token_size=8192,
            func=lambda texts: ollama_embed(
                texts, embed_model="nomic-embed-text", host="http://localhost:11434"
            ),
        ),
        llm_model_func=mistral_llm,
        # enable_llm_cache=False,
        # enable_llm_cache_for_entity_extract=False,
    )

    # this will *load* your existing KV/Vector/Graph stores instead of recreating them
    await rag.initialize_storages()
    await initialize_pipeline_status()
    return rag

In [None]:
# Paths & environment checks
WORK_DIR = "./rag_storage"

# Initialize a new RAG
rag = await initialize_rag(WORK_DIR)

# Load an existing RAG
# rag = await load_rag(WORK_DIR)

In [None]:
txt_files = list(Path("../../knowledge_extraction/txt/raw").glob("*.txt"))
txt_files = [str(f) for f in txt_files if f.is_file()]

text_contents = []
for txt_file in txt_files:
    with open(txt_file, "r", encoding="utf-8") as f:
        text_contents.append(f.read())
        

source_filepaths = [str(f).replace("..\\", "").replace("../", "") for f in txt_files]
source_filepaths

In [None]:
rag.insert(input=text_contents, file_paths=source_filepaths)

In [None]:
# Ask a question
question = "comment faire une esquisse dans catia ?"
response = rag.query(
    question,
    param=QueryParam(mode="hybrid", top_k=20)  # hybrid local+global retrieval
)

print(f"Q: {question}\nA: {response}")

In [None]:
# Ensure storages are properly closed
# await rag.finalize_storages()