Skip to content

Commit

Permalink
Add session IDs to table
Browse files Browse the repository at this point in the history
  • Loading branch information
pycui committed Jul 19, 2023
1 parent 5c8533d commit c367ea7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
24 changes: 24 additions & 0 deletions alembic/versions/3821f7adaca9_add_session_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Add session ID
Revision ID: 3821f7adaca9
Revises: 27fe156a6d72
Create Date: 2023-07-18 22:44:33.107380
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '3821f7adaca9'
down_revision = '27fe156a6d72'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column('interactions', sa.Column('session_id', sa.String(50), nullable=True))


def downgrade() -> None:
op.drop_column('interactions', 'session_id')
1 change: 1 addition & 0 deletions realtime_ai_character/models/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class Interaction(Base):

id = Column(Integer, primary_key=True, index=True, nullable=False)
client_id = Column(Integer)
session_id = Column(String(50))
client_message = Column(String) # deprecated, use client_message_unicode instead
server_message = Column(String) # deprecated, use server_message_unicode instead
client_message_unicode = Column(Unicode(65535))
Expand Down
7 changes: 6 additions & 1 deletion realtime_ai_character/websocket_routes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import os
import uuid

from fastapi import APIRouter, Depends, Path, WebSocket, WebSocketDisconnect, Query
from requests import Session
Expand Down Expand Up @@ -64,13 +65,15 @@ async def handle_receive(
text_to_speech: TextToSpeech):
try:
conversation_history = ConversationHistory()
session_id = str(uuid.uuid4().hex)

# 0. Receive client platform info (web, mobile, terminal)
data = await websocket.receive()
if data['type'] != 'websocket.receive':
raise WebSocketDisconnect('disconnected')
platform = data['text']
logger.info(f"Client #{client_id}:{platform} connected to server")
logger.info(f"Client #{client_id}:{platform} connected to server with "
f"session_id {session_id}")

# 1. User selected a character
character = None
Expand Down Expand Up @@ -164,6 +167,7 @@ async def stop_audio():
# 4. Persist interaction in the database
Interaction(
client_id=client_id,
session_id=session_id,
client_message_unicode=msg_data,
server_message_unicode=response,
platform=platform,
Expand Down Expand Up @@ -199,6 +203,7 @@ async def tts_task_done_call_back(response):
# Persist interaction in the database
Interaction(
client_id=client_id,
session_id=session_id,
client_message_unicode=transcript,
server_message_unicode=response,
platform=platform,
Expand Down

0 comments on commit c367ea7

Please sign in to comment.