Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ jobs:
enable-cache: true
- run: uv python pin ${{ matrix.python-version }}
- run: uv sync --python-preference=only-managed
- name: Login to Hugging Face Hub
run: uv run --no-project huggingface-cli login --token $HUGGINGFACE_HUB_ACCESS_TOKEN
env:
HUGGINGFACE_HUB_ACCESS_TOKEN: ${{ secrets.HUGGINGFACE_HUB_ACCESS_TOKEN }}
- name: Cache datasets
uses: actions/cache@v4
with:
Expand Down
11 changes: 8 additions & 3 deletions paperqa/agents/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from paperqa.docs import Docs
from paperqa.litqa import (
DEFAULT_AVIARY_PAPER_HF_HUB_NAME,
DEFAULT_LABBENCH_HF_HUB_NAME,
DEFAULT_REWARD_MAPPING,
read_litqa_v2_from_hub,
Expand Down Expand Up @@ -408,6 +409,7 @@ def compute_trajectory_metrics(
class LitQAv2TaskSplit(StrEnum):
TRAIN = "train"
EVAL = "eval"
TEST = "test"


class LitQAv2TaskDataset(LitQATaskDataset):
Expand All @@ -416,20 +418,23 @@ class LitQAv2TaskDataset(LitQATaskDataset):
def __init__(
self,
*args,
labbench_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
train_eval_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
test_dataset: str = DEFAULT_AVIARY_PAPER_HF_HUB_NAME,
read_data_kwargs: Mapping[str, Any] | None = None,
split: str | LitQAv2TaskSplit = LitQAv2TaskSplit.EVAL,
**kwargs,
):
super().__init__(*args, **kwargs)
train_df, eval_df = read_litqa_v2_from_hub(
labbench_dataset, **(read_data_kwargs or {})
train_df, eval_df, test_df = read_litqa_v2_from_hub(
train_eval_dataset, test_dataset, **(read_data_kwargs or {})
)
split = LitQAv2TaskSplit(split)
if split == LitQAv2TaskSplit.TRAIN:
self.data = train_df
elif split == LitQAv2TaskSplit.EVAL:
self.data = eval_df
elif split == LitQAv2TaskSplit.TEST:
self.data = test_df
else:
assert_never(split)

Expand Down
32 changes: 21 additions & 11 deletions paperqa/litqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,30 @@ def make_discounted_returns(


DEFAULT_LABBENCH_HF_HUB_NAME = "futurehouse/lab-bench"
# Test split from Aviary paper's section 4.3: https://doi.org/10.48550/arXiv.2412.21154
DEFAULT_AVIARY_PAPER_HF_HUB_NAME = "futurehouse/aviary-paper-data"


def read_litqa_v2_from_hub(
labbench_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
train_eval_dataset: str = DEFAULT_LABBENCH_HF_HUB_NAME,
test_dataset: str = DEFAULT_AVIARY_PAPER_HF_HUB_NAME,
randomize: bool = True,
seed: int | None = None,
train_eval_split: float = 0.8,
) -> tuple[pd.DataFrame, pd.DataFrame]:
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
Read LitQA v2 JSONL into train and eval DataFrames.
Read LitQA v2 JSONL into train, eval, and test DataFrames.

Args:
labbench_dataset: The Hugging Face Hub dataset's name corresponding with the
LAB-Bench dataset.
train_eval_dataset: Hugging Face Hub dataset's name corresponding with train
and eval splits.
test_dataset: Hugging Face Hub dataset's name corresponding with a test split.
randomize: Opt-out flag to shuffle the dataset after loading in by question.
seed: Random seed to use for the shuffling.
train_eval_split: Train/eval split fraction, default is 80% train 20% eval.

Raises:
DatasetNotFoundError: If the LAB-Bench dataset is not found, or the
DatasetNotFoundError: If any of the datasets are not found, or the
user is unauthenticated.
"""
try:
Expand All @@ -67,9 +71,15 @@ def read_litqa_v2_from_hub(
" `pip install paper-qa[datasets]`."
) from exc

litqa_v2 = load_dataset(labbench_dataset, "LitQA2")["train"].to_pandas()
litqa_v2["distractors"] = litqa_v2["distractors"].apply(list)
train_eval = load_dataset(train_eval_dataset, "LitQA2")["train"].to_pandas()
test = load_dataset(test_dataset, "LitQA2")["test"].to_pandas()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more of a question for me -- do we validate columns here? (i.e. is distractors checked for explicitly?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks to https://github.com/Future-House/paper-qa/blob/v5.9.2/paperqa/agents/task.py#L450-L455, we get:

  • Confirmation that column attributes are present
  • typeguard can check the types of the values during unit testing, for our defaults

To be clear, I added this comment:

Let downstream usage in the TaskDataset's environment factories check for the
presence of other DataFrame columns

# Convert to list so it's not unexpectedly a numpy array
train_eval["distractors"] = train_eval["distractors"].apply(list)
test["distractors"] = test["distractors"].apply(list)
# Let downstream usage in the TaskDataset's environment factories check for the
# presence of other DataFrame columns
if randomize:
litqa_v2 = litqa_v2.sample(frac=1, random_state=seed)
num_train = int(len(litqa_v2) * train_eval_split)
return litqa_v2[:num_train], litqa_v2[num_train:]
train_eval = train_eval.sample(frac=1, random_state=seed)
test = test.sample(frac=1, random_state=seed)
num_train = int(len(train_eval) * train_eval_split)
return train_eval[:num_train], train_eval[num_train:], test
2 changes: 1 addition & 1 deletion tests/test_litqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_make_discounted_returns(

def test_creating_litqa_questions() -> None:
"""Test making LitQA eval questions after downloading from Hugging Face Hub."""
_, eval_split = read_litqa_v2_from_hub(seed=42)
eval_split = read_litqa_v2_from_hub(seed=42)[1]
assert len(eval_split) > 3
assert [
MultipleChoiceQuestion(
Expand Down
6 changes: 5 additions & 1 deletion tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ class TestTaskDataset:

@pytest.mark.parametrize(
("split", "expected_length"),
[(LitQAv2TaskSplit.TRAIN, 159), (LitQAv2TaskSplit.EVAL, 40)],
[
(LitQAv2TaskSplit.TRAIN, 159),
(LitQAv2TaskSplit.EVAL, 40),
(LitQAv2TaskSplit.TEST, 49),
],
)
@pytest.mark.asyncio
async def test___len__(
Expand Down
Loading