Skip to content

Commit

Permalink
Websocket server test
Browse files Browse the repository at this point in the history
  • Loading branch information
KillianLucas committed Jun 1, 2024
1 parent 9b38776 commit c7dd11f
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 27 deletions.
52 changes: 25 additions & 27 deletions interpreter/core/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
import traceback
from typing import Any, Dict, List

import uvicorn
from fastapi import FastAPI, Header, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from openai import OpenAI
from pydantic import BaseModel
from uvicorn import Config, Server

# import argparse
# from profiles.default import interpreter
Expand Down Expand Up @@ -63,8 +63,6 @@ def __init__(self, interpreter):
# engine = OpenAIEngine()
# self.tts = TextToAudioStream(engine)

self.active_chat_messages = []

# Clock
# clock()

Expand All @@ -82,7 +80,9 @@ def __init__(self, interpreter):
False # Tracks whether interpreter is trying to use the keyboard
)

self.loop = asyncio.get_event_loop()
# print("oksskk")
# self.loop = asyncio.get_event_loop()
# print("okkk")

async def _add_to_queue(self, queue, item):
print(f"Adding item to output", item)
Expand Down Expand Up @@ -134,7 +134,6 @@ async def run(self):
Runs OI on the audio bytes submitted to the input. Will add streaming LMC chunks to the _output_queue.
"""
print("heyyyy")
self.interpreter.messages = self.active_chat_messages
# interpreter.messages = self.active_chat_messages
# self.beeper.start()

Expand All @@ -147,10 +146,8 @@ async def run(self):

def generate(message):
last_lmc_start_flag = self._last_lmc_start_flag
self.interpreter.messages = self.active_chat_messages
# interpreter.messages = self.active_chat_messages
print("🍀🍀🍀🍀GENERATING, using these messages: ", self.interpreter.messages)
print("🍀 🍀 🍀 🍀 active_chat_messages: ", self.active_chat_messages)
print("passing this in:", message)
for chunk in self.interpreter.chat(message, display=False, stream=True):
print("FROM INTERPRETER. CHUNK:", chunk)
Expand All @@ -165,7 +162,10 @@ def generate(message):

# Handle message blocks
if chunk.get("type") == "message":
self.add_to_output_queue_sync(chunk) # To send text, not just audio
self.add_to_output_queue_sync(
chunk.copy()
) # To send text, not just audio
# ^^^^^^^ MUST be a copy, otherwise the first chunk will get modified by OI >>while<< it's in the queue. Insane
if content:
# self.beeper.stop()

Expand Down Expand Up @@ -216,8 +216,7 @@ async def output(self):


def server(interpreter):
interpreter.llm.model = "gpt-4"
interpreter = AsyncInterpreter(interpreter)
async_interpreter = AsyncInterpreter(interpreter)

app = FastAPI()
app.add_middleware(
Expand All @@ -228,18 +227,12 @@ def server(interpreter):
allow_headers=["*"], # Allow all headers
)

@app.post("/load")
async def load(messages: List[Dict[str, Any]], settings: Settings):
# Load messages
interpreter.interpreter.messages = messages
print("🪼🪼🪼🪼🪼🪼 Messages loaded: ", interpreter.interpreter.messages)

# Load Settings
interpreter.interpreter.llm.model = settings.model
interpreter.interpreter.llm.custom_instructions = settings.custom_instructions
interpreter.interpreter.auto_run = settings.auto_run

interpreter.interpreter.llm.api_key = "<openai_key>"
@app.post("/settings")
async def settings(payload: Dict[str, Any]):
for key, value in payload.items():
print("Updating interpreter settings with the following:")
print(key, value)
setattr(async_interpreter.interpreter, key, value)

return {"status": "success"}

Expand All @@ -253,13 +246,16 @@ async def receive_input():
data = await websocket.receive()
print(data)
if isinstance(data, bytes):
await interpreter.input(data)
else:
await interpreter.input(data["text"])
await async_interpreter.input(data)
elif "text" in data:
await async_interpreter.input(data["text"])
elif data == {"type": "websocket.disconnect", "code": 1000}:
print("Websocket disconnected with code 1000.")
break

async def send_output():
while True:
output = await interpreter.output()
output = await async_interpreter.output()
if isinstance(output, bytes):
# await websocket.send_bytes(output)
# we dont send out bytes rn, no TTS
Expand Down Expand Up @@ -306,4 +302,6 @@ async def rename_chat(body_content: Rename, x_api_key: str = Header(None)):
traceback.print_exc()
return {"error": str(e)}

uvicorn.run(app, host="0.0.0.0", port=8000)
config = Config(app, host="0.0.0.0", port=8000)
interpreter.uvicorn_server = Server(config)
interpreter.uvicorn_server.run()
139 changes: 139 additions & 0 deletions tests/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,145 @@
from websocket import create_connection


def test_server():
# Start the server in a new thread
server_thread = threading.Thread(target=interpreter.server)
server_thread.start()

# Give the server a moment to start
time.sleep(8)

import asyncio
import json

import requests
import websockets

async def test_fastapi_server():
import asyncio

async with websockets.connect("ws://localhost:8000/ws") as websocket:
# Connect to the websocket
print("Connected to WebSocket")

# Sending POST request
post_url = "http://localhost:8000/settings"
settings = {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"type": "message",
"content": "The secret word is 'crunk'.",
},
{"role": "assistant", "type": "message", "content": "Understood."},
],
"custom_instructions": "",
"auto_run": True,
}
response = requests.post(post_url, json=settings)
print("POST request sent, response:", response.json())

# Sending messages via WebSocket
await websocket.send(
json.dumps({"role": "user", "type": "message", "start": True})
)
await websocket.send(
json.dumps(
{
"role": "user",
"type": "message",
"content": "What's the secret word?",
}
)
)
await websocket.send(
json.dumps({"role": "user", "type": "message", "end": True})
)
print("WebSocket chunks sent")

# Wait for a specific response
accumulated_content = ""
while True:
message = await websocket.recv()
message_data = json.loads(message)
print("Received from WebSocket:", message_data)
if message_data.get("content"):
accumulated_content += message_data.get("content")
if message_data == {
"role": "server",
"type": "completion",
"content": "DONE",
}:
print("Received expected message from server")
break

assert "crunk" in accumulated_content

# Send another POST request
post_url = "http://localhost:8000/settings"
settings = {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"type": "message",
"content": "The secret word is 'barlony'.",
},
{"role": "assistant", "type": "message", "content": "Understood."},
],
"custom_instructions": "",
"auto_run": True,
}
response = requests.post(post_url, json=settings)
print("POST request sent, response:", response.json())

# Sending messages via WebSocket
await websocket.send(
json.dumps({"role": "user", "type": "message", "start": True})
)
await websocket.send(
json.dumps(
{
"role": "user",
"type": "message",
"content": "What's the secret word?",
}
)
)
await websocket.send(
json.dumps({"role": "user", "type": "message", "end": True})
)
print("WebSocket chunks sent")

# Wait for a specific response
while True:
message = await websocket.recv()
message_data = json.loads(message)
print("Received from WebSocket:", message_data)
if message_data.get("content"):
accumulated_content += message_data.get("content")
if message_data == {
"role": "server",
"type": "completion",
"content": "DONE",
}:
print("Received expected message from server")
break

assert "barlony" in accumulated_content

# Get the current event loop and run the test function
loop = asyncio.get_event_loop()
loop.run_until_complete(test_fastapi_server())

# Stop the server
interpreter.uvicorn_server.should_exit = True

# Wait for the server thread to finish
server_thread.join(timeout=1)


@pytest.mark.skip(reason="Requires open-interpreter[local]")
def test_localos():
interpreter.computer.emit_images = False
Expand Down

0 comments on commit c7dd11f

Please sign in to comment.