Skip to content

Commit

Permalink
♻️ refactor to use pytest_asyncio for the fixture
Browse files Browse the repository at this point in the history
  • Loading branch information
agn-7 committed Dec 1, 2023
1 parent 5424607 commit 3db4423
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 47 deletions.
8 changes: 5 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import pytest
import pytest_asyncio

from sqlalchemy.ext.asyncio import AsyncSession

from . import client


@pytest.fixture
async def db():
@pytest_asyncio.fixture()
async def db() -> AsyncSession:
async with client.async_session() as session:
await client.create_tables()
yield session
Expand Down
48 changes: 23 additions & 25 deletions tests/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,30 @@

@pytest.mark.asyncio
async def test_get_interactions(db):
async for db in db: # TODO
interaction1 = models.Interaction(
settings=dict(model="model1", role="role1", prompt="prompt1"),
)
interaction2 = models.Interaction(
settings=dict(model="model2", role="role2", prompt="prompt2"),
)
db.add(interaction1)
db.add(interaction2)
await db.commit()

interactions = await crud.get_interactions(db)
assert len(interactions) == 2
assert interactions[0].settings["model"] == "model1"
assert interactions[1].settings["model"] == "model2"
interaction1 = models.Interaction(
settings=dict(model="model1", role="role1", prompt="prompt1"),
)
interaction2 = models.Interaction(
settings=dict(model="model2", role="role2", prompt="prompt2"),
)
db.add(interaction1)
db.add(interaction2)
await db.commit()

interactions = await crud.get_interactions(db)
assert len(interactions) == 2
assert interactions[0].settings["model"] == "model1"
assert interactions[1].settings["model"] == "model2"


@pytest.mark.asyncio
async def test_get_interaction(db):
async for db in db: # TODO
interaction = models.Interaction(
settings=dict(model="model", role="role", prompt="prompt"),
)
db.add(interaction)
await db.commit()

retrieved_interaction = await crud.get_interaction(db, interaction.id)
assert retrieved_interaction.id == interaction.id
assert retrieved_interaction.settings["model"] == "model"
interaction = models.Interaction(
settings=dict(model="model", role="role", prompt="prompt"),
)
db.add(interaction)
await db.commit()

retrieved_interaction = await crud.get_interaction(db, interaction.id)
assert retrieved_interaction.id == interaction.id
assert retrieved_interaction.settings["model"] == "model"
33 changes: 14 additions & 19 deletions tests/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,24 @@ async def test_get_all_interactions(db):
interaction2 = models.Interaction(settings={"prompt": "something else"})
db.add(interaction1)
db.add(interaction2)
yield db.commit() # TODO
await db.commit()

response = client.client.get("/api/interactions")
assert response.status_code == 200
assert len(response.json()) == 2


@pytest.mark.asyncio
async def test_create_interaction():
async with client.async_session() as db:
try:
await client.create_tables()
response = client.client.post(
"/api/interactions",
json={
"prompt": "something",
},
)
assert response.status_code == 200
assert response.json()["settings"] == {
"prompt": "something",
"model": "gpt-3.5-turbo",
"role": "System",
}
finally:
await client.drop_tables()
async def test_create_interaction(db):
response = client.client.post(
"/api/interactions",
json={
"prompt": "something",
},
)
assert response.status_code == 200
assert response.json()["settings"] == {
"prompt": "something",
"model": "gpt-3.5-turbo",
"role": "System",
}

0 comments on commit 3db4423

Please sign in to comment.