Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: message generation stability #90

Merged
merged 5 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend/app/operators/assistant/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ async def create(self, create_dict: Dict, **kwargs) -> ModelEntity:
chat = await self._create_entity(conn, create_dict, **kwargs)

# update assistant num_chats
assistant_ops.update(
await assistant_ops.update(
postgres_conn=conn,
assistant=assistant,
assistant_id=assistant_id,
update_dict={"num_chats": assistant.num_chats + 1},
)

Expand Down
8 changes: 7 additions & 1 deletion backend/app/routes/manage/manage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from tkhelper.schemas.base import BaseEmptyResponse, BaseDataResponse
from app.config import CONFIG
from app.services.auth.admin import create_default_admin_if_needed
Expand All @@ -17,6 +17,12 @@
response_model=BaseEmptyResponse,
)
async def api_health_check():
from app.database import redis_conn, postgres_pool

if not await redis_conn.health_check():
raise HTTPException(status_code=500, detail="Redis health check failed.")
if not await postgres_pool.health_check():
raise HTTPException(status_code=500, detail="Postgres health check failed.")
return BaseEmptyResponse()


Expand Down
6 changes: 3 additions & 3 deletions backend/app/services/assistant/generation/normal_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,17 @@ async def generate(self, system_prompt_variables: Dict):
break

message = await self.create_assistant_message(chat_completion_assistant_message["content"])
await self.chat.unlock()
return BaseDataResponse(data=message.to_response_dict())

except MessageGenerationException as e:
await self.chat.unlock()
logger.error(f"NormalSession.generate: MessageGenerationException error = {e}")
raise_http_error(ErrorCode.GENERATION_ERROR, message=str(e))

except Exception as e:
await self.chat.unlock()
logger.error(f"NormalSession.generate: Exception error = {e}")
raise_http_error(
ErrorCode.INTERNAL_SERVER_ERROR, message=str("Assistant message not generated due to an unknown error.")
)

finally:
await self.chat.unlock()
5 changes: 4 additions & 1 deletion backend/app/services/assistant/generation/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ async def prepare(self, stream: bool, system_prompt_variables: Dict, retrieval_l
raise MessageGenerationException(f"Chat {self.chat.chat_id} is locked. Please try again later.")

# 1. Get model
self.model = await get_model(self.assistant.model_id)
try:
self.model = await get_model(self.assistant.model_id)
except Exception as e:
raise MessageGenerationException(f"Failed to load model {self.assistant.model_id}.")

# 2. model streaming
if not self.model.allow_streaming() and stream:
Expand Down
13 changes: 7 additions & 6 deletions backend/app/services/inference/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,11 @@ async def chat_completion_stream(
if buffer.endswith("\n\n"):
lines = buffer.strip().split("\n")
event_data = lines[0][len("data: ") :]
try:
data = json.loads(event_data)
yield data
except json.decoder.JSONDecodeError:
print("JSONDecodeError")
continue
if event_data != "[DONE]":
try:
data = json.loads(event_data)
yield data
except json.decoder.JSONDecodeError:
logger.error(f"Failed to parse json: {event_data}")
continue
buffer = ""
2 changes: 1 addition & 1 deletion backend/app/services/model/model_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def sync_model_schema_data():
model_schema.type
for model_schema in model_schemas
if model_schema.provider_id == provider_data["provider_id"]
and model_schema.type != ModelType.WILDCARD
# and model_schema.type != ModelType.WILDCARD
]
)
)
Expand Down
39 changes: 39 additions & 0 deletions backend/tkhelper/database/postgres/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,42 @@ async def clean_data(self):
logger.info(f"Postgres database {self.db_name} clean done.")

await self._migration_if_needed()

# -- log connection info --

async def log_connection_info(self):
if self.db_pool is None:
logger.warning(f"log_connection_info: Postgres database {self.db_name} pool is not initialized")
return

# get server connections
async with self.get_db_connection() as conn:
server_connections = await conn.fetchval("SELECT COUNT(*) FROM pg_stat_activity;")
server_max_connections = await conn.fetchval(
"SELECT setting FROM pg_settings WHERE name=$1;", "max_connections"
)

# get client connections
client_connections = self.db_pool.get_size()
max_client_connections = self.db_pool.get_max_size()

logger.info(
f"db[{self.db_name}]: "
f"client pool size = {client_connections}/{max_client_connections}, "
f"server connections = {server_connections}/{server_max_connections}"
)

# -- health check --
async def health_check(self) -> bool:
"""Check if postgres database is healthy"""
if self.db_pool is None:
logger.error("Postgres health check failed: db pool is not initialized")
return False

try:
async with self.get_db_connection() as conn:
await conn.fetchval("SELECT 1")
return True
except Exception as e:
logger.error(f"Postgres[{self.db_name}]] health check failed: error={e}")
return False
102 changes: 84 additions & 18 deletions backend/tkhelper/database/redis/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import aioredis
from aioredis import Redis
import asyncio
import json
from typing import Dict, Optional
import logging
Expand All @@ -8,22 +8,77 @@


class RedisConnection:
_instance = None
_lock = asyncio.Lock()

