Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: handle OasstError in OasstApiClient #300

Merged
merged 11 commits into from
Jan 3, 2023
7 changes: 6 additions & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@
@app.exception_handler(OasstError)
async def oasst_exception_handler(request: fastapi.Request, ex: OasstError):
logger.error(f"{request.method} {request.url} failed: {repr(ex)}")

return fastapi.responses.JSONResponse(
status_code=int(ex.http_status_code), content={"message": ex.message, "error_code": ex.error_code}
status_code=int(ex.http_status_code),
content=protocol_schema.OasstErrorResponse(
message=ex.message,
error_code=OasstErrorCode(ex.error_code),
).dict(),
)


Expand Down
2 changes: 0 additions & 2 deletions discord-bot/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
aiohttp # http client
aiohttp[speedups] # speedups for aiohttp
aiosqlite # database
hikari # discord framework
hikari-lightbulb # command handler
Expand Down
41 changes: 37 additions & 4 deletions oasst-shared/oasst_shared/api_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""API Client for interacting with the OASST backend."""
import enum
import typing as t
from http import HTTPStatus
from typing import Optional, Type
from uuid import UUID

import aiohttp
from loguru import logger
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import ValidationError


# TODO: Move to `protocol`?
Expand All @@ -27,16 +30,20 @@ class TaskType(str, enum.Enum):
class OasstApiClient:
"""API Client for interacting with the OASST backend."""

def __init__(self, backend_url: str, api_key: str):
def __init__(self, backend_url: str, api_key: str, session: Optional[aiohttp.ClientSession] = None):
"""Create a new OasstApiClient.

Args:
----
backend_url (str): The base backend URL.
api_key (str): The API key to use for authentication.
"""
logger.debug("Opening OasstApiClient session")
self.session = aiohttp.ClientSession()

if session is None:
logger.debug("Opening OasstApiClient session")
session = aiohttp.ClientSession()

self.session = session
self.backend_url = backend_url
self.api_key = api_key

Expand All @@ -56,7 +63,33 @@ async def post(self, path: str, data: dict[str, t.Any]) -> Optional[dict[str, t.
"""Make a POST request to the backend."""
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
response.raise_for_status()

# If the response is not a 2XX, check to see
# if the json has the fields to create an
# OasstError.
if response.status >= 300:
data = await response.json()
try:
oasst_error = protocol_schema.OasstErrorResponse(**(data or {}))
raise OasstError(
error_code=oasst_error.error_code,
message=oasst_error.message,
)
except ValidationError as e:
logger.debug(f"Got error from API but could not parse: {e}")

raw_response = await response.text()
logger.debug(f"Raw response: {raw_response}")

raise OasstError(
raw_response,
OasstErrorCode.GENERIC_ERROR,
HTTPStatus(response.status),
)

if response.status == 204:
# No content
return None
return await response.json()

def _parse_task(self, data: Optional[dict[str, t.Any]]) -> protocol_schema.Task:
Expand Down
8 changes: 8 additions & 0 deletions oasst-shared/oasst_shared/schemas/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from uuid import UUID, uuid4

import pydantic
from oasst_shared.exceptions import OasstErrorCode
from pydantic import BaseModel, Field


Expand Down Expand Up @@ -293,3 +294,10 @@ class UserScore(BaseModel):

class LeaderboardStats(BaseModel):
leaderboard: List[UserScore]


class OasstErrorResponse(BaseModel):
"""The format of an error response from the OASST API."""

error_code: OasstErrorCode
message: str
2 changes: 2 additions & 0 deletions oasst-shared/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@
author="OASST Team",
install_requires=[
"pydantic==1.9.1",
"aiohttp==3.8.3",
"aiohttp[speedups]",
],
)
77 changes: 77 additions & 0 deletions oasst-shared/tests/test_oasst_api_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from typing import Any
from unittest import mock
from uuid import uuid4

import aiohttp
import pytest
from oasst_shared.api_client import OasstApiClient
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema


@pytest.fixture
def oasst_api_client_mocked():
"""
A an oasst_api_client pointed at the mocked backend.
Relies on ./scripts/backend-development/start-mock-server.sh
being run.
"""
client = OasstApiClient(backend_url="http://localhost:8080", api_key="123")
yield client
# TODO The fixture should close this connection, but there seems to be a bug
Expand All @@ -15,6 +24,30 @@ def oasst_api_client_mocked():
# await client.close()


class MockClientSession(aiohttp.ClientSession):
response: Any

def set_response(self, response: Any):
self.response = response

async def post(self, *args, **kwargs):
return self.response


@pytest.fixture
def mock_http_session():
yield MockClientSession()


@pytest.fixture
def oasst_api_client_fake_http(mock_http_session):
"""
An oasst_api_client that uses a mocked http session. No real requests are made.
"""
client = OasstApiClient(backend_url="http://localhost:8080", api_key="123", session=mock_http_session)
yield client


@pytest.mark.asyncio
@pytest.mark.parametrize("task_type", protocol_schema.TaskRequestType)
async def test_can_fetch_task(task_type: protocol_schema.TaskRequestType, oasst_api_client_mocked: OasstApiClient):
Expand Down Expand Up @@ -49,3 +82,47 @@ async def test_can_post_interaction(oasst_api_client_mocked: OasstApiClient):
)
is not None
)


@pytest.mark.asyncio
async def test_can_handle_oasst_error_from_api(
oasst_api_client_fake_http: OasstApiClient,
mock_http_session: MockClientSession,
):
# Return a 400 response with an OasstErrorResponse body
response_body = protocol_schema.OasstErrorResponse(
error_code=OasstErrorCode.GENERIC_ERROR,
message="Some error",
)
status_code = 400

mock_http_session.set_response(
mock.AsyncMock(
status=status_code,
text=mock.AsyncMock(return_value=response_body.json()),
json=mock.AsyncMock(return_value=response_body.dict()),
)
)

with pytest.raises(OasstError):
await oasst_api_client_fake_http.post("/some-path", data={})


@pytest.mark.asyncio
async def test_can_handle_unknown_error_from_api(
oasst_api_client_fake_http: OasstApiClient,
mock_http_session: MockClientSession,
):
response_body = "Internal Server Error"
status_code = 500

mock_http_session.set_response(
mock.AsyncMock(
status=status_code,
text=mock.AsyncMock(return_value=response_body),
json=mock.AsyncMock(return_value=None),
)
)

with pytest.raises(OasstError):
await oasst_api_client_fake_http.post("/some-path", data={})
2 changes: 2 additions & 0 deletions scripts/oasst-shared-development/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
# switch to backend directory
pushd "$parent_path/../../oasst-shared"

set -xe

pytest .

popd