diff --git a/tests/test_crud.py b/tests/test_crud.py index 264396a..ad634eb 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -28,3 +28,16 @@ def test_get_interactions(db): assert interactions[1].settings["model_name"] == "model2" +def test_get_interaction(db): + timestamp = utils.convert_timezone(datetime.now()) + interaction = models.Interaction( + settings=dict(model_name="model", role="role", prompt="prompt"), + created_at=timestamp, + updated_at=timestamp, + ) + db.add(interaction) + db.commit() + + retrieved_interaction = crud.get_interaction(db, interaction.id) + assert retrieved_interaction.id == interaction.id + assert retrieved_interaction.settings["model_name"] == "model"