Skip to content

Commit

Permalink
Merge branch 'jp-stable' into feature/argparse-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
polm committed Oct 24, 2023
2 parents 8324a7e + 6af55ef commit 25a9061
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 135 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
| <a target="_blank" href="https://huggingface.co/cyberagent/open-calm-7b" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">cyberagent-open-calm-7b</a> | 38.8 | 24.22 | 37.63 | 74.12 | 45.79 | 60.74 | 2.04 | 65.07 | 0.8 | <a target="_blank" href="https://github.com/Stability-AI/lm-evaluation-harness/tree/jp-stable/models/cyberagent/cyberagent-open-calm-7b/harness.sh" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">models/cyberagent/cyberagent-open-calm-7b/harness.sh</a> |
| <a target="_blank" href="https://huggingface.co/cyberagent/open-calm-3b" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">cyberagent-open-calm-3b</a> | 38.61 | 27.79 | 40.35 | 86.21 | 40.45 | 46.91 | 1.95 | 63.61 | 1.6 | <a target="_blank" href="https://github.com/Stability-AI/lm-evaluation-harness/tree/jp-stable/models/cyberagent/cyberagent-open-calm-3b/harness.sh" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">models/cyberagent/cyberagent-open-calm-3b/harness.sh</a> |
| <a target="_blank" href="https://huggingface.co/rinna/japanese-gpt-1b" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">rinna-japanese-gpt-1b</a> | 36.92 | 34.76 | 37.67 | 87.86 | 26.18 | 37.03 | 5.34 | 64.55 | 2 | <a target="_blank" href="https://github.com/Stability-AI/lm-evaluation-harness/tree/jp-stable/models/rinna/rinna-japanese-gpt-1b/harness.sh" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">models/rinna/rinna-japanese-gpt-1b/harness.sh</a> |

| <a target="_blank" href="https://huggingface.co/rinna/japanese-gpt-neox-small" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">rinna-japanese-gpt-neox-small</a> | 31.12 | 34.22 | 30.11 | 83.35 | 5.80 | 31.78 | 3.85 | 57.24 | 1.6 | <a target="_blank" href="https://github.com/Stability-AI/lm-evaluation-harness/tree/jp-stable/models/rinna/rinna-japanese-gpt-neox-small/harness.sh" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">models/rinna/rinna-japanese-gpt-neox-small/harness.sh</a> |
## How to evaluate your model

1. git clone https://github.com/Stability-AI/lm-evaluation-harness/tree/jp-stable
Expand Down
14 changes: 14 additions & 0 deletions lm_eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,9 @@ def process_results(self, doc, results):
return {
"acc": acc,
"acc_norm": acc_norm,
"details": {
"scores": results,
},
}

def higher_is_better(self):
Expand Down Expand Up @@ -722,6 +725,12 @@ class BalancedMultipleChoiceTask(MultipleChoiceTask):
def process_results(self, doc, results):
gold = doc["gold"]

# This isn't very clean, but it may be the best we can do since lm ops
# are submitted as an iterator for batching
response = None
if isinstance(results[-1], str):
response = results.pop()

pred = np.argmax(results)
acc = 1.0 if np.argmax(results) == gold else 0.0
completion_len = np.array([float(len(i)) for i in doc["choices"]])
Expand All @@ -733,6 +742,11 @@ def process_results(self, doc, results):
"balanced_acc": (acc, gold),
"mcc": (gold, pred),
"macro_f1": (gold, pred),
"details": {
"question": self.doc_to_text(doc),
"response": response,
"scores": results,
},
}

def higher_is_better(self):
Expand Down
158 changes: 25 additions & 133 deletions lm_eval/tasks/ja/jaqket_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,125 +138,41 @@ def doc_to_target(self, doc):
answer = answer_list[0]
return answer

