diff --git a/script.sql b/script.sql index 27185d3..6fa8cf5 100644 --- a/script.sql +++ b/script.sql @@ -1,14 +1,22 @@ - +create table users +( + id serial primary key, + email varchar(250) null, + password varchar(50) null, + name varchar(50) null +); create table chat_messages ( - id serial - primary key, - user_id varchar(250), + id serial primary key, + user_ref int not null, chatbot_id varchar(50), message text, is_bot_reply boolean, - createdat timestamp default (now() AT TIME ZONE 'UTC'::text) + createdat timestamp default (now() AT TIME ZONE 'UTC'::text), + CONSTRAINT fk_user + FOREIGN KEY(user_ref) + REFERENCES users(id) ); alter table chat_messages @@ -20,6 +28,5 @@ create index idx_timestamp create index idx_chatbot_id on chat_messages (chatbot_id); -create index idx_user_id - on chat_messages (user_id); - +create unique index idx_user_email + on users (email); diff --git a/server/src/api/apps.py b/server/src/api/apps.py index e88ac03..85c3deb 100644 --- a/server/src/api/apps.py +++ b/server/src/api/apps.py @@ -92,11 +92,11 @@ def delete_app(app_key: str, current_user: User = Depends(get_current_user)): @r.get("/applications/{app_key}/history") def get_app_conversation(app_key: str, user: User = Depends(get_current_user)): if app_key == "chat": - return chat_history.get_latest_messages(user.email, app_key) + return chat_history.get_latest_messages(user.pk, app_key) for a in apps.get_by_user_email(user.email): if app_key == a.app_key: - return chat_history.get_latest_messages(user.email, app_key) + return chat_history.get_latest_messages(user.pk, app_key) raise HTTPException(status_code=404, detail="App not found") diff --git a/server/src/api/factory.py b/server/src/api/factory.py index 98886bd..b5d49a6 100644 --- a/server/src/api/factory.py +++ b/server/src/api/factory.py @@ -16,5 +16,5 @@ chat_history = factory_chat_history(pg_conn) agent = agent_factory(chat_history, cost_service, llm_service) apps = AppDao() -user_service = UserService(UserDao(), apps) +user_service = UserService(UserDao(pg_conn), apps) current_session = {} diff --git a/server/src/core/agent/agent_service.py b/server/src/core/agent/agent_service.py index b9d08df..b755ab1 100644 --- a/server/src/core/agent/agent_service.py +++ b/server/src/core/agent/agent_service.py @@ -53,7 +53,7 @@ def is_action(self, req: UserInputDto) -> bool: async def handle_user_input(self, req: UserInputDto) -> dict: current_user = req.user add_message_dto = AddMessageDto( - user_email=current_user.email, + user_ref=req.user.pk, app_key=req.app.app_key, session_id=req.session_id, message=MessageCompletion( @@ -76,7 +76,7 @@ async def handle_user_input(self, req: UserInputDto) -> dict: message = llm_resp.message add_message_dto = AddMessageDto( - user_email=current_user.email, + user_ref=req.user.pk, app_key=req.app.app_key, session_id=req.session_id, message=MessageCompletion( @@ -96,7 +96,7 @@ async def handle_user_input(self, req: UserInputDto) -> dict: def user_history_process(self, req: UserInputDto) -> dict: add_message_dto = AddMessageDto( - user_email=req.user.email, + user_ref=req.user.pk, app_key=req.app.app_key, session_id=req.session_id, message=MessageCompletion( diff --git a/server/src/core/common/pg.py b/server/src/core/common/pg.py index bb2cc71..e1b4108 100644 --- a/server/src/core/common/pg.py +++ b/server/src/core/common/pg.py @@ -33,3 +33,13 @@ def fetch_all(self, query: str, params=None) -> List[Dict]: for row in rows: result.append(dict(zip(cols, row))) return result + + def fetch_one(self, query: str, params=None) -> Dict: + self.cursor.execute(query, params) + row = self.cursor.fetchone() + if row is None: + return None + # Convert rows to list of dictionaries so they're easier to work with + cols = [desc[0] for desc in self.cursor.description] + result = [] + return dict(zip(cols, row)) diff --git a/server/src/core/history/chat_history.py b/server/src/core/history/chat_history.py index d7a9ebb..31f4e07 100644 --- a/server/src/core/history/chat_history.py +++ b/server/src/core/history/chat_history.py @@ -9,7 +9,7 @@ class AddMessageDto(BaseModel): - user_email: str + user_ref: int app_key: str session_id: str message: MessageCompletion @@ -21,7 +21,7 @@ def __init__(self, dao: HistoryDao): self.history = CacheMemory(30) def add_message(self, req: AddMessageDto): - user_email = req.user_email + user_ref = req.user_ref app_key = req.app_key session_id = req.session_id @@ -32,18 +32,18 @@ def add_message(self, req: AddMessageDto): history.append(req.message) self.history.put(session_id, history) - self.persist_message(user_email, app_key, req.message) + self.persist_message(user_ref, app_key, req.message) def get_history(self, key): return self.history.get(key) - def persist_message(self, user_email, app_key, message): + def persist_message(self, user_ref, app_key, message): is_bot_replay = message.role == MessageRole.ASSISTANT msg = message.response if is_bot_replay else message.query - self.dao.persist_message(user_email, app_key, msg, is_bot_replay) + self.dao.persist_message(user_ref, app_key, msg, is_bot_replay) - def get_latest_messages(self, user_email: str, app_key: str, page: int) -> List[Dict]: - return self.dao.get_latest_messages(user_email, app_key, page) + def get_latest_messages(self, user_ref: int, app_key: str, page: int) -> List[Dict]: + return self.dao.get_latest_messages(user_ref, app_key, page) def factory_chat_history(pg_conn: DBConnection): diff --git a/server/src/core/history/history_dao.py b/server/src/core/history/history_dao.py index b526987..8fcd6ab 100644 --- a/server/src/core/history/history_dao.py +++ b/server/src/core/history/history_dao.py @@ -9,21 +9,21 @@ class HistoryDao: def __init__(self, db: DBConnection): self.db = db - def persist_message(self, user_email, app_key, msg, is_bot_replay): + def persist_message(self, user_ref, app_key, msg, is_bot_replay): insert_query = sql.SQL( """ - INSERT INTO chat_messages(user_id, chatbot_id, message, is_bot_reply) + INSERT INTO chat_messages(user_ref, chatbot_id, message, is_bot_reply) VALUES(%s, %s, %s, %s) """ ) - self.db.execute(insert_query, (user_email, app_key, msg, is_bot_replay)) + self.db.execute(insert_query, (user_ref, app_key, msg, is_bot_replay)) - def get_latest_messages(self, user_email: str, app_key: str) -> List[Dict]: + def get_latest_messages(self, user_ref: str, app_key: str) -> List[Dict]: return self.db.fetch_all( """ - SELECT * FROM chat_messages where chatbot_id = %s and user_id = %s + SELECT * FROM chat_messages where chatbot_id = %s and user_ref = %s ORDER BY createdat DESC LIMIT 50 """, - (app_key, user_email), + (app_key, user_ref), ) diff --git a/server/src/core/user/user_dao.py b/server/src/core/user/user_dao.py index 949b999..fd24966 100644 --- a/server/src/core/user/user_dao.py +++ b/server/src/core/user/user_dao.py @@ -1,4 +1,6 @@ -from typing import Dict +from typing import Dict, Optional + +from psycopg2 import IntegrityError from pydantic import BaseModel @@ -6,43 +8,47 @@ class User(BaseModel): + pk: Optional[int] = None email: str - password: str + password: str # TODO: encrypted password name: str = None class UserDao: - def __init__(self): - self.db = FileDB('./file_db/users') + def __init__(self, pg_conn): + self.db = pg_conn def get_all(self) -> Dict[str, User]: - users = self.db.get("all") - if users is None: - return {} - return users + users = self.db.fetch_all("SELECT * FROM users") + return [User(**u) for u in users] - def get(self, user) -> User: - return self.get_all().get(user) + def get_by_email(self, email: str) -> User: + user = self.db.fetch_one("SELECT * FROM users WHERE email=%s", (email,)) + user["pk"] = user["id"] # slightly a hack + return User(**user) def add(self, user: User): - if self.get(user.email) is not None: + try: + self.db.execute( + "INSERT INTO users (email, password, name) VALUES (%s, %s, %s)", + (user.email, user.password, user.name) + ) + except IntegrityError as e: + print("Inserting user:", e) return False - users = self.get_all() - users[user.email] = user - - self.db.put("all", users) - def edit(self, user: User): - users = self.get_all() - if users.get(user.email) is None: + try: + self.db.execute( + "UPDATE users SET email=%s, password=%s, name=%s", + (user.email, user.password, user.name) + ) + except IntegrityError as e: + print("Updating user:", e) return False - users[user.email] = user - self.db.put("all", users) - def remove(self, user: str): - if self.get(user) is None: - return False - users = self.get_all() - del users[user] - self.db.put("all", users) + def remove(self, pk: int): + self.db.execute( + "DELETE FROM users WHERE id=%s", + (pk,) + ) diff --git a/server/src/core/user/user_service.py b/server/src/core/user/user_service.py index 5fa37a4..d5e7da3 100644 --- a/server/src/core/user/user_service.py +++ b/server/src/core/user/user_service.py @@ -1,5 +1,7 @@ from typing import List +from psycopg2 import IntegrityError + from core.app.app_dao import AppDao from core.user.user_dao import UserDao, User @@ -8,11 +10,15 @@ class UserService: def __init__(self, user_storage: UserDao, app_dao: AppDao): self.app_dao = app_dao self.dao = user_storage - if self.dao.get_all() == {}: + try: self.add_user(User(name="Alex", password="123", email="admin@gmail.com")) + self.dao.db.conn.commit() + except IntegrityError as e: + print("IGNORED EXCEPTION", e) # TODO: Remove this line. + self.dao.db.conn.rollback() def authenticate_user(self, email, password): - user = self.dao.get(email) + user = self.dao.get_by_email(email) if user is None: return None if user.password == password: @@ -20,6 +26,7 @@ def authenticate_user(self, email, password): return None + # TODO: wrong responsibility def exists_app(self, email: str, app_key: str) -> bool: app = self.app_dao.get_by_id(email, app_key) if app is None: @@ -27,11 +34,12 @@ def exists_app(self, email: str, app_key: str) -> bool: return True def get_user_by_email(self, email): - return self.dao.get(email) + return self.dao.get_by_email(email) def add_user(self, user: User): self.dao.add(user) + # TODO: Should be removed. def get_all_users(self) -> List[User]: users = self.dao.get_all().values() for user in users: