Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,20 @@ class LitQAv2TaskSplit(StrEnum):
EVAL = "eval"
TEST = "test"

def get_index(self) -> int:
"""
Get the index of the train (0), eval (1), or test (2) split.

NOTE: the value matches the index in read_litqa_v2_from_hub's returned splits.
"""
if self == self.TRAIN:
return 0
if self == self.EVAL:
return 1
if self == self.TEST:
return 2
assert_never(self) # type: ignore[arg-type]


class LitQAv2TaskDataset(LitQATaskDataset):
"""Task dataset of LitQA v2 questions."""
Expand All @@ -425,18 +439,10 @@ def __init__(
**kwargs,
):
super().__init__(*args, **kwargs)
train_df, eval_df, test_df = read_litqa_v2_from_hub(
split_dfs = read_litqa_v2_from_hub(
train_eval_dataset, test_dataset, **(read_data_kwargs or {})
)
split = LitQAv2TaskSplit(split)
if split == LitQAv2TaskSplit.TRAIN:
self.data = train_df
elif split == LitQAv2TaskSplit.EVAL:
self.data = eval_df
elif split == LitQAv2TaskSplit.TEST:
self.data = test_df
else:
assert_never(split)
self.data = split_dfs[LitQAv2TaskSplit(split).get_index()]

def get_new_env_by_idx(self, idx: int) -> GradablePaperQAEnvironment:
sources = []
Expand Down
10 changes: 8 additions & 2 deletions tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from collections.abc import Iterable
from copy import deepcopy
from typing import ClassVar
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -95,6 +96,7 @@ async def before_rollout(self, traj_id: str, env) -> None: # noqa: ARG002


class TestTaskDataset:
EXPECTED_LENGTHS: ClassVar[tuple[int, ...]] = (159, 40, 49)

@pytest.mark.parametrize(
("split", "expected_length"),
Expand All @@ -107,7 +109,7 @@ class TestTaskDataset:
@pytest.mark.asyncio
async def test___len__(
self,
split: str | LitQAv2TaskSplit,
split: LitQAv2TaskSplit,
expected_length: int,
base_query_request: QueryRequest,
) -> None:
Expand All @@ -117,7 +119,11 @@ async def test___len__(
read_data_kwargs={"seed": 42},
split=split,
)
assert len(task_dataset) == expected_length
assert (
len(task_dataset)
== expected_length
== self.EXPECTED_LENGTHS[split.get_index()]
)

# Now let's check we could use the sources in a validation
for i in range(len(task_dataset)):
Expand Down
Loading