Skip to content

Commit

Permalink
⚡️ Switch to async database operations
Browse files Browse the repository at this point in the history
  • Loading branch information
agn-7 committed Nov 29, 2023
1 parent 2089145 commit ce3d184
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 23 deletions.
16 changes: 8 additions & 8 deletions ifsguid/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def create_interaction(
return interaction


def delete_interaction(db: Session, id: UUID) -> None:
async def delete_interaction(db: AsyncSession, id: UUID) -> None:
interaction = (
db.query(models.Interaction).filter(models.Interaction.id == id).first()
)
Expand All @@ -52,8 +52,8 @@ def delete_interaction(db: Session, id: UUID) -> None:
return False


def update_interaction(
db: Session, id: UUID, settings: schemas.Settings
async def update_interaction(
db: AsyncSession, id: UUID, settings: schemas.Settings
) -> models.Interaction:
interaction: models.Interaction = (
db.query(models.Interaction).filter(models.Interaction.id == id).first()
Expand All @@ -67,8 +67,8 @@ def update_interaction(
return None


def get_messages(
db: Session, interaction_id: UUID = None, page: int = None, per_page: int = 10
async def get_messages(
db: AsyncSession, interaction_id: UUID = None, page: int = None, per_page: int = 10
) -> List[models.Message]:
query = db.query(models.Message)

Expand All @@ -81,8 +81,8 @@ def get_messages(
return query.all()


def create_message(
db: Session, messages: List[schemas.MessageCreate], interaction_id: UUID
async def create_message(
db: AsyncSession, messages: List[schemas.MessageCreate], interaction_id: UUID
) -> List[models.Message]:
messages_db = []
for msg in messages:
Expand All @@ -93,5 +93,5 @@ def create_message(
db.add(message)
messages_db.append(message)

db.commit()
await db.commit()
return messages_db
31 changes: 16 additions & 15 deletions ifsguid/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from g4f import ModelUtils

from . import crud, schemas, modules
from .database import async_session, AsyncSession
Expand Down Expand Up @@ -48,7 +48,7 @@ async def get_interactions(
async def create_interactions(
prompt: schemas.Prompt,
chat_model: schemas.ChatModel = Depends(),
db: AsyncSession = Depends(get_db)
db: AsyncSession = Depends(get_db),
) -> schemas.Interaction:
settings = schemas.Settings(
model=chat_model.model, prompt=prompt.prompt, role=prompt.role
Expand All @@ -59,7 +59,7 @@ async def create_interactions(


@router.delete("/interactions", response_model=Dict[str, Any], include_in_schema=False)
async def delete_interaction(id: UUID, db: Session = Depends(get_db)) -> None:
async def delete_interaction(id: UUID, db: AsyncSession = Depends(get_db)) -> None:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError"
)
Expand All @@ -69,7 +69,7 @@ async def delete_interaction(id: UUID, db: Session = Depends(get_db)) -> None:
"/interactions/{id}", response_model=schemas.Interaction, include_in_schema=False
)
async def update_interaction(
id: UUID, settings: schemas.Settings, db: Session = Depends(get_db)
id: UUID, settings: schemas.Settings, db: AsyncSession = Depends(get_db)
) -> schemas.Interaction:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError"
Expand All @@ -83,7 +83,7 @@ async def get_all_message_in_interaction(
interaction_id: UUID,
page: Optional[int] = None,
per_page: Optional[int] = None,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
) -> List[schemas.Message]:
interaction = crud.get_interaction(db=db, id=str(interaction_id))

Expand All @@ -104,9 +104,11 @@ async def get_all_message_in_interaction(
"/interactions/{interactions_id}/messages", response_model=List[schemas.Message]
)
async def create_message(
interaction_id: UUID, message: schemas.MessageCreate, db: Session = Depends(get_db)
interaction_id: UUID,
message: schemas.MessageCreate,
db: AsyncSession = Depends(get_db),
) -> schemas.Message:
interaction = crud.get_interaction(db=db, id=str(interaction_id))
interaction = await crud.get_interaction(db=db, id=str(interaction_id))

if not interaction:
raise HTTPException(
Expand All @@ -117,17 +119,16 @@ async def create_message(

messages = []
if message.role == "human":
ai_content = modules.generate_ai_response(
content=message.content, model=interaction.settings.model_name
ai_content = await modules.generate_ai_response(
content=message.content,
model=ModelUtils.convert[interaction.settings.model],
)
ai_message = schemas.MessageCreate(role="ai", content=ai_content)

messages.append(message)
messages.append(ai_message)

return [
schemas.Message.model_validate(message)
for message in crud.create_message(
db=db, messages=messages, interaction_id=str(interaction_id)
)
]
messages = await crud.create_message(
db=db, messages=messages, interaction_id=str(interaction_id)
)
return [schemas.Message.model_validate(message) for message in messages]
1 change: 1 addition & 0 deletions ifsguid/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Prompt(BaseModel):
role: Literal["System"] = "System"
prompt: str


class Settings(ChatModel, Prompt):
pass

Expand Down

0 comments on commit ce3d184

Please sign in to comment.