-
Notifications
You must be signed in to change notification settings - Fork 4
feat: add support for the original mt-bench #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
ba4220d
Add llamacpp dependency and update gitignore with generated directories
ErlisLushtaku d2a5a42
Add documentation for llamacpp in Readme
ErlisLushtaku a828adb
Document direnv usage for environment variables management
ErlisLushtaku 0dcebf9
narrow down transformers dependency to fix version mismatch
ErlisLushtaku d60073b
Add max_model_len param for VLLM in order to prevent OOM errors
ErlisLushtaku 38f63ee
Fix completion loading and EuroLLM-9B example
ErlisLushtaku 6f5e0fc
Remove `direnv` documentation
ErlisLushtaku 42ff2ae
Revert stylistic (formatting) changes and add more documentation for …
ErlisLushtaku 8fcb032
Rename OPENJURY_EVAL_DATA to OPENJURY_DATA
ErlisLushtaku df958af
Merge main
ErlisLushtaku 35856f2
Revert changes in gitignore
ErlisLushtaku 6a11182
Handle models with max_position_embeddings when we pass max_model_len
ErlisLushtaku fecd3ed
Revert EuroLLM-9B-Instruct to EuroLLM-9B since there is a default cha…
ErlisLushtaku 0b4eaec
fix tests
ErlisLushtaku 29340b0
Change test github workflow to use uv instead of pip for a more robus…
ErlisLushtaku 2c294f1
Move dev dependencies to dependency-group
ErlisLushtaku 4be61bf
Revert comment removal
ErlisLushtaku 51d2597
Add pre-commit hook
ErlisLushtaku 8dee7b2
add project scripts and move slurmpilot to dev group
ErlisLushtaku fdc9410
fix LlamaCpp bug with ChatTemplate
ErlisLushtaku 48c5373
Add MT-Bench multi-turn evaluation support
ErlisLushtaku 648a9be
Merge branch 'main' into erlislushtaku/feat/add-mt-bench-support
ErlisLushtaku 14f747e
fix result formatting
ErlisLushtaku e67ea79
remove double environment variable
ErlisLushtaku 4089be8
remove accidental duplications
ErlisLushtaku 03f5cce
Refactor
ErlisLushtaku 8ffe3a6
Remove duplication between prompt templates
ErlisLushtaku b877f11
add temperature argument
ErlisLushtaku c2056b5
add option for making mt-bench consistent with the original one from …
ErlisLushtaku 41cd15d
Merge branch 'main' into erlislushtaku/feat/add-mt-bench-support
ErlisLushtaku 0ca66c5
remove redundant print statement
ErlisLushtaku a295305
move mt-bench logic from the entrypoint
ErlisLushtaku 0fb9700
Remove stale unused entries for fastchat mode
ErlisLushtaku e5670ea
Merge origin/main into erlislushtaku/feat/add-mt-bench-support
ErlisLushtaku 6dd78fd
Refactor mt-bench eval helpers into shared runtime module
ErlisLushtaku 0094eea
move cli args and parsing to separate util to remove dependencies on …
ErlisLushtaku f522e5b
refactor to address comments on PR
ErlisLushtaku 6a851c3
remove openjury mode for mt-bench keeping only the original version
ErlisLushtaku caaa079
Merge remote-tracking branch 'origin/main' into erlislushtaku/feat/ad…
ErlisLushtaku 2e8e04e
Restore code and fix after merge/refactor
ErlisLushtaku 5a314a7
format
ErlisLushtaku 8c91606
fix ci
ErlisLushtaku File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,181 @@ | ||
| """CLI argument configuration for generation and evaluation entrypoints.""" | ||
|
|
||
| import argparse | ||
| import json | ||
| from dataclasses import dataclass, field | ||
|
|
||
|
|
||
| @dataclass | ||
| class CliArgs: | ||
| dataset: str | ||
| model_A: str | ||
| model_B: str | ||
| judge_model: str | ||
|
|
||
| n_instructions: int | None = None | ||
| provide_explanation: bool = False | ||
| swap_mode: str = "fixed" | ||
| ignore_cache: bool = False | ||
| use_tqdm: bool = False | ||
| truncate_all_input_chars: int = 8192 | ||
| max_out_tokens_models: int = 32768 | ||
| max_out_tokens_judge: int = 32768 | ||
| max_model_len: int | None = None | ||
| chat_template: str | None = None | ||
| result_folder: str = "results" | ||
| engine_kwargs: dict = field(default_factory=dict) | ||
|
|
||
| def __post_init__(self): | ||
| supported_modes = ["fixed", "both"] | ||
| assert self.swap_mode in supported_modes, ( | ||
| f"Only {supported_modes} modes are supported but got {self.swap_mode}." | ||
| ) | ||
|
|
||
| @classmethod | ||
| def parse_args(cls): | ||
| parser = argparse.ArgumentParser( | ||
| prog="Generate completion and evaluate with a judge", | ||
| ) | ||
| parser.add_argument( | ||
| "--dataset", | ||
| help="The dataset to use. For instance `alpaca-eval`, `arena-hard`, `m-arena-hard-EU` for instruction " | ||
| "tuning cases or `french-contexts`, `spanish-contexts` for base models.", | ||
| ) | ||
| parser.add_argument( | ||
| "--model_A", | ||
| required=True, | ||
| help="Name of the LLM to use for a generation, must be a valid choice for `generation_provider`", | ||
| ) | ||
| parser.add_argument( | ||
| "--model_B", | ||
| required=True, | ||
| help="Name of the LLM to use for a generation, must be a valid choice for `generation_provider`", | ||
| ) | ||
| parser.add_argument( | ||
| "--judge_model", | ||
| required=True, | ||
| help="Name of the LLM to use, for instance `Together/meta-llama/Meta-Llama-3-70B-Instruct-Turbo`, " | ||
| "`VLLM/meta-llama/Meta-Llama-3-70B-Instruct-Turbo`, `LangChain/LocalPath` etc", | ||
| ) | ||
| parser.add_argument( | ||
| "--n_instructions", | ||
| type=int, | ||
| required=False, | ||
| ) | ||
| parser.add_argument( | ||
| "--provide_explanation", | ||
| action="store_true", | ||
| help="If specified, judge will provide explanation before making a judgement. Does not necessarily improve" | ||
| "the accuracy of the judge but enables some result interpretation.", | ||
| ) | ||
| parser.add_argument( | ||
| "--swap_mode", | ||
| type=str, | ||
| choices=["fixed", "both"], | ||
| default="fixed", | ||
| help="Model comparison order mode. 'fixed': always use model order A-B. 'both': correct for model order " | ||
| "bias by evaluating each instruction twice, once as A-B and once as B-A, and average. This helps account " | ||
| "for judge position bias. Default is 'fixed'.", | ||
| ) | ||
| parser.add_argument( | ||
| "--ignore_cache", | ||
| action="store_true", | ||
| help="If specified, ignore cache of previous completions.", | ||
| ) | ||
| parser.add_argument( | ||
| "--use_tqdm", | ||
| action="store_true", | ||
| help="If specified, use tqdm, does not work with all model providers, vLLM in particular.", | ||
| ) | ||
| parser.add_argument( | ||
| "--result_folder", | ||
| type=str, | ||
| required=False, | ||
| default="results", | ||
| help="The folder to save the results. Defaults to `results`. Evaluation results will be saved in" | ||
| " `[result_folder]/[evaluation_name]`.", | ||
| ) | ||
| parser.add_argument( | ||
| "--truncate_all_input_chars", | ||
| type=int, | ||
| required=False, | ||
| default=8192, | ||
| help="Character-level truncation applied before tokenization: truncates each instruction " | ||
| "before model A/B generation and truncates each completion before judge evaluation.", | ||
| ) | ||
| parser.add_argument( | ||
| "--max_out_tokens_models", | ||
| type=int, | ||
| required=False, | ||
| default=32768, | ||
| help=( | ||
| "Generation token budget for each model A/B response. For VLLM, keep this <= " | ||
| "--max_model_len (if provided)." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--max_out_tokens_judge", | ||
| type=int, | ||
| required=False, | ||
| default=32768, | ||
| help=( | ||
| "Generation token budget for the judge response (reasoning + scores). For " | ||
| "VLLM, keep this <= --max_model_len (if provided)." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--max_model_len", | ||
| type=int, | ||
| required=False, | ||
| default=None, | ||
| help=( | ||
| "Optional total context window for VLLM models (prompt + generation). This is " | ||
| "independent from --max_out_tokens_models/--max_out_tokens_judge, which only cap " | ||
| "generated tokens. This is useful on smaller GPUs to avoid OOM." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--chat_template", | ||
| type=str, | ||
| required=False, | ||
| default=None, | ||
| help="Jinja2 chat template string to use instead of the model's tokenizer template. " | ||
| "If not provided, ChatML is used as fallback for models without a chat template.", | ||
| ) | ||
| parser.add_argument( | ||
| "--engine_kwargs", | ||
| type=str, | ||
| required=False, | ||
| default="{}", | ||
| help=( | ||
| "JSON dict of engine-specific kwargs forwarded to the underlying engine. " | ||
| 'Example for vLLM: \'{"tensor_parallel_size": 2, "gpu_memory_utilization": 0.9}\'.' | ||
| ), | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| try: | ||
| engine_kwargs = json.loads(args.engine_kwargs) if args.engine_kwargs else {} | ||
| if not isinstance(engine_kwargs, dict): | ||
| raise ValueError("engine_kwargs must be a JSON object") | ||
| except Exception as e: | ||
| raise SystemExit(f"Failed to parse --engine_kwargs: {e}") from e | ||
|
|
||
| return cls( | ||
| dataset=args.dataset, | ||
| model_A=args.model_A, | ||
| model_B=args.model_B, | ||
| judge_model=args.judge_model, | ||
| n_instructions=args.n_instructions, | ||
| provide_explanation=args.provide_explanation, | ||
| swap_mode=args.swap_mode, | ||
| ignore_cache=args.ignore_cache, | ||
| use_tqdm=args.use_tqdm, | ||
| truncate_all_input_chars=args.truncate_all_input_chars, | ||
| max_out_tokens_models=args.max_out_tokens_models, | ||
| max_out_tokens_judge=args.max_out_tokens_judge, | ||
| max_model_len=args.max_model_len, | ||
| chat_template=args.chat_template, | ||
| result_folder=args.result_folder, | ||
| engine_kwargs=engine_kwargs, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| """Shared evaluation runtime helpers used by entrypoints and benchmark pipelines.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
| import pandas as pd | ||
|
|
||
| from judgearena.evaluate import PairScore, annotate_battles | ||
| from judgearena.utils import compute_pref_summary | ||
|
|
||
|
|
||
| def print_results(results): | ||
| """Print battle results in a readable format.""" | ||
| print("\n" + "=" * 60) | ||
| print("🏆 MODEL BATTLE RESULTS 🏆".center(60)) | ||
| print(f"📊 Dataset: {results['dataset']}") | ||
| print( | ||
| f"🤖 Competitors: Model A: {results['model_A']} vs Model B: {results['model_B']}" | ||
| ) | ||
| print(f"⚖️ Judge: {results['judge_model']}") | ||
| print("📈 Results Summary:") | ||
| print(f" Total Battles: {results['num_battles']}") | ||
| print(f" Win Rate (A): {results['winrate']:.1%}") | ||
| print(f" ✅ Wins: {results['num_wins']}") | ||
| print(f" ❌ Losses: {results['num_losses']}") | ||
| print(f" 🤝 Ties: {results['num_ties']}") | ||
| if results.get("num_missing", 0) > 0: | ||
| print(f" ❓ Missing: {results['num_missing']}") | ||
|
|
||
| per_category = results.get("per_category") | ||
| if per_category: | ||
| print("\nPer-Category Breakdown:") | ||
| print( | ||
| f" {'Category':<14} | {'Win Rate(A)':>11} | {'Wins':>4} | {'Losses':>6} | {'Ties':>4}" | ||
| ) | ||
| print(f" {'-' * 14}-+-{'-' * 11}-+-{'-' * 4}-+-{'-' * 6}-+-{'-' * 4}") | ||
| for cat, stats in sorted(per_category.items()): | ||
| print( | ||
| f" {cat:<14} | {stats['winrate']:>11.1%} | " | ||
| f"{stats['num_wins']:>4} | {stats['num_losses']:>6} | {stats['num_ties']:>4}" | ||
| ) | ||
|
|
||
| per_turn = results.get("per_turn") | ||
| if per_turn: | ||
| print("\nPer-Turn Breakdown:") | ||
| for turn, stats in sorted(per_turn.items()): | ||
| print( | ||
| f" Turn {turn} Win Rate(A): {stats['winrate']:.1%} " | ||
| f"(W:{stats['num_wins']} L:{stats['num_losses']} T:{stats['num_ties']})" | ||
| ) | ||
| print("=" * 60 + "\n") | ||
|
|
||
|
|
||
| def _compute_grouped_stats( | ||
| preferences: pd.Series, | ||
| metadata: list[dict[str, object]], | ||
| group_by: str, | ||
| ) -> dict[object, dict[str, float | int]]: | ||
| grouped: dict[object, list[float]] = {} | ||
| for meta, pref in zip(metadata, preferences, strict=True): | ||
| key = meta.get(group_by) | ||
| if key is None: | ||
| continue | ||
| grouped.setdefault(key, []).append(pref) | ||
| return {key: compute_pref_summary(pd.Series(vals)) for key, vals in grouped.items()} | ||
|
|
||
|
|
||
| def _parse_preferences_from_annotations( | ||
| annotations: list, | ||
| score_parser: PairScore, | ||
| ) -> pd.Series: | ||
| return pd.Series( | ||
| [ | ||
| score_parser.parse_model_raw(annotation.judge_completion) | ||
| for annotation in annotations | ||
| ] | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class JudgeAnnotationResult: | ||
| annotations: list | ||
| annotations_reversed: list | ||
| metadata_for_annotations: list[dict[str, object]] | ||
| metadata_for_reversed_annotations: list[dict[str, object]] | ||
| preferences: pd.Series | ||
| combined_metadata: list[dict[str, object]] | ||
|
|
||
|
|
||
| def _make_judge_annotation( | ||
| *, | ||
| judge_chat_model, | ||
| instructions: list[str], | ||
| completions_A: list[str], | ||
| completions_B: list[str], | ||
| metadata: list[dict[str, object]], | ||
| score_parser: PairScore, | ||
| provide_explanation: bool, | ||
| swap_mode: str, | ||
| truncate_input_chars: int | None, | ||
| use_tqdm: bool, | ||
| system_prompt: str | None = None, | ||
| user_prompt_template: str | None = None, | ||
| ) -> JudgeAnnotationResult: | ||
| if not instructions: | ||
| raise ValueError("instructions must be non-empty") | ||
|
|
||
| annotations = annotate_battles( | ||
| judge_chat_model=judge_chat_model, | ||
| instructions=instructions, | ||
| completions_A=completions_A, | ||
| completions_B=completions_B, | ||
| provide_explanation=provide_explanation, | ||
| system_prompt=system_prompt, | ||
| user_prompt_template=user_prompt_template, | ||
| truncate_input_chars=truncate_input_chars, | ||
| use_tqdm=use_tqdm, | ||
| ) | ||
| preference_parts = [_parse_preferences_from_annotations(annotations, score_parser)] | ||
|
|
||
| annotations_reversed: list = [] | ||
| metadata_for_reversed_annotations: list[dict[str, object]] = [] | ||
| combined_metadata = list(metadata) | ||
|
|
||
| if swap_mode == "both": | ||
| print("Correction for judge bias towards a certain model position is set.") | ||
| print("Evaluating completions with models reversed.") | ||
| annotations_reversed = annotate_battles( | ||
| judge_chat_model=judge_chat_model, | ||
| instructions=instructions, | ||
| completions_A=completions_B, | ||
| completions_B=completions_A, | ||
| provide_explanation=provide_explanation, | ||
| system_prompt=system_prompt, | ||
| user_prompt_template=user_prompt_template, | ||
| truncate_input_chars=truncate_input_chars, | ||
| use_tqdm=use_tqdm, | ||
| ) | ||
| prefs_reversed = _parse_preferences_from_annotations( | ||
| annotations_reversed, score_parser | ||
| ) | ||
| preference_parts.append(1 - prefs_reversed) | ||
| metadata_for_reversed_annotations = list(metadata) | ||
| combined_metadata.extend(metadata) | ||
|
|
||
| preferences = pd.concat(preference_parts).reset_index(drop=True) | ||
| return JudgeAnnotationResult( | ||
| annotations=annotations, | ||
| annotations_reversed=annotations_reversed, | ||
| metadata_for_annotations=list(metadata), | ||
| metadata_for_reversed_annotations=metadata_for_reversed_annotations, | ||
| preferences=preferences, | ||
| combined_metadata=combined_metadata, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name of the file is a bit misleading, probably eval_utils.py is better.