In [57]:
from pymongo import MongoClient
from pymongo.results import InsertOneResult, InsertManyResult
from pymongo.errors import BulkWriteError

from typing import Union

import os
from dotenv import load_dotenv
load_dotenv()


True

In [64]:
example_history = [
{'role': 'user', 'content': 'What is the capital of France?'},
{'role': 'chatbot', 'content': 'The capital of France is Paris.'},
{'role': 'user', 'content': 'And what is its population?'},
]

In [16]:
def get_chat_history(db_constants: dict) -> list:
    """
    Get the chat history from the database
    Args:
        db_constants (dict): A dictionary containing the database connection string and the user ID
    Returns:
        list: A list of strings containing the chat history in the format 'role: content'
    Raises:
        Exception: If an error occurs while trying to get the chat history
    """
    try:
        client = MongoClient(db_constants['connection_string'])
        collection = client[db_constants['histories_db']][db_constants['user_id']]
        messages = collection.find()
    except Exception as e:
        raise f'An error occurred while trying to get the chat history: {e}'
    
    # Format the messages in the desired 'role: content' format
    formatted_messages = [f"{message['role']}: {message['content']}" for message in messages]
    return formatted_messages


def update_chat_history(db_constants: dict, message: Union[dict, list]) -> Union[InsertOneResult, InsertManyResult]:
    """
    Update the chat history in the database
    Args:
        db_constants (dict): A dictionary containing the database connection string and the user ID
        message (dict or list): A dictionary or a list of dictionaries containing the messages to be added to the chat history
    Returns:
        InsertOneResult or InsertManyResult: The ID of the inserted document or a list of IDs of the inserted documents
    Raises:
        ValueError: If the message is not a dictionary or a list of dictionaries
        Exception: If an error occurs while trying to update the chat history
    """
    try:
        client = MongoClient(db_constants['connection_string'])
        collection = client[db_constants['histories_db']][db_constants['user_id']]
        if isinstance(message, dict):
            result = collection.insert_one(message)
        elif isinstance(message, list):
            result = collection.insert_many(message)
        else:
            raise ValueError('Message must be a dictionary or a list of dictionaries.')
    except Exception as e:
        raise f'An error occurred while trying to update the chat history: {e}'
    
    return result
    

def reset_chat_history(db_constants: dict):
    """
    Delete all the chat history from the collection
    Args:
        db_constants (dict): A dictionary containing the database connection string and the user ID
    Raises:
        Exception: If an error occurs while trying to reset the chat history
    """
    client = MongoClient(db_constants['connection_string'])
    collection = client[db_constants['histories_db']][db_constants['user_id']]
    try:
        collection.delete_many({})
    except Exception as e:
        raise f'An error occurred while trying to reset the chat history: {e}'
    
    

In [59]:


class DBHandler:
    def __init__(self, user_id: str, connection_string: Union[str, None] = None):
        """
        Initialize the DBHandler class
        Args:
            user_id (str): The name of the collection containing the chat histories
            connection_string (str): The connection string to the MongoDB database
        """
        # constants
        self.embeddings_db = 'embeddings'
        self.histories_db = 'histories'
        
        if not connection_string:
            connection_string = os.getenv('MONGODB_CONNECTION_STRING')
        try:
             if not connection_string or not isinstance(connection_string, str):
                raise ValueError('Connection string must be a non-empty string.')
             else: 
                self.client = MongoClient(connection_string)
        except Exception as e:
            raise f'An error occurred while trying to connect to the database: {e}'
        
        self.embeddings_collection = self.client[self.embeddings_db][user_id]
        self.history_collection = self.client[self.histories_db][user_id]
        
    def get_history(self) -> list:
        """
        Get the chat history from the database
        Returns:
            list: A list of strings containing the chat history in the format 'role: content'
        Raises:
            Exception: If an error occurs while trying to get the chat history
        """
        try:
            messages = self.history_collection.find()
        except Exception as e:
            raise f'An error occurred while trying to get the chat history: {e}'
        
        # Format the messages in the desired 'role: content' format
        formatted_messages = [f"{message['role']}: {message['content']}" for message in messages]
        return formatted_messages
    
    def update(self, db: str, items: Union[dict, list]) -> Union[InsertOneResult, InsertManyResult]:
        """
        Update the chat history in the database
        Args:
            db (str): The name of the db to update, either 'embeddings' or 'history'
            items (dict or list): A dictionary or a list of dictionaries containing the items to be added to the collection
        Returns:
            InsertOneResult or InsertManyResult: The ID of the inserted document or a list of IDs of the inserted documents
        Raises:
            ValueError: If the message is not a dictionary or a list of dictionaries
            Exception: If an error occurs while trying to update the chat history
        """
        if db == 'embeddings':
            collection = self.embeddings_collection
        elif db == 'history':
            collection = self.history_collection
        else:
            raise ValueError('The db must be either "embeddings" or "history".')
        
        try:
            if isinstance(items, dict):
                result = collection.insert_one(items)
            elif isinstance(items, list):
                result = collection.insert_many(items)
            else:
                raise ValueError('items must be a dictionary or a list of dictionaries.')
        except BulkWriteError as bwe:
            raise RuntimeError(f'Duplicate key error occurred: {bwe.details}')
        except Exception as e:
            raise RuntimeError(f'An error occurred while trying to update the chat history: {str(e)}')

        return result
    
    def reset_history(self):
        """
        Delete all the chat history from the collection
        Raises:
            Exception: If an error occurs while trying to reset the chat history
        """
        try:
            self.history_collection.delete_many({})
        except Exception as e:
            raise f'An error occurred while trying to reset the chat history: {e}'
        
        

In [60]:
handler = DBHandler('maccabi')

In [65]:
handler.update(db='history', items=example_history)

InsertManyResult([ObjectId('670d683e416706117c59b39c'), ObjectId('670d683e416706117c59b39d'), ObjectId('670d683e416706117c59b39e')], acknowledged=True)

In [66]:
handler.get_history()

['user: What is the capital of France?',
 'chatbot: The capital of France is Paris.',
 'user: And what is its population?',
 'user: What is the capital of France?',
 'chatbot: The capital of France is Paris.',
 'user: And what is its population?',
 'user: What is the capital of France?',
 'chatbot: The capital of France is Paris.',
 'user: And what is its population?']

In [48]:
handler.reset_history()