Skip to content

Commit

Permalink
Merge pull request #468 from allenai/mmlu-downstream
Browse files Browse the repository at this point in the history
Add MMLU downstream tasks
  • Loading branch information
OyvindTafjord committed Feb 28, 2024
2 parents 0c58bee + 079616b commit 67d24f5
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Added `output_hidden_states` argument and associated functionality to `OLMo` and `OLMoForCausalLM` to return model intermediate hidden states.
- Added MMLU downstream evaluation tasks.

## [v0.2.4](https://github.com/allenai/OLMo/releases/tag/v0.2.4) - 2024-02-02

Expand Down
5 changes: 4 additions & 1 deletion docs/Kempner.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,8 @@ Getting started
from olmo.eval.downstream import *
tokenizer = Tokenizer.from_file("tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json")
for x in label_to_task_map.values():
x(tokenizer=tokenizer)
kwargs = {}
if isinstance(x, tuple):
x, kwargs = x
x(tokenizer=tokenizer, **kwargs)
```
5 changes: 4 additions & 1 deletion olmo/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ def build_downstream_evaluator(
device: torch.device,
is_unit_test=False,
) -> Evaluator:
task_kwargs = {}
task_class = label_to_task_map[eval_cfg.label]
ds_eval_dataset = task_class(tokenizer=tokenizer) # type: ignore
if isinstance(task_class, tuple):
task_class, task_kwargs = task_class
ds_eval_dataset = task_class(tokenizer=tokenizer, **task_kwargs) # type: ignore
data_config = eval_cfg.data
if is_unit_test:
ds_eval_sampler = None
Expand Down
150 changes: 141 additions & 9 deletions olmo/eval/downstream.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import re
from typing import Any, ClassVar, Dict, List, Optional
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Union

import datasets
import torch
Expand Down Expand Up @@ -149,8 +149,9 @@ def __init__(
self,
tokenizer: Tokenizer,
dataset_path: str,
dataset_name: Optional[str] = None,
dataset_name: Union[str, Sequence[str], None] = None,
model_ctx_len: int = 2048,
split="validation",
):
super().__init__()

Expand All @@ -160,11 +161,22 @@ def __init__(
self.model_ctx_len = model_ctx_len

self.samples: List[Dict[str, Any]] = []
self.dataset = datasets.load_dataset(
path=self.dataset_path,
name=self.dataset_name,
split="validation",
)
dataset_names: Sequence[Optional[str]]
if isinstance(dataset_name, str) or dataset_name is None:
dataset_names = [dataset_name]
else:
dataset_names = dataset_name

dataset_list = []
for ds_name in dataset_names:
dataset_list.append(
datasets.load_dataset(
path=self.dataset_path,
name=ds_name,
split=split,
)
)
self.dataset = datasets.concatenate_datasets(dataset_list)

# prep examples
self.prep_examples()
Expand Down Expand Up @@ -588,7 +600,7 @@ class BoolQ(ICLMultiChoiceTaskDataset):
}
"""

metric_type = "pmi_dc"
metric_type = "acc"

def __init__(self, tokenizer, dataset_path="boolq", dataset_name=None):
super().__init__(
Expand Down Expand Up @@ -710,7 +722,7 @@ class ArcChallenge(ArcEasy):
implement PMI_DC
"""

metric_type = "pmi_dc"
metric_type = "len_norm" # Ideally "pmi_dc"

def __init__(self, tokenizer, dataset_path="ai2_arc", dataset_name="ARC-Challenge"):
super().__init__(
Expand Down Expand Up @@ -962,6 +974,118 @@ def doc_to_domain_conditional(self, doc):
return "Answer:"


class MMLU(ICLMultiChoiceTaskDataset):
"""MMLU creates context with "Question: QUESTION\nAnswer:" and sends the choices as continuations
space added as prefix to each continuation
{
'question': "Which of the following terms describes the body's ability to maintain its normal state?",
'subject': 'anatomy',
'choices': ['Anabolism', 'Catabolism', 'Tolerance', 'Homeostasis'],
' answer': 3
}
"""

metric_type = "len_norm" # Ideally pmi_dc

_subcategories = {
"abstract_algebra": ["math"],
"anatomy": ["health"],
"astronomy": ["physics"],
"business_ethics": ["business"],
"clinical_knowledge": ["health"],
"college_biology": ["biology"],
"college_chemistry": ["chemistry"],
"college_computer_science": ["computer science"],
"college_mathematics": ["math"],
"college_medicine": ["health"],
"college_physics": ["physics"],
"computer_security": ["computer science"],
"conceptual_physics": ["physics"],
"econometrics": ["economics"],
"electrical_engineering": ["engineering"],
"elementary_mathematics": ["math"],
"formal_logic": ["philosophy"],
"global_facts": ["other"],
"high_school_biology": ["biology"],
"high_school_chemistry": ["chemistry"],
"high_school_computer_science": ["computer science"],
"high_school_european_history": ["history"],
"high_school_geography": ["geography"],
"high_school_government_and_politics": ["politics"],
"high_school_macroeconomics": ["economics"],
"high_school_mathematics": ["math"],
"high_school_microeconomics": ["economics"],
"high_school_physics": ["physics"],
"high_school_psychology": ["psychology"],
"high_school_statistics": ["math"],
"high_school_us_history": ["history"],
"high_school_world_history": ["history"],
"human_aging": ["health"],
"human_sexuality": ["culture"],
"international_law": ["law"],
"jurisprudence": ["law"],
"logical_fallacies": ["philosophy"],
"machine_learning": ["computer science"],
"management": ["business"],
"marketing": ["business"],
"medical_genetics": ["health"],
"miscellaneous": ["other"],
"moral_disputes": ["philosophy"],
"moral_scenarios": ["philosophy"],
"nutrition": ["health"],
"philosophy": ["philosophy"],
"prehistory": ["history"],
"professional_accounting": ["other"],
"professional_law": ["law"],
"professional_medicine": ["health"],
"professional_psychology": ["psychology"],
"public_relations": ["politics"],
"security_studies": ["politics"],
"sociology": ["culture"],
"us_foreign_policy": ["politics"],
"virology": ["health"],
"world_religions": ["philosophy"],
}

_categories = {
"stem": ["physics", "chemistry", "biology", "computer science", "math", "engineering"],
"humanities": ["history", "philosophy", "law"],
"social_sciences": ["politics", "culture", "economics", "geography", "psychology"],
"other": ["other", "business", "health"],
}

def __init__(self, tokenizer, dataset_path="hails/mmlu_no_train", dataset_name=None, split="validation"):
dataset_names = []
# Collect the relevant categories
if dataset_name in MMLU._categories:
for sub_cat in MMLU._categories[dataset_name]:
for name, cats in MMLU._subcategories.items():
if sub_cat in cats:
dataset_names.append(name)
elif dataset_name in MMLU._subcategories:
dataset_names.append(dataset_name)
else: # E.g., "math"
for name, cats in MMLU._subcategories.items():
if dataset_name in cats:
dataset_names.append(name)
super().__init__(tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_names, split=split)

def doc_to_text(self, doc):
return "Question: " + doc["question"] + "\nAnswer:"

def doc_to_continuations(self, doc):
# add spaces in front of continuation
return [" " + choice for choice in doc["choices"]]

def doc_to_label(self, doc):
return doc["answer"]

def doc_to_domain_conditional(self, doc):
del doc
return "Answer:"


label_to_task_map = {
"piqa": PIQA,
"hellaswag": HellaSwag,
Expand All @@ -976,4 +1100,12 @@ def doc_to_domain_conditional(self, doc):
"commitment_bank": CommitmentBank,
"mrpc": MRPC,
"sst2": SST2,
"mmlu_stem_test": (MMLU, {"dataset_name": "stem", "split": "test"}),
"mmlu_humanities_test": (MMLU, {"dataset_name": "humanities", "split": "test"}),
"mmlu_social_sciences_test": (MMLU, {"dataset_name": "social_sciences", "split": "test"}),
"mmlu_other_test": (MMLU, {"dataset_name": "other", "split": "test"}),
"mmlu_stem": (MMLU, {"dataset_name": "stem"}),
"mmlu_humanities": (MMLU, {"dataset_name": "humanities"}),
"mmlu_social_sciences": (MMLU, {"dataset_name": "social_sciences"}),
"mmlu_other": (MMLU, {"dataset_name": "other"}),
}

0 comments on commit 67d24f5

Please sign in to comment.