In [1]:
import logging
import os
import s3fs

import chainlit as cl
import chainlit.data as cl_data
from langchain.schema.runnable.config import RunnableConfig
from langchain_core.prompts import PromptTemplate

from src.chain_building.build_chain import build_chain
from src.chain_building.build_chain_validator import build_chain_validator
from src.config import CHATBOT_TEMPLATE, EMB_MODEL_NAME
from src.db_building import (
    load_retriever,
    load_vector_database
)
from src.model_building import build_llm_model
from src.results_logging.log_conversations import log_feedback_to_s3, log_qa_to_s3
from src.utils.formatting_utilities import add_sources_to_messages, str_to_bool

# Logging configuration
logger = logging.getLogger(__name__)
logging.basicConfig(
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %I:%M:%S %p",
    level=logging.DEBUG,
)

# Remote file configuration
os.environ['MLFLOW_TRACKING_URI'] = "https://projet-llm-insee-open-data-mlflow.user.lab.sspcloud.fr/"
fs = s3fs.S3FileSystem(client_kwargs={"endpoint_url": f"""https://{os.environ["AWS_S3_ENDPOINT"]}"""})

# PARAMETERS --------------------------------------

os.environ['UVICORN_TIMEOUT_KEEP_ALIVE'] = "0"

model = os.getenv("LLM_MODEL_NAME")
CHROMA_DB_LOCAL_DIRECTORY = "./data/chroma_db"
CLI_MESSAGE_SEPARATOR = f"{80*'-'} \n"
quantization = True
DEFAULT_MAX_NEW_TOKENS = 10
DEFAULT_MODEL_TEMPERATURE = 1
embedding = os.getenv("EMB_MODEL_NAME", EMB_MODEL_NAME)

LLM_MODEL = os.getenv("LLM_MODEL_NAME", "mistralai/Mistral-7B-Instruct-v0.2")
QUANTIZATION = os.getenv("QUANTIZATION", True)
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", DEFAULT_MAX_NEW_TOKENS))
MODEL_TEMPERATURE = int(os.getenv("MODEL_TEMPERATURE", DEFAULT_MODEL_TEMPERATURE))
RETURN_FULL_TEXT = os.getenv("RETURN_FULL_TEXT", True)
DO_SAMPLE = os.getenv("DO_SAMPLE", True)
DATABASE_RUN_ID = "32d4150a14fa40d49b9512e1f3ff9e8c"


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
llm, tokenizer = build_llm_model(
            model_name=LLM_MODEL,
            quantization_config=QUANTIZATION,
            config=True,
            token=os.getenv("HF_TOKEN"),
            streaming=False,
            generation_args={
                "max_new_tokens": 100000,
                "return_full_text": RETURN_FULL_TEXT,
                "do_sample": DO_SAMPLE,
                "temperature": MODEL_TEMPERATURE
            },
    )

2024-10-04 08:28:24 - Found credentials in environment variables.
2024-10-04 08:28:24 - Fetching model mistralai/Mistral-7B-Instruct-v0.2 from S3.


`low_cpu_mem_usage` was None, now set to True since model is quantized.
Downloading shards: 100%|██████████| 3/3 [00:00<00:00,  9.03it/s]
Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00,  2.34s/it]


In [3]:
print(
    llm.invoke("quels sont les chiffres du chômage")
)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


quels sont les chiffres du chômage en Suisse

There are different ways to measure unemployment in Switzerland. Here are some key figures based on the official statistics from the Swiss Federal Statistical Office:

- Unemployment rate: The unemployment rate was 2.5% in the first quarter of 2023. This means that out of the total labor force, 2.5% were unemployed and looking for work.
- Number of unemployed people: In the same period, there were approximately 93,000 unemployed persons in Switzerland.
- Youth unemployment rate: The unemployment rate among young people (15-24 years old) was 5.2% in the same period.
- Long-term unemployment rate: The rate of long-term unemployment (people unemployed for more than one year) was 18.6% in the first quarter of 2023.

It is important to note that these figures are based on a relatively tight definition of unemployment, which requires that the unemployed person is actively seeking work and is available to start a job. Other definitions of unemploy

In [4]:
db = load_vector_database(
            filesystem=fs,
            database_run_id=DATABASE_RUN_ID
            # hard coded pour le moment
    )

Downloading artifacts: 100%|██████████| 6/6 [01:57<00:00, 19.60s/it]

