diff --git a/README.md b/README.md
index a8ef59646b..a5a0220e43 100644
--- a/README.md
+++ b/README.md
@@ -20,7 +20,7 @@
| cyberagent-open-calm-7b | 38.8 | 24.22 | 37.63 | 74.12 | 45.79 | 60.74 | 2.04 | 65.07 | 0.8 | models/cyberagent/cyberagent-open-calm-7b/harness.sh |
| cyberagent-open-calm-3b | 38.61 | 27.79 | 40.35 | 86.21 | 40.45 | 46.91 | 1.95 | 63.61 | 1.6 | models/cyberagent/cyberagent-open-calm-3b/harness.sh |
| rinna-japanese-gpt-1b | 36.92 | 34.76 | 37.67 | 87.86 | 26.18 | 37.03 | 5.34 | 64.55 | 2 | models/rinna/rinna-japanese-gpt-1b/harness.sh |
-
+| rinna-japanese-gpt-neox-small | 31.12 | 34.22 | 30.11 | 83.35 | 5.80 | 31.78 | 3.85 | 57.24 | 1.6 | models/rinna/rinna-japanese-gpt-neox-small/harness.sh |
## How to evaluate your model
1. git clone https://github.com/Stability-AI/lm-evaluation-harness/tree/jp-stable
diff --git a/lm_eval/base.py b/lm_eval/base.py
index 3749289f68..483b456a3f 100644
--- a/lm_eval/base.py
+++ b/lm_eval/base.py
@@ -695,6 +695,9 @@ def process_results(self, doc, results):
return {
"acc": acc,
"acc_norm": acc_norm,
+ "details": {
+ "scores": results,
+ },
}
def higher_is_better(self):
@@ -722,6 +725,12 @@ class BalancedMultipleChoiceTask(MultipleChoiceTask):
def process_results(self, doc, results):
gold = doc["gold"]
+ # This isn't very clean, but it may be the best we can do since lm ops
+ # are submitted as an iterator for batching
+ response = None
+ if isinstance(results[-1], str):
+ response = results.pop()
+
pred = np.argmax(results)
acc = 1.0 if np.argmax(results) == gold else 0.0
completion_len = np.array([float(len(i)) for i in doc["choices"]])
@@ -733,6 +742,11 @@ def process_results(self, doc, results):
"balanced_acc": (acc, gold),
"mcc": (gold, pred),
"macro_f1": (gold, pred),
+ "details": {
+ "question": self.doc_to_text(doc),
+ "response": response,
+ "scores": results,
+ },
}
def higher_is_better(self):
diff --git a/lm_eval/tasks/ja/jaqket_v2.py b/lm_eval/tasks/ja/jaqket_v2.py
index f7f4cf1ecd..b6ba389003 100644
--- a/lm_eval/tasks/ja/jaqket_v2.py
+++ b/lm_eval/tasks/ja/jaqket_v2.py
@@ -138,125 +138,41 @@ def doc_to_target(self, doc):
answer = answer_list[0]
return answer
- def fewshot_context(
- self, doc, num_fewshot, provide_description=None, rnd=None, description=None
- ):
- """Returns a fewshot context string that is made up of a prepended description
- (if provided), the `num_fewshot` number of examples, and an appended prompt example.
-
- :param doc: str
- The document as returned from training_docs, validation_docs, or test_docs.
- :param num_fewshot: int
- The number of fewshot examples to provide in the returned context string.
- :param provide_description: bool
- Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
- :param rnd: random.Random
- The pseudo-random number generator used to randomly sample examples.
- WARNING: This is currently a required arg although it's optionalized with a default `None`.
- :param description: str
- The task's description that will be prepended to the fewshot examples.
- :returns: str
- The fewshot context.
- """
- assert (
- rnd is not None
- ), "A `random.Random` generator argument must be provided to `rnd`"
- assert not provide_description, (
- "The `provide_description` arg will be removed in future versions. To prepend "
- "a custom description to the context, supply the corresponding string via the "
- "`description` arg."
+ def fewshot_context(self, doc, num_fewshot, **kwargs):
+ max_num_tokens = max(
+ [len(self._tokenize(answer)) for answer in doc["answers"]["text"]]
+ )
+ max_length = self.max_length - max_num_tokens
+
+ # If the prompt is too long with fewshot examples, reduce the number of
+ # examples until it fits.
+ while num_fewshot >= 0:
+ ctx = super().fewshot_context(doc, num_fewshot, **kwargs)
+ if len(self._tokenize(ctx)) <= max_length:
+ doc["context"] = ctx
+ return ctx
+ num_fewshot -= 1
+
+ # if we got here then even 0 fewshot is too long
+ return ValueError(
+ f"0-shot prompt is too long for max length {max_length}:\n{ctx}"
)
- if provide_description is not None:
- # nudge people to not specify it at all
- print(
- "WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
- )
-
- if hasattr(self, "FEWSHOT_SEP"):
- FEWSHOT_SEP = self.FEWSHOT_SEP
- elif hasattr(self, "SEP"):
- FEWSHOT_SEP = f"{self.SEP}{self.SEP}"
- else:
- FEWSHOT_SEP = "\n\n"
-
- if description:
- description += FEWSHOT_SEP
- elif hasattr(self, "DESCRIPTION"):
- description = self.DESCRIPTION
- else:
- description = ""
- if num_fewshot == 0:
- labeled_examples = ""
+ def _tokenize(self, text, **kwargs):
+ encode_fn = self.tokenizer.encode
+ if "add_special_tokens" in inspect.getfullargspec(encode_fn).args:
+ encode_params = dict(add_special_tokens=False)
else:
- # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
- if self.has_training_docs():
- fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
- else:
- if self._fewshot_docs is None:
- self._fewshot_docs = list(
- self.validation_docs()
- if self.has_validation_docs()
- else self.test_docs()
- )
-
- fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
-
- # get rid of the doc that's the one we're evaluating, if it's in the fewshot
- fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
-
- labeled_examples = (
- FEWSHOT_SEP.join(
- [
- self.doc_to_answering_text(doc) + self.doc_to_target(doc)
- for doc in fewshotex
- ]
- )
- + FEWSHOT_SEP
- )
-
- example = self.doc_to_text(doc)
- return description + labeled_examples + example
-
- def preprocess_ctx(self, ctx, max_length):
- # if ctx fits in max length, return
- if len(self.tokenizer.encode(ctx)) <= max_length:
- return ctx
-
- # if ctx is too long, split on a tag that separates each example
- description, remainder = ctx.split(self.FEWSHOT_SEP, 1)
- ctxs = remainder.split(self.FEWSHOT_SEP)
-
- # if there is no example and still the description + QA prompt is too long, fail
- if len(ctxs) < 2:
- raise ValueError(
- f"description + QA prompt with no example (0-shot) doesn't fit in max_length. ctx: {ctx}"
- )
-
- # delete the first example, the last includes QA prompt to be answered by lm
- del ctxs[0]
-
- # recur
- return self.preprocess_ctx(
- self.FEWSHOT_SEP.join([description, *ctxs]), max_length
- )
+ encode_params = {}
+ return encode_fn(text, **encode_params, **kwargs)
def construct_requests(self, doc, ctx):
if DYNAMIC_MAX_LENGTH == "false" or not hasattr(self.tokenizer, "encode"):
continuation = rf.greedy_until(ctx, [self.SEP])
else:
- encode_fn = self.tokenizer.encode
- if "add_special_tokens" in inspect.getfullargspec(encode_fn).args:
- encode_params = dict(add_special_tokens=False)
- else:
- encode_params = {}
max_num_tokens = max(
- [
- len(encode_fn(answer, **encode_params))
- for answer in doc["answers"]["text"]
- ]
+ [len(self._tokenize(answer)) for answer in doc["answers"]["text"]]
)
- ctx = self.preprocess_ctx(ctx, max_length=self.max_length - max_num_tokens)
continuation = rf.greedy_until(ctx, [self.SEP], max_num_tokens)
return continuation
@@ -433,30 +349,6 @@ def doc_to_answering_text(self, doc):
qa_prompt = self.doc_to_qa_prompt(doc)
return f"ユーザー: {answer_candidate}{self.SEP}{qa_prompt}{self.SEP}システム: "
- def preprocess_ctx(self, ctx, max_length):
- # if ctx fits in max length, return
- if len(self.tokenizer.encode(ctx)) <= max_length:
- return ctx
-
- # if ctx is too long, split on a tag that separates each example
- description, remainder = ctx.split(self.END_OF_DESCRIPTION, 1)
- ctxs = remainder.split(self.START_OF_FEWSHOT)
-
- # if there is no example and still the description + QA prompt is too long, fail
- if len(ctxs) < 2:
- raise ValueError(
- f"description + QA prompt with no example (0-shot) doesn't fit in max_length. ctx: {ctx}"
- )
-
- # delete the first example, the last includes QA prompt to be answered by lm
- del ctxs[1]
-
- new_ctx = self.END_OF_DESCRIPTION.join(
- [description, self.START_OF_FEWSHOT.join(ctxs)]
- )
- # recur
- return self.preprocess_ctx(new_ctx, max_length)
-
class JAQKETV2WithRinnaBilingualInstructionSFT(JAQKETV2WithRinnaInstructionSFT):
"""
diff --git a/lm_eval/tasks/ja/jcommonsenseqa.py b/lm_eval/tasks/ja/jcommonsenseqa.py
index 0cc6e91bd6..c7d94dc95a 100644
--- a/lm_eval/tasks/ja/jcommonsenseqa.py
+++ b/lm_eval/tasks/ja/jcommonsenseqa.py
@@ -8,6 +8,9 @@
Homepage: https://github.com/yahoojapan/JGLUE
"""
import os
+import warnings
+import time
+
from lm_eval.base import MultipleChoiceTask, rf
import numpy as np
@@ -118,6 +121,7 @@ class JCommonsenseQAWithFintanPrompt(JCommonsenseQA):
DESCRIPTION = (
"質問と回答の選択肢を入力として受け取り、選択肢から回答を選択してください。なお、回答は選択肢の番号(例:0)でするものとします。 \n\n"
)
+ DID_WARNING = False
def doc_to_text(self, doc):
"""
@@ -125,6 +129,14 @@ def doc_to_text(self, doc):
選択肢:0.choice0,1.choice1, ...,4.choice4
回答:
"""
+ if not self.DID_WARNING:
+ warnings.warn(
+ "#" * 100
+ + "\n\nprompt version `0.2` for JCommonsenseQA tends to output low scores! We highly recommend using `0.2.1` instead!\n\n"
+ + "#" * 100
+ )
+ self.DID_WARNING = True
+ time.sleep(5)
choices = ",".join(
[f"{idx}.{choice}" for idx, choice in enumerate(doc["choices"])]
)
@@ -134,6 +146,26 @@ def doc_to_target(self, doc):
return f"{doc['gold']}"
+class JCommonsenseQAWithFintanPromptV21(JCommonsenseQA):
+ VERSION = 1.1
+ PROMPT_VERSION = "0.2.1"
+ DESCRIPTION = "与えられた選択肢の中から、最適な答えを選んでください。 \n\n"
+
+ def doc_to_text(self, doc):
+ """
+ 与えられた選択肢の中から、最適な答えを選んでください。
+
+ 質問:{question}
+ 選択肢:
+ - {choice0}
+ - {choice4}
+ 回答:
+ """
+ choices = "\n".join([f"- {choice}" for choice in doc["choices"]])
+ input_text = f"質問:{doc['goal']}\n選択肢:\n{choices}\n回答:"
+ return input_text
+
+
class JCommonsenseQAWithJAAlpacaPrompt(JCommonsenseQA):
"""
This prompt format was inspired by the below data in fujiki/japanese_alpaca_data.
@@ -246,6 +278,7 @@ def doc_to_text(self, doc):
VERSIONS = [
JCommonsenseQA,
JCommonsenseQAWithFintanPrompt,
+ JCommonsenseQAWithFintanPromptV21,
JCommonsenseQAWithJAAlpacaPrompt,
JCommonsenseQAWithRinnaInstructionSFT,
JCommonsenseQAWithRinnaBilingualInstructionSFT,
diff --git a/lm_eval/tasks/ja/jnli.py b/lm_eval/tasks/ja/jnli.py
index 2f474d14f5..39ca927837 100644
--- a/lm_eval/tasks/ja/jnli.py
+++ b/lm_eval/tasks/ja/jnli.py
@@ -45,6 +45,7 @@ class JNLIWithFintanPrompt(BalancedMultipleChoiceTask):
+ "- そのいずれでもない場合はneutralと出力\n\n"
)
CHOICES = ["entailment", "contradiction", "neutral"]
+ SEP = "\n"
def has_training_docs(self):
return True
@@ -86,6 +87,9 @@ def construct_requests(self, doc, ctx):
lls = [
rf.loglikelihood(ctx, "{}".format(choice))[0] for choice in doc["choices"]
]
+ # this is only used for error analysis
+ if os.environ.get("DEBUG_MULTIPLECHOICE"):
+ lls.append(rf.greedy_until(ctx, [self.SEP]))
return lls
diff --git a/lm_eval/tasks/ja/marc_ja.py b/lm_eval/tasks/ja/marc_ja.py
index 46f4261843..67ca22b048 100644
--- a/lm_eval/tasks/ja/marc_ja.py
+++ b/lm_eval/tasks/ja/marc_ja.py
@@ -39,6 +39,7 @@ class MARCJaWithFintanPrompt(BalancedMultipleChoiceTask):
DATASET_NAME = "MARC-ja"
DESCRIPTION = "製品レビューをnegativeかpositiveのいずれかのセンチメントに分類してください。出力は小文字化してください。 \n\n"
CHOICES = ["positive", "negative"]
+ SEP = "\n"
def has_training_docs(self):
return True
@@ -80,7 +81,8 @@ def construct_requests(self, doc, ctx):
]
# this is only used for error analysis
- # lls.append(rf.greedy_until(ctx, [self.SEP]))
+ if os.environ.get("DEBUG_MULTIPLECHOICE"):
+ lls.append(rf.greedy_until(ctx, [self.SEP]))
return lls
diff --git a/models/rinna/rinna-japanese-gpt-neox-small/harness.sh b/models/rinna/rinna-japanese-gpt-neox-small/harness.sh
new file mode 100644
index 0000000000..19cd67bd11
--- /dev/null
+++ b/models/rinna/rinna-japanese-gpt-neox-small/harness.sh
@@ -0,0 +1,3 @@
+MODEL_ARGS="pretrained=rinna/japanese-gpt-neox-small,use_fast=False"
+TASK="jcommonsenseqa-1.1-0.2,jnli-1.1-0.2,marc_ja-1.1-0.2,jsquad-1.1-0.2,jaqket_v2-0.2-0.2,xlsum_ja,xwinograd_ja,mgsm"
+python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3,3,3,2,1,1,0,5" --device "cuda" --output_path "models/rinna/rinna-japanese-gpt-neox-small/result.json"
diff --git a/models/rinna/rinna-japanese-gpt-neox-small/result.json b/models/rinna/rinna-japanese-gpt-neox-small/result.json
new file mode 100644
index 0000000000..3fc4f5fb1d
--- /dev/null
+++ b/models/rinna/rinna-japanese-gpt-neox-small/result.json
@@ -0,0 +1,71 @@
+{
+ "results": {
+ "jcommonsenseqa-1.1-0.2": {
+ "acc": 0.34226988382484363,
+ "acc_stderr": 0.014190160441497086,
+ "acc_norm": 0.25022341376228774,
+ "acc_norm_stderr": 0.012954152571429026
+ },
+ "jnli-1.1-0.2": {
+ "acc": 0.3011503697617091,
+ "acc_stderr": 0.00930063317508552,
+ "acc_norm": 0.30156121610517667,
+ "acc_norm_stderr": 0.009304239098715018
+ },
+ "marc_ja-1.1-0.2": {
+ "acc": 0.8335691545808277,
+ "acc_stderr": 0.0049539113964134655,
+ "acc_norm": 0.8335691545808277,
+ "acc_norm_stderr": 0.0049539113964134655
+ },
+ "xwinograd_ja": {
+ "acc": 0.5724713242961418,
+ "acc_stderr": 0.0159836786259061
+ },
+ "jsquad-1.1-0.2": {
+ "exact_match": 5.808194506978839,
+ "f1": 19.590196190503015
+ },
+ "jaqket_v2-0.2-0.2": {
+ "exact_match": 3.178694158075601,
+ "f1": 15.319117213447132
+ },
+ "xlsum_ja": {
+ "rouge2": 3.8544806224917294
+ },
+ "mgsm": {
+ "acc": 0.016,
+ "acc_stderr": 0.007951661188874339
+ }
+ },
+ "versions": {
+ "jcommonsenseqa-1.1-0.2": 1.1,
+ "jnli-1.1-0.2": 1.1,
+ "marc_ja-1.1-0.2": 1.1,
+ "jsquad-1.1-0.2": 1.1,
+ "jaqket_v2-0.2-0.2": 0.2,
+ "xlsum_ja": 1.0,
+ "xwinograd_ja": 1.0,
+ "mgsm": 1.0
+ },
+ "config": {
+ "model": "hf-causal",
+ "model_args": "pretrained=rinna/japanese-gpt-neox-small,use_fast=False",
+ "num_fewshot": [
+ 3,
+ 3,
+ 3,
+ 2,
+ 1,
+ 1,
+ 0,
+ 5
+ ],
+ "batch_size": null,
+ "device": "cuda",
+ "no_cache": false,
+ "limit": null,
+ "bootstrap_iters": 100000,
+ "description_dict": {}
+ }
+}