diff --git a/README.md b/README.md index a8ef59646b..a5a0220e43 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ | cyberagent-open-calm-7b | 38.8 | 24.22 | 37.63 | 74.12 | 45.79 | 60.74 | 2.04 | 65.07 | 0.8 | models/cyberagent/cyberagent-open-calm-7b/harness.sh | | cyberagent-open-calm-3b | 38.61 | 27.79 | 40.35 | 86.21 | 40.45 | 46.91 | 1.95 | 63.61 | 1.6 | models/cyberagent/cyberagent-open-calm-3b/harness.sh | | rinna-japanese-gpt-1b | 36.92 | 34.76 | 37.67 | 87.86 | 26.18 | 37.03 | 5.34 | 64.55 | 2 | models/rinna/rinna-japanese-gpt-1b/harness.sh | - +| rinna-japanese-gpt-neox-small | 31.12 | 34.22 | 30.11 | 83.35 | 5.80 | 31.78 | 3.85 | 57.24 | 1.6 | models/rinna/rinna-japanese-gpt-neox-small/harness.sh | ## How to evaluate your model 1. git clone https://github.com/Stability-AI/lm-evaluation-harness/tree/jp-stable diff --git a/lm_eval/base.py b/lm_eval/base.py index 3749289f68..483b456a3f 100644 --- a/lm_eval/base.py +++ b/lm_eval/base.py @@ -695,6 +695,9 @@ def process_results(self, doc, results): return { "acc": acc, "acc_norm": acc_norm, + "details": { + "scores": results, + }, } def higher_is_better(self): @@ -722,6 +725,12 @@ class BalancedMultipleChoiceTask(MultipleChoiceTask): def process_results(self, doc, results): gold = doc["gold"] + # This isn't very clean, but it may be the best we can do since lm ops + # are submitted as an iterator for batching + response = None + if isinstance(results[-1], str): + response = results.pop() + pred = np.argmax(results) acc = 1.0 if np.argmax(results) == gold else 0.0 completion_len = np.array([float(len(i)) for i in doc["choices"]]) @@ -733,6 +742,11 @@ def process_results(self, doc, results): "balanced_acc": (acc, gold), "mcc": (gold, pred), "macro_f1": (gold, pred), + "details": { + "question": self.doc_to_text(doc), + "response": response, + "scores": results, + }, } def higher_is_better(self): diff --git a/lm_eval/tasks/ja/jaqket_v2.py b/lm_eval/tasks/ja/jaqket_v2.py index f7f4cf1ecd..b6ba389003 100644 --- a/lm_eval/tasks/ja/jaqket_v2.py +++ b/lm_eval/tasks/ja/jaqket_v2.py @@ -138,125 +138,41 @@ def doc_to_target(self, doc): answer = answer_list[0] return answer - def fewshot_context( - self, doc, num_fewshot, provide_description=None, rnd=None, description=None - ): - """Returns a fewshot context string that is made up of a prepended description - (if provided), the `num_fewshot` number of examples, and an appended prompt example. - - :param doc: str - The document as returned from training_docs, validation_docs, or test_docs. - :param num_fewshot: int - The number of fewshot examples to provide in the returned context string. - :param provide_description: bool - Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method - :param rnd: random.Random - The pseudo-random number generator used to randomly sample examples. - WARNING: This is currently a required arg although it's optionalized with a default `None`. - :param description: str - The task's description that will be prepended to the fewshot examples. - :returns: str - The fewshot context. - """ - assert ( - rnd is not None - ), "A `random.Random` generator argument must be provided to `rnd`" - assert not provide_description, ( - "The `provide_description` arg will be removed in future versions. To prepend " - "a custom description to the context, supply the corresponding string via the " - "`description` arg." + def fewshot_context(self, doc, num_fewshot, **kwargs): + max_num_tokens = max( + [len(self._tokenize(answer)) for answer in doc["answers"]["text"]] + ) + max_length = self.max_length - max_num_tokens + + # If the prompt is too long with fewshot examples, reduce the number of + # examples until it fits. + while num_fewshot >= 0: + ctx = super().fewshot_context(doc, num_fewshot, **kwargs) + if len(self._tokenize(ctx)) <= max_length: + doc["context"] = ctx + return ctx + num_fewshot -= 1 + + # if we got here then even 0 fewshot is too long + return ValueError( + f"0-shot prompt is too long for max length {max_length}:\n{ctx}" ) - if provide_description is not None: - # nudge people to not specify it at all - print( - "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict" - ) - - if hasattr(self, "FEWSHOT_SEP"): - FEWSHOT_SEP = self.FEWSHOT_SEP - elif hasattr(self, "SEP"): - FEWSHOT_SEP = f"{self.SEP}{self.SEP}" - else: - FEWSHOT_SEP = "\n\n" - - if description: - description += FEWSHOT_SEP - elif hasattr(self, "DESCRIPTION"): - description = self.DESCRIPTION - else: - description = "" - if num_fewshot == 0: - labeled_examples = "" + def _tokenize(self, text, **kwargs): + encode_fn = self.tokenizer.encode + if "add_special_tokens" in inspect.getfullargspec(encode_fn).args: + encode_params = dict(add_special_tokens=False) else: - # for sets with no training docs, draw from other set *but ensure no overlap with current doc* - if self.has_training_docs(): - fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd) - else: - if self._fewshot_docs is None: - self._fewshot_docs = list( - self.validation_docs() - if self.has_validation_docs() - else self.test_docs() - ) - - fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) - - # get rid of the doc that's the one we're evaluating, if it's in the fewshot - fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] - - labeled_examples = ( - FEWSHOT_SEP.join( - [ - self.doc_to_answering_text(doc) + self.doc_to_target(doc) - for doc in fewshotex - ] - ) - + FEWSHOT_SEP - ) - - example = self.doc_to_text(doc) - return description + labeled_examples + example - - def preprocess_ctx(self, ctx, max_length): - # if ctx fits in max length, return - if len(self.tokenizer.encode(ctx)) <= max_length: - return ctx - - # if ctx is too long, split on a tag that separates each example - description, remainder = ctx.split(self.FEWSHOT_SEP, 1) - ctxs = remainder.split(self.FEWSHOT_SEP) - - # if there is no example and still the description + QA prompt is too long, fail - if len(ctxs) < 2: - raise ValueError( - f"description + QA prompt with no example (0-shot) doesn't fit in max_length. ctx: {ctx}" - ) - - # delete the first example, the last includes QA prompt to be answered by lm - del ctxs[0] - - # recur - return self.preprocess_ctx( - self.FEWSHOT_SEP.join([description, *ctxs]), max_length - ) + encode_params = {} + return encode_fn(text, **encode_params, **kwargs) def construct_requests(self, doc, ctx): if DYNAMIC_MAX_LENGTH == "false" or not hasattr(self.tokenizer, "encode"): continuation = rf.greedy_until(ctx, [self.SEP]) else: - encode_fn = self.tokenizer.encode - if "add_special_tokens" in inspect.getfullargspec(encode_fn).args: - encode_params = dict(add_special_tokens=False) - else: - encode_params = {} max_num_tokens = max( - [ - len(encode_fn(answer, **encode_params)) - for answer in doc["answers"]["text"] - ] + [len(self._tokenize(answer)) for answer in doc["answers"]["text"]] ) - ctx = self.preprocess_ctx(ctx, max_length=self.max_length - max_num_tokens) continuation = rf.greedy_until(ctx, [self.SEP], max_num_tokens) return continuation @@ -433,30 +349,6 @@ def doc_to_answering_text(self, doc): qa_prompt = self.doc_to_qa_prompt(doc) return f"ユーザー: {answer_candidate}{self.SEP}{qa_prompt}{self.SEP}システム: " - def preprocess_ctx(self, ctx, max_length): - # if ctx fits in max length, return - if len(self.tokenizer.encode(ctx)) <= max_length: - return ctx - - # if ctx is too long, split on a tag that separates each example - description, remainder = ctx.split(self.END_OF_DESCRIPTION, 1) - ctxs = remainder.split(self.START_OF_FEWSHOT) - - # if there is no example and still the description + QA prompt is too long, fail - if len(ctxs) < 2: - raise ValueError( - f"description + QA prompt with no example (0-shot) doesn't fit in max_length. ctx: {ctx}" - ) - - # delete the first example, the last includes QA prompt to be answered by lm - del ctxs[1] - - new_ctx = self.END_OF_DESCRIPTION.join( - [description, self.START_OF_FEWSHOT.join(ctxs)] - ) - # recur - return self.preprocess_ctx(new_ctx, max_length) - class JAQKETV2WithRinnaBilingualInstructionSFT(JAQKETV2WithRinnaInstructionSFT): """ diff --git a/lm_eval/tasks/ja/jcommonsenseqa.py b/lm_eval/tasks/ja/jcommonsenseqa.py index 0cc6e91bd6..c7d94dc95a 100644 --- a/lm_eval/tasks/ja/jcommonsenseqa.py +++ b/lm_eval/tasks/ja/jcommonsenseqa.py @@ -8,6 +8,9 @@ Homepage: https://github.com/yahoojapan/JGLUE """ import os +import warnings +import time + from lm_eval.base import MultipleChoiceTask, rf import numpy as np @@ -118,6 +121,7 @@ class JCommonsenseQAWithFintanPrompt(JCommonsenseQA): DESCRIPTION = ( "質問と回答の選択肢を入力として受け取り、選択肢から回答を選択してください。なお、回答は選択肢の番号(例:0)でするものとします。 \n\n" ) + DID_WARNING = False def doc_to_text(self, doc): """ @@ -125,6 +129,14 @@ def doc_to_text(self, doc): 選択肢:0.choice0,1.choice1, ...,4.choice4 回答: """ + if not self.DID_WARNING: + warnings.warn( + "#" * 100 + + "\n\nprompt version `0.2` for JCommonsenseQA tends to output low scores! We highly recommend using `0.2.1` instead!\n\n" + + "#" * 100 + ) + self.DID_WARNING = True + time.sleep(5) choices = ",".join( [f"{idx}.{choice}" for idx, choice in enumerate(doc["choices"])] ) @@ -134,6 +146,26 @@ def doc_to_target(self, doc): return f"{doc['gold']}" +class JCommonsenseQAWithFintanPromptV21(JCommonsenseQA): + VERSION = 1.1 + PROMPT_VERSION = "0.2.1" + DESCRIPTION = "与えられた選択肢の中から、最適な答えを選んでください。 \n\n" + + def doc_to_text(self, doc): + """ + 与えられた選択肢の中から、最適な答えを選んでください。 + + 質問:{question} + 選択肢: + - {choice0} + - {choice4} + 回答: + """ + choices = "\n".join([f"- {choice}" for choice in doc["choices"]]) + input_text = f"質問:{doc['goal']}\n選択肢:\n{choices}\n回答:" + return input_text + + class JCommonsenseQAWithJAAlpacaPrompt(JCommonsenseQA): """ This prompt format was inspired by the below data in fujiki/japanese_alpaca_data. @@ -246,6 +278,7 @@ def doc_to_text(self, doc): VERSIONS = [ JCommonsenseQA, JCommonsenseQAWithFintanPrompt, + JCommonsenseQAWithFintanPromptV21, JCommonsenseQAWithJAAlpacaPrompt, JCommonsenseQAWithRinnaInstructionSFT, JCommonsenseQAWithRinnaBilingualInstructionSFT, diff --git a/lm_eval/tasks/ja/jnli.py b/lm_eval/tasks/ja/jnli.py index 2f474d14f5..39ca927837 100644 --- a/lm_eval/tasks/ja/jnli.py +++ b/lm_eval/tasks/ja/jnli.py @@ -45,6 +45,7 @@ class JNLIWithFintanPrompt(BalancedMultipleChoiceTask): + "- そのいずれでもない場合はneutralと出力\n\n" ) CHOICES = ["entailment", "contradiction", "neutral"] + SEP = "\n" def has_training_docs(self): return True @@ -86,6 +87,9 @@ def construct_requests(self, doc, ctx): lls = [ rf.loglikelihood(ctx, "{}".format(choice))[0] for choice in doc["choices"] ] + # this is only used for error analysis + if os.environ.get("DEBUG_MULTIPLECHOICE"): + lls.append(rf.greedy_until(ctx, [self.SEP])) return lls diff --git a/lm_eval/tasks/ja/marc_ja.py b/lm_eval/tasks/ja/marc_ja.py index 46f4261843..67ca22b048 100644 --- a/lm_eval/tasks/ja/marc_ja.py +++ b/lm_eval/tasks/ja/marc_ja.py @@ -39,6 +39,7 @@ class MARCJaWithFintanPrompt(BalancedMultipleChoiceTask): DATASET_NAME = "MARC-ja" DESCRIPTION = "製品レビューをnegativeかpositiveのいずれかのセンチメントに分類してください。出力は小文字化してください。 \n\n" CHOICES = ["positive", "negative"] + SEP = "\n" def has_training_docs(self): return True @@ -80,7 +81,8 @@ def construct_requests(self, doc, ctx): ] # this is only used for error analysis - # lls.append(rf.greedy_until(ctx, [self.SEP])) + if os.environ.get("DEBUG_MULTIPLECHOICE"): + lls.append(rf.greedy_until(ctx, [self.SEP])) return lls diff --git a/models/rinna/rinna-japanese-gpt-neox-small/harness.sh b/models/rinna/rinna-japanese-gpt-neox-small/harness.sh new file mode 100644 index 0000000000..19cd67bd11 --- /dev/null +++ b/models/rinna/rinna-japanese-gpt-neox-small/harness.sh @@ -0,0 +1,3 @@ +MODEL_ARGS="pretrained=rinna/japanese-gpt-neox-small,use_fast=False" +TASK="jcommonsenseqa-1.1-0.2,jnli-1.1-0.2,marc_ja-1.1-0.2,jsquad-1.1-0.2,jaqket_v2-0.2-0.2,xlsum_ja,xwinograd_ja,mgsm" +python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3,3,3,2,1,1,0,5" --device "cuda" --output_path "models/rinna/rinna-japanese-gpt-neox-small/result.json" diff --git a/models/rinna/rinna-japanese-gpt-neox-small/result.json b/models/rinna/rinna-japanese-gpt-neox-small/result.json new file mode 100644 index 0000000000..3fc4f5fb1d --- /dev/null +++ b/models/rinna/rinna-japanese-gpt-neox-small/result.json @@ -0,0 +1,71 @@ +{ + "results": { + "jcommonsenseqa-1.1-0.2": { + "acc": 0.34226988382484363, + "acc_stderr": 0.014190160441497086, + "acc_norm": 0.25022341376228774, + "acc_norm_stderr": 0.012954152571429026 + }, + "jnli-1.1-0.2": { + "acc": 0.3011503697617091, + "acc_stderr": 0.00930063317508552, + "acc_norm": 0.30156121610517667, + "acc_norm_stderr": 0.009304239098715018 + }, + "marc_ja-1.1-0.2": { + "acc": 0.8335691545808277, + "acc_stderr": 0.0049539113964134655, + "acc_norm": 0.8335691545808277, + "acc_norm_stderr": 0.0049539113964134655 + }, + "xwinograd_ja": { + "acc": 0.5724713242961418, + "acc_stderr": 0.0159836786259061 + }, + "jsquad-1.1-0.2": { + "exact_match": 5.808194506978839, + "f1": 19.590196190503015 + }, + "jaqket_v2-0.2-0.2": { + "exact_match": 3.178694158075601, + "f1": 15.319117213447132 + }, + "xlsum_ja": { + "rouge2": 3.8544806224917294 + }, + "mgsm": { + "acc": 0.016, + "acc_stderr": 0.007951661188874339 + } + }, + "versions": { + "jcommonsenseqa-1.1-0.2": 1.1, + "jnli-1.1-0.2": 1.1, + "marc_ja-1.1-0.2": 1.1, + "jsquad-1.1-0.2": 1.1, + "jaqket_v2-0.2-0.2": 0.2, + "xlsum_ja": 1.0, + "xwinograd_ja": 1.0, + "mgsm": 1.0 + }, + "config": { + "model": "hf-causal", + "model_args": "pretrained=rinna/japanese-gpt-neox-small,use_fast=False", + "num_fewshot": [ + 3, + 3, + 3, + 2, + 1, + 1, + 0, + 5 + ], + "batch_size": null, + "device": "cuda", + "no_cache": false, + "limit": null, + "bootstrap_iters": 100000, + "description_dict": {} + } +}