def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self, url: str):
self.url = url
self.redis: Redis = None
if not hasattr(self, "initialized"):
self.redis: Optional[aioredis.Redis] = None
self.initialized = False
self.health_check_failures = 0

# -- connection management --

async def init(self):
self.redis = await aioredis.from_url(self.url)
await self.redis.config_set("maxmemory-policy", "allkeys-lru")
logger.info("Set redis maxmemory-policy to allkeys-lru")
logger.info("Redis pool initialized.")
async with self._lock:
if not self.initialized or self.redis is None:
if self.redis is not None:
await self.redis.close()
self.redis = await aioredis.from_url(self.url)
await self.redis.config_set("maxmemory-policy", "allkeys-lru")
logger.info("Set redis maxmemory-policy to allkeys-lru")
logger.info("Redis pool initialized or reinitialized.")
self.initialized = True
self.health_check_failures = 0

async def close(self):
if self.redis is not None:
await self.redis.close()
logger.info("Redis pool closed.")
async with self._lock:
if self.redis is not None and self.initialized:
await self.redis.close()
self.redis = None
self.initialized = False
self.health_check_failures = 0
logger.info("Redis pool closed.")

# -- health check --

async def restart_redis(self):
await self.close() # close Redis connection
await self.init() # restart Redis connection
logger.info("Redis client has been restarted.")

async def health_check(self):
if self.redis is None:
self.health_check_failures += 1
logger.error(f"Redis health check failed: redis is not initialized, failures={self.health_check_failures}")
return False
try:
pong = await self.redis.ping()
if pong:
self.health_check_failures = 0
return True
else:
self.health_check_failures += 1
logger.error(f"Redis health check failed: did not receive PONG., failures={self.health_check_failures}")
except asyncio.CancelledError:
self.health_check_failures += 1
logger.error(f"Redis health check failed: operation was cancelled, failures={self.health_check_failures}")
except Exception as e:
self.health_check_failures += 1
logger.error(f"Redis health check failed: error={e}, failures={self.health_check_failures}")

if self.health_check_failures > 10:
logger.warning("Redis health check failed 10 times, attempting to restart Redis client.")
await self.restart_redis()
self.health_check_failures = 0

return False

# -- clean --

Expand All @@ -38,32 +93,32 @@ async def set_int(self, key: str, value: int, expire: int = 3600 * 4):
if self.redis is None:
return
try:
await self.redis.set(key, value)
if expire:
await self.redis.expire(key, expire)
await self.redis.set(key, value, ex=expire)
logger.debug(f"set_int: key={key}, value={value}")
except asyncio.CancelledError:
logger.error(f"get_object: operation was cancelled, key={key}")
except Exception as e:
logger.error(f"set_int: error={e}")

async def set_object(self, key: str, value: Dict, expire: int = 3600 * 4):
if self.redis is None:
return
try:
await self.redis.set(key, json.dumps(value))
if expire:
await self.redis.expire(key, expire)
await self.redis.set(key, json.dumps(value), ex=expire)
logger.debug(f"set_object: key={key}, value={value}")
except asyncio.CancelledError:
logger.error(f"get_object: operation was cancelled, key={key}")
except Exception as e:
logger.error(f"set_object: error={e}")

async def set_string(self, key: str, value: str, expire: int = 3600 * 4):
if self.redis is None:
return
try:
await self.redis.set(key, value)
if expire:
await self.redis.expire(key, expire)
await self.redis.set(key, value, ex=expire)
logger.debug(f"set_string: key={key}, value={value}")
except asyncio.CancelledError:
logger.error(f"get_object: operation was cancelled, key={key}")
except Exception as e:
logger.error(f"set_string: error={e}")

Expand All @@ -73,6 +128,8 @@ async def pop(self, key: str):
try:
await self.redis.delete(key)
logger.debug(f"pop: key={key}")
except asyncio.CancelledError:
logger.error(f"get_object: operation was cancelled, key={key}")
except Exception as e:
logger.error(f"pop: error={e}")

Expand All @@ -83,6 +140,9 @@ async def get_string(self, key: str):
value_string = await self.redis.get(key)
logger.debug(f"get_string: key={key}, value={value_string}")
return value_string
except asyncio.CancelledError:
logger.error(f"get_object: operation was cancelled, key={key}")
return None
except Exception as e:
logger.error(f"get_string: error={e}")
return None
Expand All @@ -95,6 +155,9 @@ async def get_object(self, key: str) -> Optional[Dict]:
logger.debug(f"get_object: key={key}, value={value_string}")
if value_string:
return json.loads(value_string)
except asyncio.CancelledError:
logger.error(f"get_object: operation was cancelled, key={key}")
return None
except Exception as e:
logger.error(f"get_object: error={e}")
return None
Expand All @@ -107,6 +170,9 @@ async def get_int(self, key: str):
logger.debug(f"get_int: key={key}, value={value}")
if value:
return int(value)
except asyncio.CancelledError:
logger.error(f"get_object: operation was cancelled, key={key}")
return None
except Exception as e:
logger.error(f"get_int: error={e}")
return None