From 4d29d76bc81c4115f743867424700a828dae17e0 Mon Sep 17 00:00:00 2001 From: James Braza Date: Thu, 9 Jan 2025 16:40:49 -0800 Subject: [PATCH] Added get_index to LitQAv2TaskSplit for convenience --- paperqa/agents/task.py | 26 ++++++++++++++++---------- tests/test_task.py | 10 ++++++++-- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/paperqa/agents/task.py b/paperqa/agents/task.py index 39acf9eca..4d9f2285d 100644 --- a/paperqa/agents/task.py +++ b/paperqa/agents/task.py @@ -411,6 +411,20 @@ class LitQAv2TaskSplit(StrEnum): EVAL = "eval" TEST = "test" + def get_index(self) -> int: + """ + Get the index of the train (0), eval (1), or test (2) split. + + NOTE: the value matches the index in read_litqa_v2_from_hub's returned splits. + """ + if self == self.TRAIN: + return 0 + if self == self.EVAL: + return 1 + if self == self.TEST: + return 2 + assert_never(self) # type: ignore[arg-type] + class LitQAv2TaskDataset(LitQATaskDataset): """Task dataset of LitQA v2 questions.""" @@ -425,18 +439,10 @@ def __init__( **kwargs, ): super().__init__(*args, **kwargs) - train_df, eval_df, test_df = read_litqa_v2_from_hub( + split_dfs = read_litqa_v2_from_hub( train_eval_dataset, test_dataset, **(read_data_kwargs or {}) ) - split = LitQAv2TaskSplit(split) - if split == LitQAv2TaskSplit.TRAIN: - self.data = train_df - elif split == LitQAv2TaskSplit.EVAL: - self.data = eval_df - elif split == LitQAv2TaskSplit.TEST: - self.data = test_df - else: - assert_never(split) + self.data = split_dfs[LitQAv2TaskSplit(split).get_index()] def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment: sources = [] diff --git a/tests/test_task.py b/tests/test_task.py index af845c66e..523fd9464 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -1,6 +1,7 @@ import asyncio from collections.abc import Iterable from copy import deepcopy +from typing import ClassVar from unittest.mock import patch import pytest @@ -95,6 +96,7 @@ async def before_rollout(self, traj_id: str, env) -> None: # noqa: ARG002 class TestTaskDataset: + EXPECTED_LENGTHS: ClassVar[tuple[int, ...]] = (159, 40, 49) @pytest.mark.parametrize( ("split", "expected_length"), @@ -107,7 +109,7 @@ class TestTaskDataset: @pytest.mark.asyncio async def test___len__( self, - split: str | LitQAv2TaskSplit, + split: LitQAv2TaskSplit, expected_length: int, base_query_request: QueryRequest, ) -> None: @@ -117,7 +119,11 @@ async def test___len__( read_data_kwargs={"seed": 42}, split=split, ) - assert len(task_dataset) == expected_length + assert ( + len(task_dataset) + == expected_length + == self.EXPECTED_LENGTHS[split.get_index()] + ) # Now let's check we could use the sources in a validation for i in range(len(task_dataset)):