def fewshot_context(
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param provide_description: bool
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
:param rnd: random.Random
The pseudo-random number generator used to randomly sample examples.
WARNING: This is currently a required arg although it's optionalized with a default `None`.
:param description: str
The task's description that will be prepended to the fewshot examples.
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
assert not provide_description, (
"The `provide_description` arg will be removed in future versions. To prepend "
"a custom description to the context, supply the corresponding string via the "
"`description` arg."
def fewshot_context(self, doc, num_fewshot, **kwargs):
max_num_tokens = max(
[len(self._tokenize(answer)) for answer in doc["answers"]["text"]]
)
max_length = self.max_length - max_num_tokens

# If the prompt is too long with fewshot examples, reduce the number of
# examples until it fits.
while num_fewshot >= 0:
ctx = super().fewshot_context(doc, num_fewshot, **kwargs)
if len(self._tokenize(ctx)) <= max_length:
doc["context"] = ctx
return ctx
num_fewshot -= 1

# if we got here then even 0 fewshot is too long
return ValueError(
f"0-shot prompt is too long for max length {max_length}:\n{ctx}"
)
if provide_description is not None:
# nudge people to not specify it at all
print(
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
)

if hasattr(self, "FEWSHOT_SEP"):
FEWSHOT_SEP = self.FEWSHOT_SEP
elif hasattr(self, "SEP"):
FEWSHOT_SEP = f"{self.SEP}{self.SEP}"
else:
FEWSHOT_SEP = "\n\n"

if description:
description += FEWSHOT_SEP
elif hasattr(self, "DESCRIPTION"):
description = self.DESCRIPTION
else:
description = ""

if num_fewshot == 0:
labeled_examples = ""
def _tokenize(self, text, **kwargs):
encode_fn = self.tokenizer.encode
if "add_special_tokens" in inspect.getfullargspec(encode_fn).args:
encode_params = dict(add_special_tokens=False)
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
else self.test_docs()
)

fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)

# get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]

labeled_examples = (
FEWSHOT_SEP.join(
[
self.doc_to_answering_text(doc) + self.doc_to_target(doc)
for doc in fewshotex
]
)
+ FEWSHOT_SEP
)

example = self.doc_to_text(doc)
return description + labeled_examples + example

def preprocess_ctx(self, ctx, max_length):
# if ctx fits in max length, return
if len(self.tokenizer.encode(ctx)) <= max_length:
return ctx

# if ctx is too long, split on a tag that separates each example
description, remainder = ctx.split(self.FEWSHOT_SEP, 1)
ctxs = remainder.split(self.FEWSHOT_SEP)

# if there is no example and still the description + QA prompt is too long, fail
if len(ctxs) < 2:
raise ValueError(
f"description + QA prompt with no example (0-shot) doesn't fit in max_length. ctx: {ctx}"
)

# delete the first example, the last includes QA prompt to be answered by lm
del ctxs[0]

# recur
return self.preprocess_ctx(
self.FEWSHOT_SEP.join([description, *ctxs]), max_length
)
encode_params = {}
return encode_fn(text, **encode_params, **kwargs)

def construct_requests(self, doc, ctx):
if DYNAMIC_MAX_LENGTH == "false" or not hasattr(self.tokenizer, "encode"):
continuation = rf.greedy_until(ctx, [self.SEP])
else:
encode_fn = self.tokenizer.encode
if "add_special_tokens" in inspect.getfullargspec(encode_fn).args:
encode_params = dict(add_special_tokens=False)
else:
encode_params = {}
max_num_tokens = max(
[
len(encode_fn(answer, **encode_params))
for answer in doc["answers"]["text"]
]
[len(self._tokenize(answer)) for answer in doc["answers"]["text"]]
)
ctx = self.preprocess_ctx(ctx, max_length=self.max_length - max_num_tokens)
continuation = rf.greedy_until(ctx, [self.SEP], max_num_tokens)
return continuation

Expand Down Expand Up @@ -433,30 +349,6 @@ def doc_to_answering_text(self, doc):
qa_prompt = self.doc_to_qa_prompt(doc)
return f"ユーザー: {answer_candidate}{self.SEP}{qa_prompt}{self.SEP}システム: "

def preprocess_ctx(self, ctx, max_length):
# if ctx fits in max length, return
if len(self.tokenizer.encode(ctx)) <= max_length:
return ctx

# if ctx is too long, split on a tag that separates each example
description, remainder = ctx.split(self.END_OF_DESCRIPTION, 1)
ctxs = remainder.split(self.START_OF_FEWSHOT)

# if there is no example and still the description + QA prompt is too long, fail
if len(ctxs) < 2:
raise ValueError(
f"description + QA prompt with no example (0-shot) doesn't fit in max_length. ctx: {ctx}"
)

# delete the first example, the last includes QA prompt to be answered by lm
del ctxs[1]

new_ctx = self.END_OF_DESCRIPTION.join(
[description, self.START_OF_FEWSHOT.join(ctxs)]
)
# recur
return self.preprocess_ctx(new_ctx, max_length)


class JAQKETV2WithRinnaBilingualInstructionSFT(JAQKETV2WithRinnaInstructionSFT):
"""
Expand Down
33 changes: 33 additions & 0 deletions lm_eval/tasks/ja/jcommonsenseqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
Homepage: https://github.com/yahoojapan/JGLUE
"""
import os
import warnings
import time

