diff --git a/paperqa/agents/task.py b/paperqa/agents/task.py index c2033e1bf..bebae1868 100644 --- a/paperqa/agents/task.py +++ b/paperqa/agents/task.py @@ -86,6 +86,8 @@ def __init__( ) self._evaluation_callback = evaluation_callback self._rewards = rewards + self.answer = "" + self.ideal = "" async def validate_sources( self, manifest_or_index: dict[str, DocDetails] | SearchIndex | None = None @@ -141,10 +143,19 @@ async def step( evaluation = await self._evaluation_from_answer(self.state.session.answer) if evaluation_callback := self._evaluation_callback: await evaluation_callback(evaluation) + self.answer = evaluation.answer or "" + self.ideal = evaluation.ideal or "" return messages, reward + self._rewards[evaluation.value], done, truncated def export_frame(self) -> Frame: - raise NotImplementedError("Didn't yet need to export a frame.") + return Frame( + state=self.state, + info={ + "query": self._query, + "answer": self.answer, + "ideal": self.ideal, + }, + ) def __deepcopy__(self, memo) -> Self: copy_state = deepcopy(self.state, memo)