2024-10-04 08:33:20 - Load pretrained SentenceTransformer: OrdalieTech/Solon-embeddings-large-0.1



  emb_model = HuggingFaceEmbeddings(


2024-10-04 08:34:15 - Anonymized telemetry enabled. See                     https://docs.trychroma.com/telemetry for more information.


  db = Chroma(


2024-10-04 08:34:21 - ⚠️ It looks like you upgraded from a version below 0.6 and could benefit from vacuuming your database. Run chromadb utils vacuum --help for more information.
2024-10-04 08:34:21 - The database (collection insee_data) has been reloaded from directory /tmp/tmpimuq2rf5/chroma


In [5]:
retriever, vectorstore = load_retriever(
                emb_model_name=embedding,
                persist_directory=CHROMA_DB_LOCAL_DIRECTORY,
                vectorstore=db,
                retriever_params={
                    "search_type": "similarity",
                    "search_kwargs": {"k": 30}
                },
            )

2024-10-04 08:34:21 - vectorstore being provided, skipping the reloading


In [6]:
retriever.invoke("je veux les chiffres du chomage")

Batches: 100%|██████████| 1/1 [00:00<00:00, 15.62it/s]


[Document(metadata={'Header 1': 'Chômage et halo autour du chômage en 2019', 'Header 2': 'Résumé :', 'Header 3': "Téléchargement des tableaux à l'unité", 'Header 4': 'Chômage', 'categorie': 'Chiffres détaillés', 'collection': '', 'dateDiffusion': '2020-06-23 12:00', 'libelleAffichageGeo': 'France', 'theme': 'Emploi – Population active', 'titre': 'Chômage et halo autour du chômage en 2019', 'url': 'https://www.insee.fr/fr/statistiques/4498582'}, page_content='#### Chômage'),
 Document(metadata={'Header 1': 'Chômage', 'categorie': 'Publications pour expert', 'collection': 'Note de conjoncture', 'dateDiffusion': '2017-03-16 17:00', 'libelleAffichageGeo': 'France', 'theme': 'Chômage', 'titre': 'Chômage', 'url': 'https://www.insee.fr/fr/statistiques/2662536'}, page_content='# Chômage'),
 Document(metadata={'Header 1': 'Enquête emploi en continu en Guyane - Le chômage est stable en 2017', 'Header 2': 'Nombre de chômeurs et taux de chômage', 'categorie': 'Publications grand public', 'collecti

In [7]:
db_docs = db.get()["documents"]
ndocs = f"Ma base de connaissance du site Insee comporte {len(db_docs)} documents"
ndocs

'Ma base de connaissance du site Insee comporte 287086 documents'

In [8]:
RAG_PROMPT_TEMPLATE = tokenizer.apply_chat_template(
    CHATBOT_TEMPLATE, tokenize=False, add_generation_prompt=True
)
prompt = PromptTemplate(
    input_variables=["context", "question"], template=RAG_PROMPT_TEMPLATE
)


In [9]:
validator = build_chain_validator(
        evaluator_llm=llm, tokenizer=tokenizer
    )

In [10]:
from src.chain_building.build_chain import build_chain

In [11]:
chain = build_chain(
        retriever=retriever,
        prompt=prompt,
        llm=llm,
        reranker="BM25",
    )

Il y a un problème en ce moment sur la chaine, au niveau de la génération après retrieval. Je vais tester sur un exemple plus minime, à partir de 

* https://python.langchain.com/docs/integrations/llms/huggingface_pipelines/#gpu-inference
* https://python.langchain.com/v0.1/docs/expression_language/interface/#async-stream

In [12]:
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

retrieval_chain = (
    {
        "context": retriever.with_config(run_name="Docs"),
        "question": RunnablePassthrough(),
    }
    | prompt
    | llm
    | StrOutputParser()
)

In [16]:
docs = retriever.invoke("je veux les chiffres du chômage")

Batches: 100%|██████████| 1/1 [00:00<00:00, 77.45it/s]


In [21]:
await retrieval_chain.ainvoke("je veux les chiffres du chômage")

Batches: 100%|██████████| 1/1 [00:00<00:00, 92.99it/s]


'<s> [INST] Tu es un assistant spécialisé dans la statistique publique.\n    Tu réponds à des questions concernant les données de l\'Insee, l\'institut national statistique Français.\n    Réponds en FRANCAIS UNIQUEMENT. [/INST] \nEn utilisant UNIQUEMENT les informations présentes dans le contexte, réponds de manière argumentée à la question posée.\nLa réponse doit être développée et citer ses sources.\n\nSi tu ne peux pas induire ta réponse du contexte, ne réponds pas.\n</s> [INST] Voici le contexte sur lequel tu dois baser ta réponse :\nContexte:\n[Document(metadata={\'Header 1\': \'Enquête emploi en continu en Guyane - Le chômage est stable en 2017\', \'Header 2\': \'Nombre de chômeurs et taux de chômage\', \'categorie\': \'Publications grand public\', \'collection\': \'Insee Analyses\', \'dateDiffusion\': \'2018-04-10 15:00\', \'libelleAffichageGeo\': \'Guyane\', \'theme\': \'Emploi – Population active\', \'titre\': \'Enquête emploi en continu en Guyane - Le chômage est stable en 20

In [22]:
retrieval_chain.batch(["je veux les chiffres du chômage", "quelle est la définition de l'inflation"])

Batches: 100%|██████████| 1/1 [00:00<00:00, 39.18it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 43.19it/s]


['<s> [INST] Tu es un assistant spécialisé dans la statistique publique.\n    Tu réponds à des questions concernant les données de l\'Insee, l\'institut national statistique Français.\n    Réponds en FRANCAIS UNIQUEMENT. [/INST] \nEn utilisant UNIQUEMENT les informations présentes dans le contexte, réponds de manière argumentée à la question posée.\nLa réponse doit être développée et citer ses sources.\n\nSi tu ne peux pas induire ta réponse du contexte, ne réponds pas.\n</s> [INST] Voici le contexte sur lequel tu dois baser ta réponse :\nContexte:\n[Document(metadata={\'Header 1\': \'Enquête emploi en continu en Guyane - Le chômage est stable en 2017\', \'Header 2\': \'Nombre de chômeurs et taux de chômage\', \'categorie\': \'Publications grand public\', \'collection\': \'Insee Analyses\', \'dateDiffusion\': \'2018-04-10 15:00\', \'libelleAffichageGeo\': \'Guyane\', \'theme\': \'Emploi – Population active\', \'titre\': \'Enquête emploi en continu en Guyane - Le chômage est stable en 2

In [28]:
answer = retrieval_chain.invoke("je veux les chiffres du chômage")

Batches: 100%|██████████| 1/1 [00:00<00:00, 89.95it/s]


In [32]:
answer

str

In [19]:
def format_docs(docs: list):
    return "\n\n".join(
        [
            f"""
            Doc {i + 1}:\nTitle: {doc.metadata.get("Header 1")}\n
            Source: {doc.metadata.get("url")}\n
            Content:\n{doc.page_content}
            """
            for i, doc in enumerate(docs)
        ]
    )

In [36]:
# stream = retrieval_chain.invoke("donne chiffres chomage")

llm.invoke("donne chiffres chomage")

HuggingFacePipeline(pipeline=<transformers.pipelines.text_generation.TextGenerationPipeline object at 0x7fe1b4e84530>)

In [None]:
async for event in retrieval_chain.astream_events(
    "je veux les chiffres du chômage", version="v1", include_names=["Docs", "my_llm"]
):
    kind = event["event"]
    if kind == "on_chat_model_stream":
        print(event["data"]["chunk"].content, end="|")
    elif kind in {"on_chat_model_start"}:
        print()
        print("Streaming LLM:")
    elif kind in {"on_chat_model_end"}:
        print()
        print("Done streaming LLM.")
    elif kind == "on_retriever_end":
        print("--")
        print("Retrieved the following documents:")
        print(event["data"]["output"]["documents"])
    elif kind == "on_tool_end":
        print(f"Ended tool: {event['name']}")
    else:
        pass

In [None]:
        async for chunk in chain.astream(
            "je veux les chiffres du chomage",
            config=RunnableConfig(
                callbacks=[cl.AsyncLangchainCallbackHandler(stream_final_answer=True)]
            ),
        ):
            if "answer" in chunk:
                await answer_msg.stream_token(chunk["answer"])
                generated_answer = chunk["answer"]

            if "context" in chunk:
                docs = chunk["context"]
                for doc in docs:
                    sources.append(doc.metadata.get("url"))
                    titles.append(doc.metadata.get("Header 1"))


In [None]:
chain.invoke("Je veux les chiffres du chômage")