diff --git a/ifsguid/crud.py b/ifsguid/crud.py index edc0f47..e965484 100644 --- a/ifsguid/crud.py +++ b/ifsguid/crud.py @@ -39,7 +39,7 @@ async def create_interaction( return interaction -def delete_interaction(db: Session, id: UUID) -> None: +async def delete_interaction(db: AsyncSession, id: UUID) -> None: interaction = ( db.query(models.Interaction).filter(models.Interaction.id == id).first() ) @@ -52,8 +52,8 @@ def delete_interaction(db: Session, id: UUID) -> None: return False -def update_interaction( - db: Session, id: UUID, settings: schemas.Settings +async def update_interaction( + db: AsyncSession, id: UUID, settings: schemas.Settings ) -> models.Interaction: interaction: models.Interaction = ( db.query(models.Interaction).filter(models.Interaction.id == id).first() @@ -67,8 +67,8 @@ def update_interaction( return None -def get_messages( - db: Session, interaction_id: UUID = None, page: int = None, per_page: int = 10 +async def get_messages( + db: AsyncSession, interaction_id: UUID = None, page: int = None, per_page: int = 10 ) -> List[models.Message]: query = db.query(models.Message) @@ -81,8 +81,8 @@ def get_messages( return query.all() -def create_message( - db: Session, messages: List[schemas.MessageCreate], interaction_id: UUID +async def create_message( + db: AsyncSession, messages: List[schemas.MessageCreate], interaction_id: UUID ) -> List[models.Message]: messages_db = [] for msg in messages: @@ -93,5 +93,5 @@ def create_message( db.add(message) messages_db.append(message) - db.commit() + await db.commit() return messages_db diff --git a/ifsguid/endpoints.py b/ifsguid/endpoints.py index bc1c15d..2329884 100644 --- a/ifsguid/endpoints.py +++ b/ifsguid/endpoints.py @@ -2,7 +2,7 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy.orm import Session +from g4f import ModelUtils from . import crud, schemas, modules from .database import async_session, AsyncSession @@ -48,7 +48,7 @@ async def get_interactions( async def create_interactions( prompt: schemas.Prompt, chat_model: schemas.ChatModel = Depends(), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), ) -> schemas.Interaction: settings = schemas.Settings( model=chat_model.model, prompt=prompt.prompt, role=prompt.role @@ -59,7 +59,7 @@ async def create_interactions( @router.delete("/interactions", response_model=Dict[str, Any], include_in_schema=False) -async def delete_interaction(id: UUID, db: Session = Depends(get_db)) -> None: +async def delete_interaction(id: UUID, db: AsyncSession = Depends(get_db)) -> None: raise HTTPException( status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError" ) @@ -69,7 +69,7 @@ async def delete_interaction(id: UUID, db: Session = Depends(get_db)) -> None: "/interactions/{id}", response_model=schemas.Interaction, include_in_schema=False ) async def update_interaction( - id: UUID, settings: schemas.Settings, db: Session = Depends(get_db) + id: UUID, settings: schemas.Settings, db: AsyncSession = Depends(get_db) ) -> schemas.Interaction: raise HTTPException( status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError" @@ -83,7 +83,7 @@ async def get_all_message_in_interaction( interaction_id: UUID, page: Optional[int] = None, per_page: Optional[int] = None, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_db), ) -> List[schemas.Message]: interaction = crud.get_interaction(db=db, id=str(interaction_id)) @@ -104,9 +104,11 @@ async def get_all_message_in_interaction( "/interactions/{interactions_id}/messages", response_model=List[schemas.Message] ) async def create_message( - interaction_id: UUID, message: schemas.MessageCreate, db: Session = Depends(get_db) + interaction_id: UUID, + message: schemas.MessageCreate, + db: AsyncSession = Depends(get_db), ) -> schemas.Message: - interaction = crud.get_interaction(db=db, id=str(interaction_id)) + interaction = await crud.get_interaction(db=db, id=str(interaction_id)) if not interaction: raise HTTPException( @@ -117,17 +119,16 @@ async def create_message( messages = [] if message.role == "human": - ai_content = modules.generate_ai_response( - content=message.content, model=interaction.settings.model_name + ai_content = await modules.generate_ai_response( + content=message.content, + model=ModelUtils.convert[interaction.settings.model], ) ai_message = schemas.MessageCreate(role="ai", content=ai_content) messages.append(message) messages.append(ai_message) - return [ - schemas.Message.model_validate(message) - for message in crud.create_message( - db=db, messages=messages, interaction_id=str(interaction_id) - ) - ] + messages = await crud.create_message( + db=db, messages=messages, interaction_id=str(interaction_id) + ) + return [schemas.Message.model_validate(message) for message in messages] diff --git a/ifsguid/schemas.py b/ifsguid/schemas.py index fe2d31c..c005748 100644 --- a/ifsguid/schemas.py +++ b/ifsguid/schemas.py @@ -28,6 +28,7 @@ class Prompt(BaseModel): role: Literal["System"] = "System" prompt: str + class Settings(ChatModel, Prompt): pass