In [40]:
from psycopg import Cursor
from openai import OpenAI
import os
import psycopg

In [41]:
conversation_file_path = "../data/conversation.txt"
openai_client = OpenAI(api_key=os.getenv("OPENAI_KEY"))
db_connection_str = "dbname=chatbot_rag user=postgres password=1234 host=localhost port=5432"

In [42]:
def create_conversation_list(file_path: str) -> list[str]:
    with open(file_path, encoding="utf-8") as file:
        text = file.read()
        text_list = text.split("\n")
        filtered_list = [chaine.removeprefix("     ") for chaine in text_list if not chaine.startswith("<")]
        return filtered_list
    
def calculate_embedding(corpus: str, client: OpenAI) -> list[float]:
    embeddings = client.embeddings.create(
        model="text-embedding-ada-002",
        input=corpus,
        encoding_format="float"
    ).data
    return embeddings[0].embedding

def save_embedding(corpus: str, embedding: list[float], cursor: Cursor) -> None:
    cursor.execute("""
        INSERT INTO embeddings (corpus, embedding) VALUES (%s, %s)
    """, (corpus, embedding))

def retrieve_similar_corpus(input_corpus: str, client: OpenAI, db_connection_str: str) -> tuple[int, str, list[float]]:
    input_corpus_embedding = calculate_embedding(corpus=input_corpus, client=client)
    with psycopg.connect(db_connection_str) as conn:
        with conn.cursor() as cur:
            query = """
                SELECT id, corpus, embedding
                FROM embeddings
                ORDER BY embedding <=> %s::vector
                LIMIT 1;
            """
            cur.execute(query, [input_corpus_embedding])
            result = cur.fetchone()
            return result
        
def generate_response(input_corpus: str, client: OpenAI = openai_client, db_connection_str: str=db_connection_str):
    similar_text = retrieve_similar_corpus(input_corpus=input_corpus, client=client, db_connection_str=db_connection_str)[1]
    completion = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": "Vous êtes un assistant chatbot serviable travaillant dans le service d'accueil d'une université. Vous devez reformuler des réponses extraites d'une base de données de manière cohérente et compréhensible pour l'utilisateur."},
            {
                "role": "user",
                "content": similar_text
            }
        ]
    )
    return completion.choices[0].message.content
    

In [11]:
import psycopg
import numpy as np

with psycopg.connect(db_connection_str) as conn:
    with conn.cursor() as cur:
        cur.execute(""" DROP TABLE embeddings""")
        cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")

        cur.execute("""
            CREATE TABLE IF NOT EXISTS embeddings (
                    id serial PRIMARY KEY,
                    corpus text,
                    embedding vector(1536)
            );
        """)

        corpus_list = create_conversation_list(file_path=conversation_file_path)
        for corpus in corpus_list:
            embedding = calculate_embedding(corpus=corpus, client=openai_client)
            save_embedding(corpus=corpus, embedding=embedding, cursor=cur)

        conn.commit()

In [36]:
user_message = "Où se trouve le site?"
retrieve_similar_corpus(input_corpus=user_message, client=openai_client, db_connection_str=db_connection_str)

(3,
 'h: oui # au site à Lorient à quel numéro',
 '[-0.0059153903,-0.009765194,-0.007924703,-0.007222933,0.000943417,0.032175485,-0.020655867,-0.009209074,-0.01182416,0.00032316172,0.039722823,0.025343161,-0.0062530343,-0.0057962216,-0.017358873,-0.004528402,0.021185504,0.0051407954,0.030533608,-0.01310191,-0.019172883,0.004012005,-0.014339939,-0.019464182,0.008666196,-0.01718674,0.0059749745,-0.017014608,-0.0020937237,0.007249415,0.022959791,0.0031298273,-0.0049057687,-0.023608597,-0.026839387,0.024469258,0.010586132,-0.0024826764,0.008182901,-0.009897603,0.0312751,-0.004141104,0.0026134306,-0.028309131,-0.009983669,0.028203202,0.007865119,-0.0101624215,-0.0031496887,0.0194377,0.011850642,0.0026829457,0.0077062272,-0.0032738226,-0.0050911414,0.0074943723,0.0011072736,0.018510835,-0.0015657413,0.00078287063,0.014485589,-0.027461711,-0.031354547,0.011201835,-0.028520986,-0.024297126,0.010890674,0.0076731252,-0.02759412,-0.01620691,0.014392902,0.029871562,-0.0014556759,-0.038848918,0.032

In [43]:
generate_response(input_corpus=user_message)

"Pour contacter le site de Lorient, vous pouvez appeler au numéro suivant : [insérez le bon numéro ici, car je n'ai pas accès à la base de données pour vous le fournir]."