diff --git a/backend/app/operators/assistant/chat.py b/backend/app/operators/assistant/chat.py index 92dbf643..d1e8725a 100644 --- a/backend/app/operators/assistant/chat.py +++ b/backend/app/operators/assistant/chat.py @@ -32,9 +32,9 @@ async def create(self, create_dict: Dict, **kwargs) -> ModelEntity: chat = await self._create_entity(conn, create_dict, **kwargs) # update assistant num_chats - assistant_ops.update( + await assistant_ops.update( postgres_conn=conn, - assistant=assistant, + assistant_id=assistant_id, update_dict={"num_chats": assistant.num_chats + 1}, ) diff --git a/backend/app/routes/manage/manage.py b/backend/app/routes/manage/manage.py index 8b5a651d..1d03640c 100644 --- a/backend/app/routes/manage/manage.py +++ b/backend/app/routes/manage/manage.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from tkhelper.schemas.base import BaseEmptyResponse, BaseDataResponse from app.config import CONFIG from app.services.auth.admin import create_default_admin_if_needed @@ -17,6 +17,12 @@ response_model=BaseEmptyResponse, ) async def api_health_check(): + from app.database import redis_conn, postgres_pool + + if not await redis_conn.health_check(): + raise HTTPException(status_code=500, detail="Redis health check failed.") + if not await postgres_pool.health_check(): + raise HTTPException(status_code=500, detail="Postgres health check failed.") return BaseEmptyResponse() diff --git a/backend/app/services/assistant/generation/normal_session.py b/backend/app/services/assistant/generation/normal_session.py index 99805e62..c17b06d7 100644 --- a/backend/app/services/assistant/generation/normal_session.py +++ b/backend/app/services/assistant/generation/normal_session.py @@ -53,17 +53,17 @@ async def generate(self, system_prompt_variables: Dict): break message = await self.create_assistant_message(chat_completion_assistant_message["content"]) + await self.chat.unlock() return BaseDataResponse(data=message.to_response_dict()) except MessageGenerationException as e: + await self.chat.unlock() logger.error(f"NormalSession.generate: MessageGenerationException error = {e}") raise_http_error(ErrorCode.GENERATION_ERROR, message=str(e)) except Exception as e: + await self.chat.unlock() logger.error(f"NormalSession.generate: Exception error = {e}") raise_http_error( ErrorCode.INTERNAL_SERVER_ERROR, message=str("Assistant message not generated due to an unknown error.") ) - - finally: - await self.chat.unlock() diff --git a/backend/app/services/assistant/generation/session.py b/backend/app/services/assistant/generation/session.py index 4fb9c1fb..bb15aadd 100644 --- a/backend/app/services/assistant/generation/session.py +++ b/backend/app/services/assistant/generation/session.py @@ -58,7 +58,10 @@ async def prepare(self, stream: bool, system_prompt_variables: Dict, retrieval_l raise MessageGenerationException(f"Chat {self.chat.chat_id} is locked. Please try again later.") # 1. Get model - self.model = await get_model(self.assistant.model_id) + try: + self.model = await get_model(self.assistant.model_id) + except Exception as e: + raise MessageGenerationException(f"Failed to load model {self.assistant.model_id}.") # 2. model streaming if not self.model.allow_streaming() and stream: diff --git a/backend/app/services/inference/chat_completion.py b/backend/app/services/inference/chat_completion.py index bbf339aa..44d296ca 100644 --- a/backend/app/services/inference/chat_completion.py +++ b/backend/app/services/inference/chat_completion.py @@ -75,10 +75,11 @@ async def chat_completion_stream( if buffer.endswith("\n\n"): lines = buffer.strip().split("\n") event_data = lines[0][len("data: ") :] - try: - data = json.loads(event_data) - yield data - except json.decoder.JSONDecodeError: - print("JSONDecodeError") - continue + if event_data != "[DONE]": + try: + data = json.loads(event_data) + yield data + except json.decoder.JSONDecodeError: + logger.error(f"Failed to parse json: {event_data}") + continue buffer = "" diff --git a/backend/app/services/model/model_schema.py b/backend/app/services/model/model_schema.py index 0bc9466d..eccfa4b5 100644 --- a/backend/app/services/model/model_schema.py +++ b/backend/app/services/model/model_schema.py @@ -78,7 +78,7 @@ async def sync_model_schema_data(): model_schema.type for model_schema in model_schemas if model_schema.provider_id == provider_data["provider_id"] - and model_schema.type != ModelType.WILDCARD + # and model_schema.type != ModelType.WILDCARD ] ) ) diff --git a/backend/tkhelper/database/postgres/pool.py b/backend/tkhelper/database/postgres/pool.py index 1e32404f..c826a45f 100644 --- a/backend/tkhelper/database/postgres/pool.py +++ b/backend/tkhelper/database/postgres/pool.py @@ -100,3 +100,42 @@ async def clean_data(self): logger.info(f"Postgres database {self.db_name} clean done.") await self._migration_if_needed() + + # -- log connection info -- + + async def log_connection_info(self): + if self.db_pool is None: + logger.warning(f"log_connection_info: Postgres database {self.db_name} pool is not initialized") + return + + # get server connections + async with self.get_db_connection() as conn: + server_connections = await conn.fetchval("SELECT COUNT(*) FROM pg_stat_activity;") + server_max_connections = await conn.fetchval( + "SELECT setting FROM pg_settings WHERE name=$1;", "max_connections" + ) + + # get client connections + client_connections = self.db_pool.get_size() + max_client_connections = self.db_pool.get_max_size() + + logger.info( + f"db[{self.db_name}]: " + f"client pool size = {client_connections}/{max_client_connections}, " + f"server connections = {server_connections}/{server_max_connections}" + ) + + # -- health check -- + async def health_check(self) -> bool: + """Check if postgres database is healthy""" + if self.db_pool is None: + logger.error("Postgres health check failed: db pool is not initialized") + return False + + try: + async with self.get_db_connection() as conn: + await conn.fetchval("SELECT 1") + return True + except Exception as e: + logger.error(f"Postgres[{self.db_name}]] health check failed: error={e}") + return False diff --git a/backend/tkhelper/database/redis/connection.py b/backend/tkhelper/database/redis/connection.py index 98fdea44..02751be5 100644 --- a/backend/tkhelper/database/redis/connection.py +++ b/backend/tkhelper/database/redis/connection.py @@ -1,5 +1,5 @@ import aioredis -from aioredis import Redis +import asyncio import json from typing import Dict, Optional import logging @@ -8,22 +8,77 @@ class RedisConnection: + _instance = None + _lock = asyncio.Lock() + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + def __init__(self, url: str): self.url = url - self.redis: Redis = None + if not hasattr(self, "initialized"): + self.redis: Optional[aioredis.Redis] = None + self.initialized = False + self.health_check_failures = 0 # -- connection management -- async def init(self): - self.redis = await aioredis.from_url(self.url) - await self.redis.config_set("maxmemory-policy", "allkeys-lru") - logger.info("Set redis maxmemory-policy to allkeys-lru") - logger.info("Redis pool initialized.") + async with self._lock: + if not self.initialized or self.redis is None: + if self.redis is not None: + await self.redis.close() + self.redis = await aioredis.from_url(self.url) + await self.redis.config_set("maxmemory-policy", "allkeys-lru") + logger.info("Set redis maxmemory-policy to allkeys-lru") + logger.info("Redis pool initialized or reinitialized.") + self.initialized = True + self.health_check_failures = 0 async def close(self): - if self.redis is not None: - await self.redis.close() - logger.info("Redis pool closed.") + async with self._lock: + if self.redis is not None and self.initialized: + await self.redis.close() + self.redis = None + self.initialized = False + self.health_check_failures = 0 + logger.info("Redis pool closed.") + + # -- health check -- + + async def restart_redis(self): + await self.close() # close Redis connection + await self.init() # restart Redis connection + logger.info("Redis client has been restarted.") + + async def health_check(self): + if self.redis is None: + self.health_check_failures += 1 + logger.error(f"Redis health check failed: redis is not initialized, failures={self.health_check_failures}") + return False + try: + pong = await self.redis.ping() + if pong: + self.health_check_failures = 0 + return True + else: + self.health_check_failures += 1 + logger.error(f"Redis health check failed: did not receive PONG., failures={self.health_check_failures}") + except asyncio.CancelledError: + self.health_check_failures += 1 + logger.error(f"Redis health check failed: operation was cancelled, failures={self.health_check_failures}") + except Exception as e: + self.health_check_failures += 1 + logger.error(f"Redis health check failed: error={e}, failures={self.health_check_failures}") + + if self.health_check_failures > 10: + logger.warning("Redis health check failed 10 times, attempting to restart Redis client.") + await self.restart_redis() + self.health_check_failures = 0 + + return False # -- clean -- @@ -38,10 +93,10 @@ async def set_int(self, key: str, value: int, expire: int = 3600 * 4): if self.redis is None: return try: - await self.redis.set(key, value) - if expire: - await self.redis.expire(key, expire) + await self.redis.set(key, value, ex=expire) logger.debug(f"set_int: key={key}, value={value}") + except asyncio.CancelledError: + logger.error(f"get_object: operation was cancelled, key={key}") except Exception as e: logger.error(f"set_int: error={e}") @@ -49,10 +104,10 @@ async def set_object(self, key: str, value: Dict, expire: int = 3600 * 4): if self.redis is None: return try: - await self.redis.set(key, json.dumps(value)) - if expire: - await self.redis.expire(key, expire) + await self.redis.set(key, json.dumps(value), ex=expire) logger.debug(f"set_object: key={key}, value={value}") + except asyncio.CancelledError: + logger.error(f"get_object: operation was cancelled, key={key}") except Exception as e: logger.error(f"set_object: error={e}") @@ -60,10 +115,10 @@ async def set_string(self, key: str, value: str, expire: int = 3600 * 4): if self.redis is None: return try: - await self.redis.set(key, value) - if expire: - await self.redis.expire(key, expire) + await self.redis.set(key, value, ex=expire) logger.debug(f"set_string: key={key}, value={value}") + except asyncio.CancelledError: + logger.error(f"get_object: operation was cancelled, key={key}") except Exception as e: logger.error(f"set_string: error={e}") @@ -73,6 +128,8 @@ async def pop(self, key: str): try: await self.redis.delete(key) logger.debug(f"pop: key={key}") + except asyncio.CancelledError: + logger.error(f"get_object: operation was cancelled, key={key}") except Exception as e: logger.error(f"pop: error={e}") @@ -83,6 +140,9 @@ async def get_string(self, key: str): value_string = await self.redis.get(key) logger.debug(f"get_string: key={key}, value={value_string}") return value_string + except asyncio.CancelledError: + logger.error(f"get_object: operation was cancelled, key={key}") + return None except Exception as e: logger.error(f"get_string: error={e}") return None @@ -95,6 +155,9 @@ async def get_object(self, key: str) -> Optional[Dict]: logger.debug(f"get_object: key={key}, value={value_string}") if value_string: return json.loads(value_string) + except asyncio.CancelledError: + logger.error(f"get_object: operation was cancelled, key={key}") + return None except Exception as e: logger.error(f"get_object: error={e}") return None @@ -107,6 +170,9 @@ async def get_int(self, key: str): logger.debug(f"get_int: key={key}, value={value}") if value: return int(value) + except asyncio.CancelledError: + logger.error(f"get_object: operation was cancelled, key={key}") + return None except Exception as e: logger.error(f"get_int: error={e}") return None