Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option in TaskManager to not index library default tasks ; Tests for include_path #1856

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,43 @@ class TaskManager:

"""

def __init__(self, verbosity="INFO", include_path: Optional[str] = None) -> None:
def __init__(
self,
verbosity="INFO",
include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True,
) -> None:
self.verbosity = verbosity
self.include_path = include_path
self.logger = utils.eval_logger
self.logger.setLevel(getattr(logging, f"{verbosity}"))

self._task_index = self.initialize_tasks(include_path=include_path)
self._task_index = self.initialize_tasks(
include_path=include_path, include_defaults=include_defaults
)
self._all_tasks = sorted(list(self._task_index.keys()))

self.task_group_map = collections.defaultdict(list)

def initialize_tasks(self, include_path: Optional[str] = None):
def initialize_tasks(
self,
include_path: Optional[Union[str, List]] = None,
include_defaults: bool = True,
):
"""Creates a dictionary of tasks index.

:param include_path: str = None
An additional path to be searched for tasks

:param include_path: Union[str, List] = None
An additional path to be searched for tasks recursively.
Can provide more than one such path as a list.
:param include_defaults: bool = True
If set to false, default tasks (those in lm_eval/tasks/) are not indexed.
:return
Dictionary of task names as key and task metadata
"""
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
if include_defaults:
all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"]
else:
all_paths = []
if include_path is not None:
if isinstance(include_path, str):
include_path = [include_path]
Expand Down
93 changes: 93 additions & 0 deletions tests/test_include_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os

import pytest

import lm_eval.api as api
import lm_eval.evaluator as evaluator
from lm_eval import tasks


@pytest.mark.parametrize(
"limit,model,model_args",
[
(
10,
"hf",
"pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu",
),
],
)
def test_include_correctness(limit: int, model: str, model_args: str):
task_name = ["arc_easy"]

task_manager = tasks.TaskManager()
task_dict = tasks.get_task_dict(task_name, task_manager)

e1 = evaluator.simple_evaluate(
model=model,
tasks=task_name,
limit=limit,
model_args=model_args,
)
assert e1 is not None

# run with evaluate() and "arc_easy" test config (included from ./testconfigs path)
lm = api.registry.get_model(model).create_from_arg_string(
model_args,
{
"batch_size": None,
"max_batch_size": None,
"device": None,
},
)

task_name = ["arc_easy"]

task_manager = tasks.TaskManager(
include_path=os.path.dirname(os.path.abspath(__file__)) + "/testconfigs",
include_defaults=False,
)
task_dict = tasks.get_task_dict(task_name, task_manager)

e2 = evaluator.evaluate(
lm=lm,
task_dict=task_dict,
limit=limit,
)

assert e2 is not None
# check that caching is working

def r(x):
return x["results"]["arc_easy"]

assert all(
x == y
for x, y in zip([y for _, y in r(e1).items()], [y for _, y in r(e2).items()])
)


# test that setting include_defaults = False works as expected and that include_path works
def test_no_include_defaults():
task_name = ["arc_easy"]

task_manager = tasks.TaskManager(
include_path=os.path.dirname(os.path.abspath(__file__)) + "/testconfigs",
include_defaults=False,
)
# should succeed, because we've included an 'arc_easy' task from this dir
task_dict = tasks.get_task_dict(task_name, task_manager)

# should fail, since ./testconfigs has no arc_challenge task
task_name = ["arc_challenge"]
with pytest.raises(KeyError):
task_dict = tasks.get_task_dict(task_name, task_manager) # noqa: F841


# test that include_path containing a task shadowing another task's name fails
# def test_shadowed_name_fails():

# task_name = ["arc_easy"]

# task_manager = tasks.TaskManager(include_path=os.path.dirname(os.path.abspath(__file__)) + "/testconfigs")
# task_dict = tasks.get_task_dict(task_name, task_manager)
21 changes: 21 additions & 0 deletions tests/testconfigs/arc_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
task: arc_easy
dataset_path: allenai/ai2_arc
dataset_name: ARC-Easy
output_type: multiple_choice
training_split: train
validation_split: validation
test_split: test
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{choices.label.index(answerKey)}}"
doc_to_choice: "{{choices.text}}"
should_decontaminate: true
doc_to_decontamination_query: "Question: {{question}}\nAnswer:"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
Loading