from lm_eval.base import MultipleChoiceTask, rf
import numpy as np

Expand Down Expand Up @@ -118,13 +121,22 @@ class JCommonsenseQAWithFintanPrompt(JCommonsenseQA):
DESCRIPTION = (
"質問と回答の選択肢を入力として受け取り、選択肢から回答を選択してください。なお、回答は選択肢の番号(例:0)でするものとします。 \n\n"
)
DID_WARNING = False

def doc_to_text(self, doc):
"""
質問:question
選択肢:0.choice0,1.choice1, ...,4.choice4
回答:
"""
if not self.DID_WARNING:
warnings.warn(
"#" * 100
+ "\n\nprompt version `0.2` for JCommonsenseQA tends to output low scores! We highly recommend using `0.2.1` instead!\n\n"
+ "#" * 100
)
self.DID_WARNING = True
time.sleep(5)
choices = ",".join(
[f"{idx}.{choice}" for idx, choice in enumerate(doc["choices"])]
)
Expand All @@ -134,6 +146,26 @@ def doc_to_target(self, doc):
return f"{doc['gold']}"


class JCommonsenseQAWithFintanPromptV21(JCommonsenseQA):
VERSION = 1.1
PROMPT_VERSION = "0.2.1"
DESCRIPTION = "与えられた選択肢の中から、最適な答えを選んでください。 \n\n"

def doc_to_text(self, doc):
"""
与えられた選択肢の中から、最適な答えを選んでください。
質問:{question}
選択肢:
- {choice0}
- {choice4}
回答:
"""
choices = "\n".join([f"- {choice}" for choice in doc["choices"]])
input_text = f"質問:{doc['goal']}\n選択肢:\n{choices}\n回答:"
return input_text


class JCommonsenseQAWithJAAlpacaPrompt(JCommonsenseQA):
"""
This prompt format was inspired by the below data in fujiki/japanese_alpaca_data.
Expand Down Expand Up @@ -246,6 +278,7 @@ def doc_to_text(self, doc):
VERSIONS = [
JCommonsenseQA,
JCommonsenseQAWithFintanPrompt,
JCommonsenseQAWithFintanPromptV21,
JCommonsenseQAWithJAAlpacaPrompt,
JCommonsenseQAWithRinnaInstructionSFT,
JCommonsenseQAWithRinnaBilingualInstructionSFT,
Expand Down
4 changes: 4 additions & 0 deletions lm_eval/tasks/ja/jnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class JNLIWithFintanPrompt(BalancedMultipleChoiceTask):
+ "- そのいずれでもない場合はneutralと出力\n\n"
)
CHOICES = ["entailment", "contradiction", "neutral"]
SEP = "\n"

def has_training_docs(self):
return True
Expand Down Expand Up @@ -86,6 +87,9 @@ def construct_requests(self, doc, ctx):
lls = [
rf.loglikelihood(ctx, "{}".format(choice))[0] for choice in doc["choices"]
]
# this is only used for error analysis
if os.environ.get("DEBUG_MULTIPLECHOICE"):
lls.append(rf.greedy_until(ctx, [self.SEP]))
return lls


Expand Down
4 changes: 3 additions & 1 deletion lm_eval/tasks/ja/marc_ja.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class MARCJaWithFintanPrompt(BalancedMultipleChoiceTask):
DATASET_NAME = "MARC-ja"
DESCRIPTION = "製品レビューをnegativeかpositiveのいずれかのセンチメントに分類してください。出力は小文字化してください。 \n\n"
CHOICES = ["positive", "negative"]
SEP = "\n"

def has_training_docs(self):
return True
Expand Down Expand Up @@ -80,7 +81,8 @@ def construct_requests(self, doc, ctx):
]

# this is only used for error analysis
# lls.append(rf.greedy_until(ctx, [self.SEP]))
if os.environ.get("DEBUG_MULTIPLECHOICE"):
lls.append(rf.greedy_until(ctx, [self.SEP]))

return lls

Expand Down
3 changes: 3 additions & 0 deletions models/rinna/rinna-japanese-gpt-neox-small/harness.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MODEL_ARGS="pretrained=rinna/japanese-gpt-neox-small,use_fast=False"
TASK="jcommonsenseqa-1.1-0.2,jnli-1.1-0.2,marc_ja-1.1-0.2,jsquad-1.1-0.2,jaqket_v2-0.2-0.2,xlsum_ja,xwinograd_ja,mgsm"
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3,3,3,2,1,1,0,5" --device "cuda" --output_path "models/rinna/rinna-japanese-gpt-neox-small/result.json"
Loading

0 comments on commit 25a9061

Please sign in to comment.