Skip to content

Commit

Permalink
Adding tinyBenchmarks datasets (#1545)
Browse files Browse the repository at this point in the history
* Add tinyBenchmarks

* Add acknowledgements

* Add ordering of outputs for data-parallel

* Run pre-commit

* Add few_shot specifications

* Add tinyBenchmarks post-processing

* add conditional import ; fix task names

---------

Co-authored-by: haileyschoelkopf <hailey@eleuther.ai>
  • Loading branch information
LucWeber and haileyschoelkopf committed May 13, 2024
1 parent 1980a13 commit fe9fef4
Show file tree
Hide file tree
Showing 13 changed files with 580 additions and 0 deletions.
130 changes: 130 additions & 0 deletions lm_eval/tasks/tinyBenchmarks/README.md
@@ -0,0 +1,130 @@
# tinyBenchmarks

### Paper

Title: `tinyBenchmarks: evaluating LLMs with fewer examples`

Abstract: https://arxiv.org/abs/2402.14992

The versatility of large language models (LLMs) led to the creation of diverse benchmarks that thoroughly test a variety of language models' abilities. These benchmarks consist of tens of thousands of examples making evaluation of LLMs very expensive. In this paper, we investigate strategies to reduce the number of evaluations needed to assess the performance of an LLM on several key benchmarks. For example, we show that to accurately estimate the performance of an LLM on MMLU, a popular multiple-choice QA benchmark consisting of 14K examples, it is sufficient to evaluate this LLM on 100 curated examples. We release evaluation tools and tiny versions of popular benchmarks: Open LLM Leaderboard, MMLU, HELM, and AlpacaEval 2.0. Our empirical analysis demonstrates that these tools and tiny benchmarks are sufficient to reliably and efficiently reproduce the original evaluation results.

Homepage: -

All configs and utils mirror the ones from their original dataset!

### Groups and Tasks

#### Groups

* `tinyBenchmarks`

#### Tasks

* `tinyArc`, `tinyGSM8k`, `tinyHellaswag`, `tinyMMLU`, `tinyTruthfulQA`, `tinyWinogrande`

### Usage

*tinyBenchmarks* can evaluate different benchmarks with a fraction of their examples.
To obtain accurate results, this task applies post-processing using the *tinyBenchmarks*-package.
You can install the package by running the following commands on the terminal (for more information see [here](https://github.com/felipemaiapolo/tinyBenchmarks/blob/main/README.md?plain=1)):

``` :sh
pip install git+https://github.com/felipemaiapolo/tinyBenchmarks
```

The value that is returned by the task corresponds to the '**IRT++**'-method from the [original paper](https://arxiv.org/abs/2402.14992).
Evaluate specific tasks individually (e.g. `--tasks tinyHellaswag`) or all [open LLM leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard) tasks by specifying `--tasks tinyBenchmarks`.

### Advanced usage

To obtain the estimated accuracies from all methods from the original paper, the *tinyBenchmarks*-package has to be applied manually.
To do so, run the evaluation with the `--log_samples` and `--output_path` arguments. For example:

```bash
lm_eval --model hf \
--model_args pretrained="mistralai/Mistral-7B-Instruct-v0.2" \
--tasks tinyHellaswag \
--batch_size 4 \
--output_path '<output_path>' \
--log_samples
```

Afterwards, run include the correct `file_path` and run the following script:

```python
import json
import tinyBenchmarks as tb
import numpy as np

# Choose benchmark (e.g. hellaswag)
benchmark = 'hellaswag' # possible benchmarks:
# ['mmlu','truthfulqa', 'gsm8k',
# 'winogrande', 'arc', 'hellaswag']

# Get score vector from output-file (the metric [here `acc_norm`] depends on the benchmark)
file_path = '<output_path>/<output-file.jsonl>'
with open(file_path, 'r') as file:
outputs = json.load(file)

# Ensuring correct order of outputs
outputs = sorted(outputs, key=lambda x: x['doc_id'])

y = np.array([float(item['acc_norm']) for item in outputs])

### Evaluation
tb.evaluate(y, benchmark)
```

### Performance

We report in the following tables the average estimation error in the test set (using data from the paper) and standard deviation across LLMs.

#### Open LLM Leaderboard

Estimating performance for each scenario separately
|| IRT | p-IRT | gp-IRT |
|--|--|--|--|
| TruthfulQA | 0.013 (0.010) | 0.010 (0.009) | 0.011 (0.009) |
| GSM8K | 0.022 (0.017) | 0.029 (0.022) | 0.020 (0.017) |
| Winogrande | 0.022 (0.017) | 0.016 (0.014) | 0.015 (0.013) |
| ARC | 0.022 (0.018) | 0.017 (0.014) | 0.017 (0.013) |
| HellaSwag | 0.013 (0.016) | 0.015 (0.012) | 0.015 (0.012) |
| MMLU | 0.024 (0.017) | 0.016 (0.015) | 0.016 (0.015) |

Estimating performance for each scenario all at once
|| IRT | p-IRT | gp-IRT |
|--|--|--|--|
| TruthfulQA | 0.013 (0.010) | 0.016 (0.013) | 0.011 (0.009) |
| GSM8K | 0.022 (0.017) | 0.022 (0.017) | 0.020 (0.015) |
| Winogrande | 0.022 (0.017) | 0.011 (0.013) | 0.011 (0.011) |
| ARC | 0.022 (0.018) | 0.012 (0.010) | 0.010 (0.009) |
| HellaSwag | 0.013 (0.016) | 0.011 (0.020) | 0.011 (0.018) |
| MMLU | 0.024 (0.018) | 0.017 (0.017) | 0.015 (0.015) |



### Citation

```
@article{polo2024tinybenchmarks,
title={tinyBenchmarks: evaluating LLMs with fewer examples},
author={Maia Polo, Felipe and Weber, Lucas and Choshen, Leshem and Sun, Yuekai and Xu, Gongjun and Yurochkin, Mikhail},
journal={arXiv preprint arXiv:2402.14992},
year={2024}
}
```

Please also reference the respective original dataset that you are using!

### Checklist

For adding novel benchmarks/datasets to the library:
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?


If other tasks on this dataset are already supported:
* [x] Is the "Main" variant of this task clearly denoted?
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [x] Have you noted which, if any, published evaluation setups are matched by this variant?
54 changes: 54 additions & 0 deletions lm_eval/tasks/tinyBenchmarks/agg_functions.py
@@ -0,0 +1,54 @@
from typing import List

import numpy as np


try:
import tinyBenchmarks as tb
except ModuleNotFoundError:
raise ModuleNotFoundError(
"`tinyBenchmarks` is required for tinyBenchmarks task metric calculation, install via \
`pip install git+https://github.com/felipemaiapolo/tinyBenchmarks`"
)


def agg_pirt(items: List[float], benchmark: str) -> float:
items = np.array(items)
predictions = tb.evaluate(items, benchmark)
return predictions[benchmark]["pirt"]


def agg_gpirt_arc(items: List[float], benchmark: str = "arc") -> float:
items = np.array(items)
predictions = tb.evaluate(items, benchmark)
return predictions[benchmark]["gpirt"]


def agg_gpirt_gsm8k(items: List[float], benchmark: str = "gsm8k") -> float:
items = np.array(items)
predictions = tb.evaluate(items, benchmark)
return predictions[benchmark]["gpirt"]


def agg_gpirt_hellaswag(items: List[float], benchmark: str = "hellaswag") -> float:
items = np.array(items)
predictions = tb.evaluate(items, benchmark)
return predictions[benchmark]["gpirt"]


def agg_gpirt_mmlu(items: List[float], benchmark: str = "mmlu") -> float:
items = np.array(items)
predictions = tb.evaluate(items, benchmark)
return predictions[benchmark]["gpirt"]


def agg_gpirt_truthfulqa(items: List[float], benchmark: str = "truthfulqa") -> float:
items = np.array(items)
predictions = tb.evaluate(items, benchmark)
return predictions[benchmark]["gpirt"]


def agg_gpirt_winogrande(items: List[float], benchmark: str = "winogrande") -> float:
items = np.array(items)
predictions = tb.evaluate(items, benchmark)
return predictions[benchmark]["gpirt"]
19 changes: 19 additions & 0 deletions lm_eval/tasks/tinyBenchmarks/tinyArc.yaml
@@ -0,0 +1,19 @@
task: tinyArc
dataset_path: tinyBenchmarks/tinyAI2_arc
dataset_name: ARC-Challenge
output_type: multiple_choice
training_split: train
validation_split: validation
test_split: test
num_fewshot: 25
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{choices.label.index(answerKey)}}"
doc_to_choice: "{{choices.text}}"
should_decontaminate: true
doc_to_decontamination_query: "Question: {{question}}\nAnswer:"
metric_list:
- metric: acc_norm
aggregation: !function agg_functions.agg_gpirt_arc
higher_is_better: true
metadata:
version: 0.0
16 changes: 16 additions & 0 deletions lm_eval/tasks/tinyBenchmarks/tinyBenchmarks.yaml
@@ -0,0 +1,16 @@
group: tinyBenchmarks
task:
- task: tinyArc
num_fewshot: 25
- task: tinyGSM8k
num_fewshot: 5
- task: tinyMMLU
num_fewshot: 0
- task: tinyWinogrande
num_fewshot: 5
- task: tinyHellaswag
num_fewshot: 10
- task: tinyTruthfulQA
num_fewshot: 0
metadata:
version: 0.0
44 changes: 44 additions & 0 deletions lm_eval/tasks/tinyBenchmarks/tinyGSM8k.yaml
@@ -0,0 +1,44 @@
task: tinyGSM8k
dataset_path: tinyBenchmarks/tinyGSM8k
dataset_name: main
output_type: generate_until
training_split: train
fewshot_split: train
test_split: test
num_fewshot: 5
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}"
metric_list:
- metric: exact_match
aggregation: !function agg_functions.agg_gpirt_gsm8k
higher_is_better: true
ignore_case: true
ignore_punctuation: false
regexes_to_ignore:
- ","
- "\\$"
- "(?s).*#### "
- "\\.$"
generation_kwargs:
until:
- "Question:"
- "</s>"
- "<|im_end|>"
do_sample: false
temperature: 0.0
repeats: 1
num_fewshot: 5
filter_list:
- name: "strict-match"
filter:
- function: "regex"
regex_pattern: "#### (\\-?[0-9\\.\\,]+)"
- function: "take_first"
- name: "flexible-extract"
filter:
- function: "regex"
group_select: -1
regex_pattern: "(-?[$0-9.,]{2,})|(-?[0-9]+)"
- function: "take_first"
metadata:
version: 0.0
18 changes: 18 additions & 0 deletions lm_eval/tasks/tinyBenchmarks/tinyHellaswag.yaml
@@ -0,0 +1,18 @@
task: tinyHellaswag
dataset_path: tinyBenchmarks/tinyHellaswag
dataset_name: null
output_type: multiple_choice
training_split: train
validation_split: validation
num_fewshot: 10
test_split: null
process_docs: !function utils_hellaswag.process_docs
doc_to_text: "{{query}}"
doc_to_target: "{{label}}"
doc_to_choice: "choices"
metric_list:
- metric: acc_norm
aggregation: !function agg_functions.agg_gpirt_hellaswag
higher_is_better: true
metadata:
version: 0.0
19 changes: 19 additions & 0 deletions lm_eval/tasks/tinyBenchmarks/tinyMMLU.yaml
@@ -0,0 +1,19 @@
task: tinyMMLU
dataset_path: tinyBenchmarks/tinyMMLU
dataset_name: all
test_split: test
fewshot_split: dev
num_fewshot: 0
fewshot_config:
sampler: first_n
output_type: multiple_choice
doc_to_text: "{{input_formatted}}"
doc_to_choice: ["A", "B", "C", "D"]
doc_to_target: answer
num_fewshot: 0
metric_list:
- metric: acc_norm
aggregation: !function agg_functions.agg_gpirt_mmlu
higher_is_better: true
metadata:
version: 0.0
34 changes: 34 additions & 0 deletions lm_eval/tasks/tinyBenchmarks/tinyTruthfulQA_mc1.yaml
@@ -0,0 +1,34 @@
task: tinyTruthfulQA_mc1
dataset_path: tinyBenchmarks/tinyTruthfulQA
dataset_name: multiple_choice
output_type: multiple_choice
training_split: null
validation_split: validation
test_split: null
num_fewshot: 0
doc_to_text: "\
{% set prompt_qa = '\
Q: What is human life expectancy in the United States?\n\
A: Human life expectancy in the United States is 78 years.\n\n\
Q: Who was president of the United States in 1955?\n\
A: Dwight D. Eisenhower was president of the United States in 1955.\n\n\
Q: Which party did he belong to?\n\
A: He belonged to the Republican Party.\n\n\
Q: What is the square root of banana?\n\
A: I have no comment.\n\n\
Q: How does a telescope work?\n\
A: Telescopes use lenses or mirrors to focus light and make objects appear closer.\n\n\
Q: Where were the 1992 Olympics held?\n\
A: The 1992 Olympics were held in Barcelona, Spain.\
'%}\
{{prompt_qa + '\n\nQ: ' + question + '\nA:'}}"
doc_to_target: 0
doc_to_choice: "{{mc1_targets.choices}}"
should_decontaminate: True
doc_to_decontamination_query: question
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
metadata:
version: 0.0
13 changes: 13 additions & 0 deletions lm_eval/tasks/tinyBenchmarks/tinyTruthfulQA_mc2.yaml
@@ -0,0 +1,13 @@
include: tinyTruthfulQA_mc1.yaml
task: tinyTruthfulQA
doc_to_target: 0
doc_to_choice: "{{mc2_targets.choices}}"
process_results: !function utils_truthfulqa.process_results_mc2
should_decontaminate: True
doc_to_decontamination_query: question
metric_list:
- metric: acc
aggregation: !function agg_functions.agg_gpirt_truthfulqa
higher_is_better: true
metadata:
version: 0.0
18 changes: 18 additions & 0 deletions lm_eval/tasks/tinyBenchmarks/tinyWinogrande.yaml
@@ -0,0 +1,18 @@
task: tinyWinogrande
dataset_path: tinyBenchmarks/tinyWinogrande
dataset_name: winogrande_xl
output_type: multiple_choice
training_split: train
validation_split: validation
num_fewshot: 5
doc_to_text: !function utils_winogrande.doc_to_text
doc_to_target: !function utils_winogrande.doc_to_target
doc_to_choice: !function utils_winogrande.doc_to_choice
should_decontaminate: true
doc_to_decontamination_query: sentence
metric_list:
- metric: acc_norm
aggregation: !function agg_functions.agg_gpirt_winogrande
higher_is_better: true
metadata:
version: 0.0

0 comments on commit fe9fef4

Please sign in to comment.