diff --git a/docs/model_guide.md b/docs/model_guide.md index 72a478e0f3..e068d0b950 100644 --- a/docs/model_guide.md +++ b/docs/model_guide.md @@ -66,7 +66,7 @@ All three request types take as input `requests` of type `list[Instance]` that h - It should return `(ll,) : Tuple[float]` , a.k.a. solely the *loglikelihood* of producing each piece of text given no starting input. -To allow a model to be evaluated on all types of tasks, you will need to implement these three types of measurements (note that `loglikelihood_rolling` is a special case of `loglikelihood`). For a reference implementation, check out `lm_eval/models/huggingface.py` ! +To allow a model to be evaluated on all types of tasks, you will need to implement these three types of measurements (note that `loglikelihood_rolling` is a special case of `loglikelihood`). For a reference implementation, check out `lm_eval/models/huggingface.py` ! Additionally, check out `lm_eval.api.model.TemplateLM` for a class that abstracts away some commonly used functions across LM subclasses, or see if your model would lend itself well to subclassing the `lm_eval.models.huggingface.HFLM` class and overriding just the initialization or a couple methods! **Tip: be careful of indexing in loglikelihood!** diff --git a/lm_eval/api/model.py b/lm_eval/api/model.py index df829af592..7f93cc4394 100644 --- a/lm_eval/api/model.py +++ b/lm_eval/api/model.py @@ -247,3 +247,61 @@ def fn(requests): def get_cache_hook(self): return CacheHook(self) + + +class TemplateLM(LM): + """ + A class acting as intermediary between the LM base class + and boilerplate often included in other LM subclasses. + """ + + @property + @abc.abstractmethod + def eot_token_id(self): + pass + + @abc.abstractmethod + def tok_encode(self, string: str, **kwargs): + pass + + @abc.abstractmethod + def _loglikelihood_tokens(self, requests, **kwargs): + pass + + def _encode_pair(self, context, continuation): + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + + whole_enc = self.tok_encode(context + continuation, add_special_tokens=False) + context_enc = self.tok_encode(context, add_special_tokens=False) + + context_enc_len = len(context_enc) + continuation_enc = whole_enc[context_enc_len:] + + return context_enc, continuation_enc + + def loglikelihood(self, requests) -> List[Tuple[float, bool]]: + new_reqs = [] + for context, continuation in [req.args for req in requests]: + if context == "": + # end of text as context + context_enc, continuation_enc = ( + [self.eot_token_id], + self.tok_encode(continuation), + ) + else: + context_enc, continuation_enc = self._encode_pair(context, continuation) + + new_reqs.append(((context, continuation), context_enc, continuation_enc)) + + return self._loglikelihood_tokens(new_reqs) + + @abc.abstractmethod + def loglikelihood_rolling(self, requests) -> List[Tuple[float, bool]]: + pass + + @abc.abstractmethod + def generate_until(self, requests) -> List[str]: + pass diff --git a/lm_eval/models/huggingface.py b/lm_eval/models/huggingface.py index 5d9d9947a2..cedbe00fd9 100644 --- a/lm_eval/models/huggingface.py +++ b/lm_eval/models/huggingface.py @@ -24,7 +24,7 @@ from lm_eval import utils from lm_eval.api.instance import Instance -from lm_eval.api.model import LM +from lm_eval.api.model import TemplateLM from lm_eval.api.registry import register_model from lm_eval.models.utils import ( Collator, @@ -64,7 +64,7 @@ def _get_accelerate_args( @register_model("hf-auto", "hf", "huggingface") -class HFLM(LM): +class HFLM(TemplateLM): """ An abstracted Huggingface model class. Enables usage with both models of `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes. @@ -780,39 +780,6 @@ def _select_cont_toks( return logits - def _encode_pair( - self, context: str, continuation: str - ) -> Tuple[List[int], List[int]]: - n_spaces = len(context) - len(context.rstrip()) - if n_spaces > 0: - continuation = context[-n_spaces:] + continuation - context = context[:-n_spaces] - - whole_enc = self.tok_encode(context + continuation, add_special_tokens=False) - context_enc = self.tok_encode(context, add_special_tokens=False) - - # whole_enc = self.tok_encode(context + continuation) - # context_enc = self.tok_encode(context, add_special_tokens=False) - context_enc_len = len(context_enc) - continuation_enc = whole_enc[context_enc_len:] - return context_enc, continuation_enc - - def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: - new_reqs = [] - for context, continuation in [req.args for req in requests]: - if context == "": - # end of text as context - context_enc, continuation_enc = ( - [self.eot_token_id], - self.tok_encode(continuation), - ) - else: - context_enc, continuation_enc = self._encode_pair(context, continuation) - - new_reqs.append(((context, continuation), context_enc, continuation_enc)) - - return self._loglikelihood_tokens(requests=new_reqs) - def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: loglikelihoods = [] diff --git a/lm_eval/models/neuron_optimum.py b/lm_eval/models/neuron_optimum.py index d20c3be1ac..ca1421e8fe 100644 --- a/lm_eval/models/neuron_optimum.py +++ b/lm_eval/models/neuron_optimum.py @@ -15,7 +15,7 @@ import lm_eval.models.utils from lm_eval import utils -from lm_eval.api.model import LM +from lm_eval.api.model import TemplateLM from lm_eval.api.registry import register_model from lm_eval.models.utils import stop_sequences_criteria @@ -172,7 +172,7 @@ def generate( @register_model("neuronx") -class NEURON_HF(LM): +class NEURON_HF(TemplateLM): """ Enables usage with on AWS Neuron using the HuggingFace Transformers + Transformers neuronx library. @@ -447,37 +447,6 @@ def _select_cont_toks(self, logits, contlen=None, inplen=None): return logits - def _encode_pair(self, context, continuation): - n_spaces = len(context) - len(context.rstrip()) - if n_spaces > 0: - continuation = context[-n_spaces:] + continuation - context = context[:-n_spaces] - - whole_enc = self.tok_encode(context + continuation, add_special_tokens=False) - context_enc = self.tok_encode(context, add_special_tokens=False) - - # whole_enc = self.tok_encode(context + continuation) - # context_enc = self.tok_encode(context, add_special_tokens=False) - context_enc_len = len(context_enc) - continuation_enc = whole_enc[context_enc_len:] - return context_enc, continuation_enc - - def loglikelihood(self, requests): - new_reqs = [] - for context, continuation in [req.args for req in requests]: - if context == "": - # end of text as context - context_enc, continuation_enc = ( - [self.eot_token_id], - self.tok_encode(continuation), - ) - else: - context_enc, continuation_enc = self._encode_pair(context, continuation) - - new_reqs.append(((context, continuation), context_enc, continuation_enc)) - - return self._loglikelihood_tokens(new_reqs) - def loglikelihood_rolling(self, requests): loglikelihoods = [] diff --git a/lm_eval/models/openai_completions.py b/lm_eval/models/openai_completions.py index 2497aacb52..1ed09a7a92 100644 --- a/lm_eval/models/openai_completions.py +++ b/lm_eval/models/openai_completions.py @@ -8,7 +8,7 @@ import lm_eval.models.utils from lm_eval import utils -from lm_eval.api.model import LM +from lm_eval.api.model import LM, TemplateLM from lm_eval.api.registry import register_model from lm_eval.models.utils import retry_on_specific_exceptions from lm_eval.utils import eval_logger @@ -75,7 +75,7 @@ def completion(): @register_model("openai-completions", "local-completions") -class OpenaiCompletionsLM(LM): +class OpenaiCompletionsLM(TemplateLM): _DEFAULT_MAX_LENGTH = 2048 def __init__( @@ -171,41 +171,12 @@ def device(self): # Isn't used because we override _loglikelihood_tokens raise NotImplementedError() - def tok_encode(self, string: str) -> List[int]: + def tok_encode(self, string: str, **kwargs) -> List[int]: return self.tokenizer.encode(string) def tok_decode(self, tokens: List[int]) -> str: return self.tokenizer.decode(tokens) - def _encode_pair( - self, context: str, continuation: str - ) -> Tuple[List[int], List[int]]: - n_spaces = len(context) - len(context.rstrip()) - if n_spaces > 0: - continuation = context[-n_spaces:] + continuation - context = context[:-n_spaces] - whole_enc = self.tok_encode(context + continuation) - context_enc = self.tok_encode(context) - context_enc_len = len(context_enc) - continuation_enc = whole_enc[context_enc_len:] - return context_enc, continuation_enc - - def loglikelihood(self, requests) -> List[Tuple[float, bool]]: - new_reqs = [] - for context, continuation in [req.args for req in requests]: - if context == "": - # end of text as context - context_enc, continuation_enc = ( - [self.eot_token_id], - self.tok_encode(continuation), - ) - else: - context_enc, continuation_enc = self._encode_pair(context, continuation) - - new_reqs.append(((context, continuation), context_enc, continuation_enc)) - - return self._loglikelihood_tokens(new_reqs) - def _loglikelihood_tokens( self, requests, disable_tqdm: bool = False ) -> List[Tuple[float, bool]]: diff --git a/lm_eval/models/vllm_causallms.py b/lm_eval/models/vllm_causallms.py index e0894befbd..164d38c0e4 100644 --- a/lm_eval/models/vllm_causallms.py +++ b/lm_eval/models/vllm_causallms.py @@ -5,7 +5,7 @@ from tqdm import tqdm from lm_eval.api.instance import Instance -from lm_eval.api.model import LM +from lm_eval.api.model import TemplateLM from lm_eval.api.registry import register_model from lm_eval.models.utils import Collator, divide from lm_eval.utils import ( @@ -35,7 +35,7 @@ def run_inference_one_model( @register_model("vllm") -class VLLM(LM): +class VLLM(TemplateLM): _DEFAULT_MAX_LENGTH = 2048 def __init__( @@ -194,37 +194,6 @@ def _model_generate( ) return outputs - def _encode_pair( - self, context: str, continuation: str - ) -> Tuple[List[int], List[int]]: - n_spaces = len(context) - len(context.rstrip()) - if n_spaces > 0: - continuation = context[-n_spaces:] + continuation - context = context[:-n_spaces] - - whole_enc = self.tok_encode(context + continuation, add_special_tokens=False) - context_enc = self.tok_encode(context, add_special_tokens=False) - - context_enc_len = len(context_enc) - continuation_enc = whole_enc[context_enc_len:] - return context_enc, continuation_enc - - def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: - new_reqs = [] - for context, continuation in [req.args for req in requests]: - if context == "": - # end of text as context - context_enc, continuation_enc = ( - [self.eot_token_id], - self.tok_encode(continuation), - ) - else: - context_enc, continuation_enc = self._encode_pair(context, continuation) - - new_reqs.append(((context, continuation), context_enc, continuation_enc)) - - return self._loglikelihood_tokens(new_reqs) - def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: loglikelihoods = []