In [1]:
import os
import asyncio
from datetime import datetime
from fastapi import FastAPI, HTTPException, Request, Depends,status
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlmodel import SQLModel, Field, Session, create_engine, select
import openai
from packaging import version
import functions
from typing import Annotated

# Check OpenAI version is correct
required_version = version.parse("1.1.1")
current_version = version.parse(openai.__version__)
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')  # Use getenv to avoid KeyError if variable is not set
if current_version < required_version:
    raise ValueError(f"Error: OpenAI version {openai.__version__} is less than the required version 1.1.1")
else:
    print("OpenAI version is compatible.")

# Database model for logging queries and responses
class QueryLog(SQLModel, table=True):
    id: int = Field(default=None, primary_key=True)
    user_query: str
    assistant_response: str
    timestamp: datetime = Field(default_factory=datetime.utcnow)

class Location(SQLModel, table=True):
    name: str = Field(index=True, primary_key=True)
    location: str

# Database setup
database_url = "DATABASE_URL"  # Replace with your actual database URL
engine = create_engine(database_url)

def create_db_and_tables():
    SQLModel.metadata.create_all(engine)

# @asynccontextmanager
def lifespan(app: FastAPI):
    create_db_and_tables()


OpenAI version is compatible.


In [13]:
app = FastAPI()

# @app.on_event("startup")
# def on_startup():
#     create_db_and_tables()

# Placeholder types for message content
class ChatRequest(BaseModel):
    thread_id: str  
    message: str

class ChatResponse(BaseModel):
    response: str

# Init OpenAI client
client = openai.OpenAI(api_key=OPENAI_API_KEY)

# Create new assistant or load existing
assistant_id = functions.create_assistant(client)

Loaded existing assistant ID.


In [3]:
async def start_conversation():
    thread = client.beta.threads.create()
    thread_id = thread.id
    print({"thread_id": thread_id})
    return {"thread_id": thread_id}

In [4]:
def read_all_persons():
    """
    Retrieves all persons from the database.

    Returns:
        list: A list of Location objects representing the persons.
    """
    with Session(engine) as session:
        loc_data = session.exec(select(Location)).all()
        return loc_data

In [5]:
def create_person(person_data: Location):
    """
    Creates a new person record in the database.

    Args:
        person_data (Location): name and location of person. 

    Returns:
        Location: The created person record that is name and location of person. 
    """
    with Session(engine) as session:
        session.add(person_data)
        session.commit()
        session.refresh(person_data)
        return person_data

In [6]:
def get_location_or_404(name:str)->Location:
    with Session(engine) as session:
        loc_data = session.exec(select(Location).where(Location.name == name)).first()
        if not loc_data:
            raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"No location found for {name}")
        return loc_data
    
def get_person_location(name: str, location: Annotated[Location, Depends(get_location_or_404)]):
    """
    Retrieve the location of a person by their name.

    Args:
        name (str): The name of the person.

    Returns:
        Location: The location of the person.
    """
    print(f"Fetching location for {name}")
    
    print(f"Retrieved location data: {location}")
    return location

In [9]:
async def chat(request: ChatRequest):
    thread_id = request.thread_id
    user_input = request.message
    
    if not thread_id:
        return JSONResponse(content={"Error": "Missing thread id"}, status_code=400)

    # Send the user message to the assistant thread
    client.beta.threads.messages.create(thread_id=thread_id, role="user", content="share me Harry's location")
    run = client.beta.threads.runs.create(thread_id=thread_id, assistant_id=assistant_id)
    dict(run)

In [10]:
from openai.types.beta.thread import Thread

thread: Thread  = client.beta.threads.create()

print(thread)

Thread(id='thread_bbjiO05ZL2MXogaAaDhd8ZkY', created_at=1708804320, metadata={}, object='thread')


In [12]:
from openai.types.beta.threads.thread_message import ThreadMessage

# First Request
message = client.beta.threads.messages.create(
    thread_id=thread.id,
    role="user",
    content="share the location of zia?"
)
dict(message)

{'id': 'msg_vzNgX6c8y2d1n2rwEgurTNwn',
 'assistant_id': None,
 'content': [MessageContentText(text=Text(annotations=[], value='share the location of zia?'), type='text')],
 'created_at': 1708804359,
 'file_ids': [],
 'metadata': {},
 'object': 'thread.message',
 'role': 'user',
 'run_id': None,
 'thread_id': 'thread_bbjiO05ZL2MXogaAaDhd8ZkY'}

In [15]:
from openai.types.beta.threads.run import Run

