In [None]:
from sqlalchemy import select
import os
import json
from tca.database.models import (
    DocumentChunk,
    OpenAITextEmbedding3LargeChunkEmbedding,
    OpenAITextEmbedding3LargeThemeEmbedding,
    Theme,
)
from tca.database.session_manager import PostgresSessionManager
from sqlalchemy import func
import pandas as pd
from tca.text.llm_clients import OpenAIAPIClient
from typing import List, Dict, Any, Set, Tuple

postgres_session_manager = PostgresSessionManager()
session = postgres_session_manager.session

llm_client = OpenAIAPIClient(
    model_name=os.environ["OPENAI_LLM_MODEL"],
    api_key=os.environ["OPENAI_API_KEY"],
)
# llm_client = OllamaAPIClient(
#     endpoint=Url("http://localhost:11434/api/generate"),
#     model_name="llama3.1:8b-instruct-fp16",
#     timeout=60,
# )

NB_CHUNKS_TO_RETRIEVE = 6

query = select(
    Theme,
    OpenAITextEmbedding3LargeThemeEmbedding.embeddings,
).join(Theme, OpenAITextEmbedding3LargeThemeEmbedding.theme_id == Theme.id)
theme_with_embeddings = [
    (theme, embeddings) for theme, embeddings in session.execute(query).all()
]


def find_themes_in_documents(
    theme_with_embeddings: list[Tuple[Theme, Any]],
    session: Any,
    llm_client: Any,
    nb_chunks_to_retrieve: int = NB_CHUNKS_TO_RETRIEVE,
) -> pd.DataFrame:
    theme_list: List[str] = []
    doc_sentences: Dict[str, Set[str]] = {}

    theme2_to_theme_info = {}
    for theme_info, query_embeddings in theme_with_embeddings:
        theme: str = theme_info.themes[-1]
        theme2_to_theme_info[theme] = theme_info
        theme_list.append(f"{theme} : {theme_info.description}")

        chunk_select_query = (
            select(
                DocumentChunk,
                OpenAITextEmbedding3LargeChunkEmbedding.embeddings.cosine_distance(
                    query_embeddings
                ).label("cos_distance"),
                func.row_number()
                .over(
                    partition_by=DocumentChunk.document_id,
                    order_by=OpenAITextEmbedding3LargeChunkEmbedding.embeddings.cosine_distance(
                        query_embeddings
                    ),
                )
                .label("row_number"),
            )
            .join(
                OpenAITextEmbedding3LargeChunkEmbedding,
                OpenAITextEmbedding3LargeChunkEmbedding.chunk_id == DocumentChunk.id,
            )
            .filter(
                OpenAITextEmbedding3LargeChunkEmbedding.embeddings.cosine_distance(
                    query_embeddings
                )
                < 0.75
            )
            .subquery()
        )

        query = (
            select(
                chunk_select_query.c.document_id,
                func.array_agg(chunk_select_query.c.chunk_text).label("chunk_texts"),
                func.array_agg(chunk_select_query.c.cos_distance).label(
                    "cos_distances"
                ),
            )
            .filter(chunk_select_query.c.row_number <= nb_chunks_to_retrieve)
            .group_by(chunk_select_query.c.document_id)
            .order_by(func.min(chunk_select_query.c.cos_distance))
        )

        results = session.execute(query).all()

        for document_id, chunk_texts, _cos_distances in results:
            doc_sentences.setdefault(document_id, set()).update(chunk_texts)

    themes_str: str = "\n- ".join(theme_list)
    found_themes_per_doc: Dict[str, Dict[str, str]] = {}

    results = []
    for document_id, sentences in doc_sentences.items():
        print(f"document_id {document_id}")
        sentences_str: str = "\n- ".join(sentences)
        prompt: str = f"""
Je souhaite déterminer si les thèmes suivants sont abordés explicitement dans un accord d'entreprise.
Les thèmes sont les suivants, avec la forme "theme : description du thème" :
- {themes_str}

Ces phrases proviennent de l'accord d'entreprise :

- {sentences_str}

Garde chaque thème qui est abordé explicitement en positif ou en négatif dans au moins une des phrases et ignore les autres thèmes.
Retourne le résultat sous forme de JSON (sans balises de bloc de code) avec le format suivant :

["theme1", "theme2", ...]
        """
        response: str = llm_client.generate_text(prompt)
        response_obj = json.loads(response)
        for theme in response_obj:
            results.append(
                {
                    "Document": document_id,
                    "Thème n1": theme2_to_theme_info[theme].themes[0],
                    "Thème n2": theme2_to_theme_info[theme].themes[1],
                }
            )
        found_themes_per_doc[document_id] = response_obj
    return pd.DataFrame(results)


found_themes_per_doc = find_themes_in_documents(
    theme_with_embeddings=theme_with_embeddings,
    session=session,
    llm_client=llm_client,
    nb_chunks_to_retrieve=NB_CHUNKS_TO_RETRIEVE,
)

In [None]:
import json

print(
    json.dumps(
        found_themes_per_doc[
            "1d6937456045163545c4d5d7cf65178c9e7b1ea96922a36727abd4afe3733b7b"
        ],
        indent=4,
        ensure_ascii=False,
    )
)
result = {
    "salaires": "Amélioration des conditions salariales mentionnée.",
    "commissions": "Création d'une commission ad hoc mentionnée.",
    "fin de conflit": "Protocole de fin de conflit mentionné.",
    "indemnités": "Conditions d'accès à l'indemnité des paniers repas mentionnées.",
    "droit syndical irp & expression des salariés": "Participation des organisations syndicales CGT et STC mentionnée.",
    "dispositions conditions de travail": "Conditions de travail et plannings d'intervention mentionnés.",
    "déplacement": "Temps de déplacement considéré comme temps de travail effectif mentionné.",
}

In [None]:
from tca.database.models import DocumentChunk
from tca.database.models import (
    DocumentChunk,
)
from tca.database.session_manager import PostgresSessionManager

postgres_session_manager = PostgresSessionManager()
session = postgres_session_manager.session
doc_id = "1d6937456045163545c4d5d7cf65178c9e7b1ea96922a36727abd4afe3733b7b"
document = session.query(DocumentChunk).filter_by(document_id=doc_id).first()
if document:
    print(document.document_name)
else:
    print(f"No document found with id {doc_id}")