Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling users through the DB users table #52

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions script.sql
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@

create table users
(
id serial primary key,
email varchar(250) null,
password varchar(50) null,
name varchar(50) null
);

create table chat_messages
(
id serial
primary key,
user_id varchar(250),
id serial primary key,
user_ref int not null,
chatbot_id varchar(50),
message text,
is_bot_reply boolean,
createdat timestamp default (now() AT TIME ZONE 'UTC'::text)
createdat timestamp default (now() AT TIME ZONE 'UTC'::text),
CONSTRAINT fk_user
FOREIGN KEY(user_ref)
REFERENCES users(id)
);

alter table chat_messages
Expand All @@ -20,6 +28,5 @@ create index idx_timestamp
create index idx_chatbot_id
on chat_messages (chatbot_id);

create index idx_user_id
on chat_messages (user_id);

create unique index idx_user_email
on users (email);
4 changes: 2 additions & 2 deletions server/src/api/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def delete_app(app_key: str, current_user: User = Depends(get_current_user)):
@r.get("/applications/{app_key}/history")
def get_app_conversation(app_key: str, user: User = Depends(get_current_user)):
if app_key == "chat":
return chat_history.get_latest_messages(user.email, app_key)
return chat_history.get_latest_messages(user.pk, app_key)

for a in apps.get_by_user_email(user.email):
if app_key == a.app_key:
return chat_history.get_latest_messages(user.email, app_key)
return chat_history.get_latest_messages(user.pk, app_key)

raise HTTPException(status_code=404, detail="App not found")

