In [None]:
import logging
import os
import s3fs

import chainlit as cl
import chainlit.data as cl_data
from langchain.schema.runnable.config import RunnableConfig
from langchain_core.prompts import PromptTemplate

from src.chain_building.build_chain import build_chain
from src.chain_building.build_chain_validator import build_chain_validator
from src.config import CHATBOT_TEMPLATE, EMB_MODEL_NAME
from src.db_building import (
    load_retriever,
    load_vector_database
)
from src.model_building import build_llm_model
from src.results_logging.log_conversations import log_feedback_to_s3, log_qa_to_s3
from src.utils.formatting_utilities import add_sources_to_messages, str_to_bool

# Logging configuration
logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %I:%M:%S %p",
    level=logging.DEBUG,
)

# Remote file configuration
os.environ['MLFLOW_TRACKING_URI'] = "https://projet-llm-insee-open-data-mlflow.user.lab.sspcloud.fr/"
fs = s3fs.S3FileSystem(client_kwargs={"endpoint_url": f"""https://{os.environ["AWS_S3_ENDPOINT"]}"""})

# PARAMETERS --------------------------------------

os.environ['UVICORN_TIMEOUT_KEEP_ALIVE'] = "0"

model = os.getenv("LLM_MODEL_NAME")
CHROMA_DB_LOCAL_DIRECTORY = "./data/chroma_db"
CLI_MESSAGE_SEPARATOR = f"{80*'-'} \n"
quantization = True
DEFAULT_MAX_NEW_TOKENS = 10
DEFAULT_MODEL_TEMPERATURE = 1
embedding = os.getenv("EMB_MODEL_NAME", EMB_MODEL_NAME)

LLM_MODEL = os.getenv("LLM_MODEL_NAME", "mistralai/Mistral-7B-Instruct-v0.2")
QUANTIZATION = os.getenv("QUANTIZATION", True)
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", DEFAULT_MAX_NEW_TOKENS))
MODEL_TEMPERATURE = int(os.getenv("MODEL_TEMPERATURE", DEFAULT_MODEL_TEMPERATURE))
RETURN_FULL_TEXT = os.getenv("RETURN_FULL_TEXT", True)
DO_SAMPLE = os.getenv("DO_SAMPLE", True)
DATABASE_RUN_ID = "32d4150a14fa40d49b9512e1f3ff9e8c"


In [None]:
llm, tokenizer = build_llm_model(
            model_name=LLM_MODEL,
            quantization_config=QUANTIZATION,
            config=True,
            token=os.getenv("HF_TOKEN"),
            streaming=False,
            generation_args={
                "max_new_tokens": 100000,
                "return_full_text": RETURN_FULL_TEXT,
                "do_sample": DO_SAMPLE,
                "temperature": MODEL_TEMPERATURE
            },
    )

In [None]:
print(
    llm.invoke("quels sont les chiffres du chômage")
)

In [None]:
db = load_vector_database(
            filesystem=fs,
            database_run_id=DATABASE_RUN_ID
            # hard coded pour le moment
    )

In [None]:
retriever, vectorstore = load_retriever(
                emb_model_name=embedding,
                persist_directory=CHROMA_DB_LOCAL_DIRECTORY,
                vectorstore=db,
                retriever_params={
                    "search_type": "similarity",
                    "search_kwargs": {"k": 30}
                },
            )

In [None]:
retriever.invoke("je veux les chiffres du chomage")

In [None]:
db_docs = db.get()["documents"]
ndocs = f"Ma base de connaissance du site Insee comporte {len(db_docs)} documents"
ndocs

In [None]:
RAG_PROMPT_TEMPLATE = tokenizer.apply_chat_template(
    CHATBOT_TEMPLATE, tokenize=False, add_generation_prompt=True
)
prompt = PromptTemplate(
    input_variables=["context", "question"], template=RAG_PROMPT_TEMPLATE
)


In [None]:
validator = build_chain_validator(
        evaluator_llm=llm, tokenizer=tokenizer
    )

In [None]:
from src.chain_building.build_chain import build_chain

In [None]:
chain = build_chain(
        retriever=retriever,
        prompt=prompt,
        llm=llm,
        reranker="BM25",
    )

Il y a un problème en ce moment sur la chaine, au niveau de la génération après retrieval. Je vais tester sur un exemple plus minime, à partir de 

* https://python.langchain.com/docs/integrations/llms/huggingface_pipelines/#gpu-inference
* https://python.langchain.com/v0.1/docs/expression_language/interface/#async-stream

In [None]:
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

retrieval_chain = (
    {
        "context": retriever.with_config(run_name="Docs"),
        "question": RunnablePassthrough(),
    }
    | prompt
    | llm
    | StrOutputParser()
)

In [None]:
async for event in retrieval_chain.astream_events(
    "je veux les chiffres du chômage", version="v1", include_names=["Docs", "my_llm"]
):
    kind = event["event"]
    if kind == "on_chat_model_stream":
        print(event["data"]["chunk"].content, end="|")
    elif kind in {"on_chat_model_start"}:
        print()
        print("Streaming LLM:")
    elif kind in {"on_chat_model_end"}:
        print()
        print("Done streaming LLM.")
    elif kind == "on_retriever_end":
        print("--")
        print("Retrieved the following documents:")
        print(event["data"]["output"]["documents"])
    elif kind == "on_tool_end":
        print(f"Ended tool: {event['name']}")
    else:
        pass

In [None]:
        async for chunk in chain.astream(
            "je veux les chiffres du chomage",
            config=RunnableConfig(
                callbacks=[cl.AsyncLangchainCallbackHandler(stream_final_answer=True)]
            ),
        ):
            if "answer" in chunk:
                await answer_msg.stream_token(chunk["answer"])
                generated_answer = chunk["answer"]

            if "context" in chunk:
                docs = chunk["context"]
                for doc in docs:
                    sources.append(doc.metadata.get("url"))
                    titles.append(doc.metadata.get("Header 1"))


In [None]:
chain.invoke("Je veux les chiffres du chômage")