Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Evaluation with HELM #370

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
82 changes: 82 additions & 0 deletions eval/helm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Optional, Tuple, List

import torch

# acknowledgement to
# https://github.com/llm-efficiency-challenge/neurips_llm_efficiency_challenge/blob/master/toy-submission/helper.py


@torch.no_grad()
def generate(
model: torch.nn.Module,
idx: torch.Tensor,
max_returned_tokens: int,
max_seq_length: int,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
) -> Tuple[List[int], List[float], List[Tuple[int, float]]]:
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.

The implementation of this function is modified from A. Karpathy's nanoGPT.

Args:
model: The model to use.
idx: Tensor of shape (T) with indices of the prompt sequence.
max_returned_tokens: The maximum number of tokens to return (given plus generated).
max_seq_length: The maximum sequence length allowed. Should be less or equal than the block size.
temperature: Scales the predicted logits by 1 / temperature.
top_k: If specified, only sample among the tokens with the k highest probabilities.
eos_id: If specified, stop generating any more token once the <eos> token is triggered.

Returns:
Tuple containing a list of token indexes, id of the top log probability, and the actual log probability of the
selected token.
"""
T = idx.size(0)
assert max_returned_tokens > T
device, dtype = idx.device, idx.dtype
# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(max_returned_tokens, dtype=dtype, device=device)
empty[:T] = idx
idx = empty
input_pos = torch.arange(0, T, device=device)

top_logprob = []
logprob = []

# generate up to a fixed number of tokens
for _ in range(max_returned_tokens - T):
x = idx.index_select(0, input_pos).view(1, -1)

# forward
logits = model(x, max_seq_length, input_pos)
logits = logits[0, -1] / temperature

# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits = torch.where(logits < v[[-1]], -float("Inf"), logits)

probs = torch.nn.functional.softmax(logits, dim=-1)

idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)

# append the logprob of selected token
logprob.append(torch.log(probs[idx_next]).item())

# append th idx and logprob of top token
top_logprob.append((torch.argmax(probs).item(), torch.log(probs).max().item()))

# advance
input_pos = input_pos[-1:] + 1

# concatenate the new generation
idx = idx.index_copy(0, input_pos, idx_next)

# if <eos> token is triggered, return the output (stop generation)
if idx_next == eos_id:
return idx[:input_pos], logprob, top_logprob # include the EOS token

return idx, logprob, top_logprob
67 changes: 58 additions & 9 deletions tutorials/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,63 @@ python eval/lm_eval_harness.py \
--save_filepath "results.json"
```

### To evaluate LoRA finetuned LLMs:
## Using HELM

```bash
python eval/lm_eval_harness_lora.py \
--lora_path "lit_model_lora_finetuned.pth" \
--checkpoint_dir "checkpoints/Llama-2-7b-hf/" \
--precision "bf16-true" \
--eval_tasks "[truthfulqa_mc,hellaswag]" \
--batch_size 4 \
--save_filepath "results.json"
> ![NOTE]\
> acknowledgements to NeurIPS Challenge Organizers and HELM authora for the instructions shown below

### Installing HELM

> ![WARNING]\
> HELM requires Python 3.8\
> It is recommended to install HELM into a virtual environment with Python version 3.8 to avoid dependency conflicts

To create, a Python virtual environment with Python version >= 3.8 and activate it, follow the instructions below.

Install HELM with conda or miniconda, do:

```sh
conda create -n crfm-helm python=3.8 pip -y
conda activate crfm-helm
pip install crfm-helm
```

### Configure HELM

You can configure which datasets to run HELM (Holistic Evaluation of Language Models) on by editing a `run_specs.conf`.

Here's how you can create a simple `run_spec.conf` for local testing:

```sh
echo 'entries: [{description: "mmlu:model=neurips/local,subject=college_computer_science", priority: 4}]' > run_specs.conf
```

> ![NOTE]\
>
> To run your model on a large set of datasets, take a look at the [official example](https://github.com/stanford-crfm/helm/blob/main/src/helm/benchmark/presentation/run_specs_lite.conf) for inspiration.

### Run and Analyze Your Results

After creating `run_spec.conf`, you can run a quick local test with:

```sh
helm-run --conf-paths run_specs.conf --suite v1 --max-eval-instances 10
```

After running the above command, HELM will create a directory named `benchmark_output`. This directory will contain several subdirectories, which are listed below:

- `runs/`
- `eval_cache/`
- `mmlu:{SUBJECT}, {METHOD}, {MODEL}/`
- `scenario_instances/`
- `scenarios/`
- `mmlu`

and then analyze the results with:

```sh
helm-summarize --suite v1
helm-server
```

This will analyze results and then launch a server on your local host, if you're working on a remote machine you might need to setup port forwarding. If everything worked correctly you should see a page that looks like [this](https://user-images.githubusercontent.com/3282513/249620854-080f4d77-c5fd-4ea4-afa4-cf6a9dceb8c9.png)