Expand Down
2 changes: 1 addition & 1 deletion server/src/api/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
chat_history = factory_chat_history(pg_conn)
agent = agent_factory(chat_history, cost_service, llm_service)
apps = AppDao()
user_service = UserService(UserDao(), apps)
user_service = UserService(UserDao(pg_conn), apps)
current_session = {}
6 changes: 3 additions & 3 deletions server/src/core/agent/agent_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def is_action(self, req: UserInputDto) -> bool:
async def handle_user_input(self, req: UserInputDto) -> dict:
current_user = req.user
add_message_dto = AddMessageDto(
user_email=current_user.email,
user_ref=req.user.pk,
app_key=req.app.app_key,
session_id=req.session_id,
message=MessageCompletion(
Expand All @@ -76,7 +76,7 @@ async def handle_user_input(self, req: UserInputDto) -> dict:
message = llm_resp.message

add_message_dto = AddMessageDto(
user_email=current_user.email,
user_ref=req.user.pk,
app_key=req.app.app_key,
session_id=req.session_id,
message=MessageCompletion(
Expand All @@ -96,7 +96,7 @@ async def handle_user_input(self, req: UserInputDto) -> dict:

def user_history_process(self, req: UserInputDto) -> dict:
add_message_dto = AddMessageDto(
user_email=req.user.email,
user_ref=req.user.pk,
app_key=req.app.app_key,
session_id=req.session_id,
message=MessageCompletion(
Expand Down
10 changes: 10 additions & 0 deletions server/src/core/common/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,13 @@ def fetch_all(self, query: str, params=None) -> List[Dict]:
for row in rows:
result.append(dict(zip(cols, row)))
return result

def fetch_one(self, query: str, params=None) -> Dict:
self.cursor.execute(query, params)
row = self.cursor.fetchone()
if row is None:
return None
# Convert rows to list of dictionaries so they're easier to work with
cols = [desc[0] for desc in self.cursor.description]
result = []
return dict(zip(cols, row))
14 changes: 7 additions & 7 deletions server/src/core/history/chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class AddMessageDto(BaseModel):
user_email: str
user_ref: int
app_key: str
session_id: str
message: MessageCompletion
Expand All @@ -21,7 +21,7 @@ def __init__(self, dao: HistoryDao):
self.history = CacheMemory(30)

def add_message(self, req: AddMessageDto):
user_email = req.user_email
user_ref = req.user_ref
app_key = req.app_key
session_id = req.session_id

Expand All @@ -32,18 +32,18 @@ def add_message(self, req: AddMessageDto):
history.append(req.message)
self.history.put(session_id, history)

self.persist_message(user_email, app_key, req.message)
self.persist_message(user_ref, app_key, req.message)

def get_history(self, key):
return self.history.get(key)

def persist_message(self, user_email, app_key, message):
def persist_message(self, user_ref, app_key, message):
is_bot_replay = message.role == MessageRole.ASSISTANT
msg = message.response if is_bot_replay else message.query
self.dao.persist_message(user_email, app_key, msg, is_bot_replay)
self.dao.persist_message(user_ref, app_key, msg, is_bot_replay)

def get_latest_messages(self, user_email: str, app_key: str, page: int) -> List[Dict]:
return self.dao.get_latest_messages(user_email, app_key, page)
def get_latest_messages(self, user_ref: int, app_key: str, page: int) -> List[Dict]:
return self.dao.get_latest_messages(user_ref, app_key, page)


def factory_chat_history(pg_conn: DBConnection):
Expand Down
12 changes: 6 additions & 6 deletions server/src/core/history/history_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@ class HistoryDao:
def __init__(self, db: DBConnection):
self.db = db

def persist_message(self, user_email, app_key, msg, is_bot_replay):
def persist_message(self, user_ref, app_key, msg, is_bot_replay):
insert_query = sql.SQL(
"""
INSERT INTO chat_messages(user_id, chatbot_id, message, is_bot_reply)
INSERT INTO chat_messages(user_ref, chatbot_id, message, is_bot_reply)
VALUES(%s, %s, %s, %s)
"""
)
self.db.execute(insert_query, (user_email, app_key, msg, is_bot_replay))
self.db.execute(insert_query, (user_ref, app_key, msg, is_bot_replay))

def get_latest_messages(self, user_email: str, app_key: str) -> List[Dict]:
def get_latest_messages(self, user_ref: str, app_key: str) -> List[Dict]:
return self.db.fetch_all(
"""
SELECT * FROM chat_messages where chatbot_id = %s and user_id = %s
SELECT * FROM chat_messages where chatbot_id = %s and user_ref = %s
ORDER BY createdat DESC
LIMIT 50
""",
(app_key, user_email),
(app_key, user_ref),
)
57 changes: 31 additions & 26 deletions server/src/core/user/user_dao.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,53 @@
from typing import Dict
from typing import Dict, Optional

from psycopg2 import IntegrityError

from pydantic import BaseModel

from core.common.file_db import FileDB


class User(BaseModel):
pk: Optional[int] = None
email: str
password: str
password: str # TODO: encrypted password
name: str = None


class UserDao:
def __init__(self):
self.db = FileDB('./file_db/users')
def __init__(self, pg_conn):
self.db = pg_conn

def get_all(self) -> Dict[str, User]:
users = self.db.get("all")
if users is None:
return {}
return users
users = self.db.fetch_all("SELECT * FROM users")
return [User(**u) for u in users]

def get(self, user) -> User:
return self.get_all().get(user)
def get_by_email(self, email: str) -> User:
user = self.db.fetch_one("SELECT * FROM users WHERE email=%s", (email,))
return User(**user)

def add(self, user: User):
if self.get(user.email) is not None:
try:
self.db.execute(
"INSERT INTO users (email, password, name) VALUES (%s, %s, %s)",
(user.email, user.password, user.name)
)
except IntegrityError as e:
print("Inserting user:", e)
return False

users = self.get_all()
users[user.email] = user

self.db.put("all", users)

def edit(self, user: User):
users = self.get_all()
if users.get(user.email) is None:
try:
self.db.execute(
"UPDATE users SET email=%s, password=%s, name=%s",
(user.email, user.password, user.name)
)
except IntegrityError as e:
print("Updating user:", e)
return False
users[user.email] = user
self.db.put("all", users)

def remove(self, user: str):
if self.get(user) is None:
return False
users = self.get_all()
del users[user]
self.db.put("all", users)
def remove(self, pk: int):
self.db.execute(
"DELETE FROM users WHERE id=%s",
(pk,)
)
14 changes: 11 additions & 3 deletions server/src/core/user/user_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List

from psycopg2 import IntegrityError

from core.app.app_dao import AppDao
from core.user.user_dao import UserDao, User

Expand All @@ -8,30 +10,36 @@ class UserService:
def __init__(self, user_storage: UserDao, app_dao: AppDao):
self.app_dao = app_dao
self.dao = user_storage
if self.dao.get_all() == {}:
try:
self.add_user(User(name="Alex", password="123", email="admin@gmail.com"))
self.dao.db.conn.commit()
except IntegrityError as e:
print("IGNORED EXCEPTION", e) # TODO: Remove this line.
self.dao.db.conn.rollback()

def authenticate_user(self, email, password):
user = self.dao.get(email)
user = self.dao.get_by_email(email)
if user is None:
return None
if user.password == password:
return user

return None

# TODO: wrong responsibility
def exists_app(self, email: str, app_key: str) -> bool:
app = self.app_dao.get_by_id(email, app_key)
if app is None:
return False
return True

def get_user_by_email(self, email):
return self.dao.get(email)
return self.dao.get_by_email(email)

def add_user(self, user: User):
self.dao.add(user)

# TODO: Should be removed.
def get_all_users(self) -> List[User]:
users = self.dao.get_all().values()
for user in users:
Expand Down