diff --git a/README.md b/README.md index e0f8bf10af..c8c5cc07c3 100644 --- a/README.md +++ b/README.md @@ -141,6 +141,15 @@ python main.py \ --tasks hellaswag ``` +GGUF or GGML quantized models can be loaded by using `llama-cpp-python` server: + +```bash +python main.py \ + --model gguf \ + --model_args base_url=http://localhost:8000 \ + --tasks hellaswag +``` + We support wildcards in task names, for example you can run all of the machine-translated lambada tasks via `--task lambada_openai_mt_*`. We currently only support one prompt per task, which we strive to make the "standard" as defined by the benchmark's authors. If you would like to study how varying prompts causes changes in the evaluation score, check out the [BigScience fork](https://github.com/bigscience-workshop/lm-evaluation-harness) of this repo. We are currently working on upstreaming this capability to `main`. diff --git a/lm_eval/models/__init__.py b/lm_eval/models/__init__.py index 8ca27fac81..21bfb73f52 100644 --- a/lm_eval/models/__init__.py +++ b/lm_eval/models/__init__.py @@ -4,6 +4,7 @@ from . import huggingface from . import textsynth from . import dummy +from . import gguf MODEL_REGISTRY = { "hf": gpt2.HFLM, @@ -15,6 +16,7 @@ "anthropic": anthropic_llms.AnthropicLM, "textsynth": textsynth.TextSynthLM, "dummy": dummy.DummyLM, + "gguf": gguf.GGUFLM } diff --git a/lm_eval/models/gguf.py b/lm_eval/models/gguf.py new file mode 100644 index 0000000000..4e72e0fb30 --- /dev/null +++ b/lm_eval/models/gguf.py @@ -0,0 +1,142 @@ +import requests +import logging +import time +from tqdm import tqdm +from requests.exceptions import RequestException +import transformers +from lm_eval.utils import Reorderer +from lm_eval.base import BaseLM + +logger = logging.getLogger(__name__) + + +def get_result(logprobs, context_length): + is_greedy = True + offsets = logprobs['text_offset'] + tokens = logprobs['tokens'] + tokens_logprobs = logprobs['token_logprobs'] + + idx = 0 + while offsets[idx] < context_length: + idx += 1 + continuation_logprobs = sum(tokens_logprobs[idx:-1]) + for i in range(idx, len(tokens)): + token = tokens[i] + top_tokens = logprobs["top_logprobs"][i] + top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x]) + if top_token != token: + is_greedy = False + break + + return continuation_logprobs, is_greedy + + +class GGUFLM(BaseLM): + def __init__(self, base_url, max_length=2048): + super().__init__() + self.base_url = base_url + self.logprobs = 10 + self.temperature = 0.0 + self.max_length = max_length + + def gguf_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs): + for _ in range(retries): + try: + prompt = context + request = {'prompt': prompt, 'logprobs': self.logprobs, + 'temperature': self.temperature} + if continuation: + prompt += continuation + request.update({'prompt': prompt, 'max_tokens': 1, 'echo': True}) + if stop is not None: + request['stop'] = stop + response = requests.post(f"{self.base_url}/v1/completions", json=request) + response.raise_for_status() + return response.json() + except RequestException as e: + logger.error(f"RequestException: {e}") + time.sleep(delay) # wait before retrying + else: + raise Exception(f"Failed to get a valid response after {retries} retries.") + + def loglikelihood(self, requests): + if not requests: + return [] + res = [] + for context, continuation in tqdm(requests): + response = self.gguf_completion(context=context, continuation=continuation) + if response and "choices" in response and response["choices"]: + choice = response["choices"][0] + logprobs = choice.get("logprobs") + if logprobs and "token_logprobs" in logprobs and logprobs["token_logprobs"]: + logprob, is_greedy = get_result(logprobs, len(context)) + res.append((logprob, is_greedy)) + else: + logger.warning("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.") + else: + logger.error(f"Invalid response for loglikelihood. Response: {response}") + assert False + return res + + def greedy_until(self, requests): + if not requests: + return [] + + res = [] + for request in tqdm(requests): + inp = request[0] + request_args = request[1] + until = request_args["until"] + response = self.gguf_completion(context=inp, stop=until) + if response and "choices" in response and response["choices"]: + choice = response["choices"][0] + if "text" in choice: + generated_text = choice["text"].strip() + res.append(generated_text) + else: + logger.error(f"Invalid response for greedy_until. Response: {response}") + res.append(None) # Add default value in case of error + else: + logger.error(f"Invalid response for greedy_until. Response: {response}") + res.append(None) # Add default value in case of error + return res + + def loglikelihood_rolling(self, requests): + raise NotImplementedError("loglikelihood_rolling not yet supported for GGUF models") + + def _model_call(self, inps): + # Placeholder implementation + raise NotImplementedError() + + def _model_generate(self, context, max_length, eos_token_id): + # Placeholder implementation + raise NotImplementedError() + + def tok_encode(self, string: str): + raise NotImplementedError() + + def tok_decode(self, tokens): + raise NotImplementedError() + + @property + def batch_size(self): + # Placeholder implementation + raise NotImplementedError() + + @property + def device(self): + # Placeholder implementation + raise NotImplementedError() + + @property + def eot_token_id(self): + # Placeholder implementation + raise NotImplementedError() + + def max_length(self): + return self.max_length + + @property + def max_gen_toks(self): + # Placeholder implementation + raise NotImplementedError() diff --git a/tests/test_gguf.py b/tests/test_gguf.py new file mode 100644 index 0000000000..6db1556893 --- /dev/null +++ b/tests/test_gguf.py @@ -0,0 +1,79 @@ +import unittest +from unittest.mock import patch +import hashlib +import json +import os +import pickle +from lm_eval.models.gguf import GGUFLM + +base_url = "https://matthoffner-ggml-llm-api.hf.space" + +def gguf_completion_mock(base_url, **kwargs): + # Generate a hash from the parameters + hash_kwargs = {'base_url': base_url, **kwargs} + hash = hashlib.sha256(json.dumps(hash_kwargs, sort_keys=True).encode('utf-8')).hexdigest() + + fname = f"./tests/testdata/ggml_test_{hash}.pkl" + + if os.path.exists(fname): + with open(fname, "rb") as fh: + return pickle.load(fh) + else: + print("The file does not exist, attempting to write...") + if 'stop' in kwargs: + result = {"choices": [{"text": f"generated text until {kwargs['stop']}", "logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]} + else: + result = {"choices": [{"logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]} + + try: + os.makedirs(os.path.dirname(fname), exist_ok=True) + print('Writing file at', fname) + with open(fname, "wb") as fh: + pickle.dump(result, fh) + print('File written successfully') + except Exception as e: + print('File writing failed:', e) + + return result + + +class GGUFLMTest(unittest.TestCase): + @patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock) + def test_loglikelihood(self, gguf_completion_mock): + lm = GGUFLM(base_url) + + # Test loglikelihood + requests = [("context1", "continuation1"), ("context2", "continuation2")] + res = lm.loglikelihood(requests) + + # Assert the loglikelihood response is correct + expected_res = [(logprob, True) for logprob in [-1.2345, -1.2345]] + self.assertEqual(res, expected_res) + + @patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock) + def test_greedy_until(self, gguf_completion_mock): + lm = GGUFLM(base_url) + + # Test greedy_until + requests = [("input1", {"until": "stop1"}), ("input2", {"until": "stop2"})] + res = lm.greedy_until(requests) + + # Assert the greedy_until response is correct + expected_res = ["generated text until stop1", "generated text until stop2"] + self.assertEqual(res, expected_res) + + @patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock) + def test_loglikelihood_rolling(self, gguf_completion_mock): + lm = GGUFLM(base_url) + + # Test loglikelihood_rolling + requests = ["input1", "input2"] + res = lm.loglikelihood_rolling(requests) + + # Assert the loglikelihood_rolling response is correct + expected_res = [(-1.2345, True), (-1.2345, True)] + self.assertEqual(res, expected_res) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/testdata/ggml_test_01d366e32dd8ae86bd079b6822814dcafad69a9082e4cf4db9633eaad47933c2.pkl b/tests/testdata/ggml_test_01d366e32dd8ae86bd079b6822814dcafad69a9082e4cf4db9633eaad47933c2.pkl new file mode 100644 index 0000000000..c42af66f2c Binary files /dev/null and b/tests/testdata/ggml_test_01d366e32dd8ae86bd079b6822814dcafad69a9082e4cf4db9633eaad47933c2.pkl differ diff --git a/tests/testdata/ggml_test_04e9938f35d50bceb56453089ce5c7a0738ac878d40ded36f8a1fc170ab54b18.pkl b/tests/testdata/ggml_test_04e9938f35d50bceb56453089ce5c7a0738ac878d40ded36f8a1fc170ab54b18.pkl new file mode 100644 index 0000000000..affc72b00b Binary files /dev/null and b/tests/testdata/ggml_test_04e9938f35d50bceb56453089ce5c7a0738ac878d40ded36f8a1fc170ab54b18.pkl differ diff --git a/tests/testdata/ggml_test_18f981234b6b471823bff9f977aad7f72439244c6c4d5f091b1f0984a71c8f11.pkl b/tests/testdata/ggml_test_18f981234b6b471823bff9f977aad7f72439244c6c4d5f091b1f0984a71c8f11.pkl new file mode 100644 index 0000000000..affc72b00b Binary files /dev/null and b/tests/testdata/ggml_test_18f981234b6b471823bff9f977aad7f72439244c6c4d5f091b1f0984a71c8f11.pkl differ diff --git a/tests/testdata/ggml_test_941e4a484a2f5d4d99b45084003946423f63cc2955e9400f7153a51cbed9470a.pkl b/tests/testdata/ggml_test_941e4a484a2f5d4d99b45084003946423f63cc2955e9400f7153a51cbed9470a.pkl new file mode 100644 index 0000000000..aaa8a2a6f1 Binary files /dev/null and b/tests/testdata/ggml_test_941e4a484a2f5d4d99b45084003946423f63cc2955e9400f7153a51cbed9470a.pkl differ diff --git a/tests/testdata/ggml_test_c28e46a48a7076dc266da0a1a93be005d91162fb16950e31daee94d23d9e091e.pkl b/tests/testdata/ggml_test_c28e46a48a7076dc266da0a1a93be005d91162fb16950e31daee94d23d9e091e.pkl new file mode 100644 index 0000000000..affc72b00b Binary files /dev/null and b/tests/testdata/ggml_test_c28e46a48a7076dc266da0a1a93be005d91162fb16950e31daee94d23d9e091e.pkl differ diff --git a/tests/testdata/ggml_test_e768167d669b5de84c99743c4b55bc85889d4894d1901d9f955626e0622059a6.pkl b/tests/testdata/ggml_test_e768167d669b5de84c99743c4b55bc85889d4894d1901d9f955626e0622059a6.pkl new file mode 100644 index 0000000000..affc72b00b Binary files /dev/null and b/tests/testdata/ggml_test_e768167d669b5de84c99743c4b55bc85889d4894d1901d9f955626e0622059a6.pkl differ diff --git a/tests/testdata/ggml_test_e7804132de7c4a26d33e65d1853c1a97d3b4b364da6bf7fcee2883953b61e6c6.pkl b/tests/testdata/ggml_test_e7804132de7c4a26d33e65d1853c1a97d3b4b364da6bf7fcee2883953b61e6c6.pkl new file mode 100644 index 0000000000..825a5b9d49 Binary files /dev/null and b/tests/testdata/ggml_test_e7804132de7c4a26d33e65d1853c1a97d3b4b364da6bf7fcee2883953b61e6c6.pkl differ