run: Run = client.beta.threads.runs.create(
  thread_id=thread.id,
  assistant_id="asst_NhaXiIu3IJjKIQ9g7AxNyY66"
)
dict(run)

{'id': 'run_o0EoP3Otfkzw4OwQH4QiYOBp',
 'assistant_id': 'asst_NhaXiIu3IJjKIQ9g7AxNyY66',
 'cancelled_at': None,
 'completed_at': None,
 'created_at': 1708804430,
 'expires_at': 1708805030,
 'failed_at': None,
 'file_ids': [],
 'instructions': '\n              The assistant will be responsible for communicating with the database to share locations of friends\n              ',
 'last_error': None,
 'metadata': {},
 'model': 'gpt-3.5-turbo-0125',
 'object': 'thread.run',
 'required_action': None,
 'started_at': None,
 'status': 'queued',
 'thread_id': 'thread_bbjiO05ZL2MXogaAaDhd8ZkY',
 'tools': [ToolAssistantToolsFunction(function=FunctionDefinition(name='create_person', description='Create a new person record with name and location', parameters={'type': 'object', 'properties': {'name': {'type': 'string', 'description': 'The name of the person'}, 'location': {'type': 'string', 'description': 'The location of the person'}}, 'required': ['name', 'location']}), type='function'),
  ToolAssis

In [16]:
available_functions = {
    "read_all_persons": read_all_persons,
    "create_person": create_person,
    "get_person_location": get_person_location
} 

In [17]:
thread.id

'thread_bbjiO05ZL2MXogaAaDhd8ZkY'

In [21]:
# import time

#   # Loop until the run completes or requires action
# while True:
#     runStatus = client.beta.threads.runs.retrieve(thread_id=thread.id,
#                                                   run_id=run.id)
#     # Add run steps retrieval here for debuging
#     run_steps = client.beta.threads.runs.steps.list(thread_id=thread.id, run_id=run.id)
#     # show_json("Run Steps:", run_steps)
#     print(runStatus.status ,'.....')

#     # This means run is making a function call   
#     if runStatus.status == "requires_action":
#         print(runStatus.status ,'.....')
#         print("Status: ", "requires_action")
#         show_json("submit_tool_outputs", runStatus.required_action)
#         if runStatus.required_action.submit_tool_outputs and runStatus.required_action.submit_tool_outputs.tool_calls:
#             print("toolCalls present:")
#             toolCalls = runStatus.required_action.submit_tool_outputs.tool_calls

#             tool_outputs = []
#             for toolcall in toolCalls:
#                 function_name = toolcall.function.name
#                 function_args = json.loads(toolcall.function.arguments)
                
#                 if function_name in available_functions:
                    
                    
#                     function_to_call = available_functions[function_name]
#                     print(function_to_call,function_to_call.__name__=="read_all_persons","================================================================")
                  
#                     if function_to_call.__name__ == "read_all_persons":
                        
#                         response = function_to_call
                        
                        
#                         tool_outputs.append({
#                                   "tool_call_id": toolcall.id,
#                                   "output": response
#                               })
#                     elif function_to_call.__name__ == "get_person_location":
#                         response = function_to_call(
#                           name=function_args.get("name")
#                           )
#                         tool_outputs.append({
#                           "tool_call_id": toolcall.id,
#                           "output": response,
#                               })
#                     elif function_to_call.__name__ == "create_person":
#                         response = function_to_call(
#                           name=function_args.get("name")
#                           location=function_args.get("location")
#                           )
#                         tool_outputs.append({
#                           "tool_call_id": toolcall.id,
#                           "output": response,
#                               })
#             print(tool_outputs,">>>>>") 
#             # Submit tool outputs and update the run
#             client.beta.threads.runs.submit_tool_outputs(
#                 thread_id=thread.id,
#                 run_id=run.id,
#                 tool_outputs=tool_outputs)
      
#     elif runStatus.status == "completed":
#         # List the messages to get the response
#         print("completed...........logic")
#         messages: list[ThreadMessage] = client.beta.threads.messages.list(thread_id=thread.id)
#         for message in messages.data:
#             role_label = "User" if message.role == "user" else "Assistant"
#             message_content = message.content[0].text.value
#             print(f"{role_label}: {message_content}\n")
#         break  # Exit the loop after processing the completed run

#     elif run.status == "failed":
#       print("Run failed.")
#       break

#     elif run.status in ["in_progress", "queued"]:
#       print(f"Run is {run.status}. Waiting...")
#       time.sleep(5)  # Wait for 5 seconds before checking again

#     else:
#       print(f"Unexpected status: {run.status}")
#       break