Skip to content

Commit

Permalink
Merge pull request #617 from matthoffner/master
Browse files Browse the repository at this point in the history
Add GGML model
  • Loading branch information
haileyschoelkopf committed Nov 3, 2023
2 parents 382af8c + 05c2991 commit 008fc2a
Show file tree
Hide file tree
Showing 11 changed files with 232 additions and 0 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ python main.py \
--tasks hellaswag
```

GGUF or GGML quantized models can be loaded by using `llama-cpp-python` server:

```bash
python main.py \
--model gguf \
--model_args base_url=http://localhost:8000 \
--tasks hellaswag
```

We support wildcards in task names, for example you can run all of the machine-translated lambada tasks via `--task lambada_openai_mt_*`.

We currently only support one prompt per task, which we strive to make the "standard" as defined by the benchmark's authors. If you would like to study how varying prompts causes changes in the evaluation score, check out the [BigScience fork](https://github.com/bigscience-workshop/lm-evaluation-harness) of this repo. We are currently working on upstreaming this capability to `main`.
Expand Down
2 changes: 2 additions & 0 deletions lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import huggingface
from . import textsynth
from . import dummy
from . import gguf

MODEL_REGISTRY = {
"hf": gpt2.HFLM,
Expand All @@ -15,6 +16,7 @@
"anthropic": anthropic_llms.AnthropicLM,
"textsynth": textsynth.TextSynthLM,
"dummy": dummy.DummyLM,
"gguf": gguf.GGUFLM
}


Expand Down
142 changes: 142 additions & 0 deletions lm_eval/models/gguf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import requests
import logging
import time
from tqdm import tqdm
from requests.exceptions import RequestException
import transformers
from lm_eval.utils import Reorderer
from lm_eval.base import BaseLM

logger = logging.getLogger(__name__)


def get_result(logprobs, context_length):
is_greedy = True
offsets = logprobs['text_offset']
tokens = logprobs['tokens']
tokens_logprobs = logprobs['token_logprobs']

idx = 0
while offsets[idx] < context_length:
idx += 1
continuation_logprobs = sum(tokens_logprobs[idx:-1])
for i in range(idx, len(tokens)):
token = tokens[i]
top_tokens = logprobs["top_logprobs"][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break

return continuation_logprobs, is_greedy


class GGUFLM(BaseLM):
def __init__(self, base_url, max_length=2048):
super().__init__()
self.base_url = base_url
self.logprobs = 10
self.temperature = 0.0
self.max_length = max_length

def gguf_completion(self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs):
for _ in range(retries):
try:
prompt = context
request = {'prompt': prompt, 'logprobs': self.logprobs,
'temperature': self.temperature}
if continuation:
prompt += continuation
request.update({'prompt': prompt, 'max_tokens': 1, 'echo': True})
if stop is not None:
request['stop'] = stop
response = requests.post(f"{self.base_url}/v1/completions", json=request)
response.raise_for_status()
return response.json()
except RequestException as e:
logger.error(f"RequestException: {e}")
time.sleep(delay) # wait before retrying
else:
raise Exception(f"Failed to get a valid response after {retries} retries.")

def loglikelihood(self, requests):
if not requests:
return []
res = []
for context, continuation in tqdm(requests):
response = self.gguf_completion(context=context, continuation=continuation)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
logprobs = choice.get("logprobs")
if logprobs and "token_logprobs" in logprobs and logprobs["token_logprobs"]:
logprob, is_greedy = get_result(logprobs, len(context))
res.append((logprob, is_greedy))
else:
logger.warning("Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list.")
else:
logger.error(f"Invalid response for loglikelihood. Response: {response}")
assert False
return res

def greedy_until(self, requests):
if not requests:
return []

res = []
for request in tqdm(requests):
inp = request[0]
request_args = request[1]
until = request_args["until"]
response = self.gguf_completion(context=inp, stop=until)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
if "text" in choice:
generated_text = choice["text"].strip()
res.append(generated_text)
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error
else:
logger.error(f"Invalid response for greedy_until. Response: {response}")
res.append(None) # Add default value in case of error
return res

def loglikelihood_rolling(self, requests):
raise NotImplementedError("loglikelihood_rolling not yet supported for GGUF models")

def _model_call(self, inps):
# Placeholder implementation
raise NotImplementedError()

def _model_generate(self, context, max_length, eos_token_id):
# Placeholder implementation
raise NotImplementedError()

def tok_encode(self, string: str):
raise NotImplementedError()

def tok_decode(self, tokens):
raise NotImplementedError()

@property
def batch_size(self):
# Placeholder implementation
raise NotImplementedError()

@property
def device(self):
# Placeholder implementation
raise NotImplementedError()

@property
def eot_token_id(self):
# Placeholder implementation
raise NotImplementedError()

def max_length(self):
return self.max_length

@property
def max_gen_toks(self):
# Placeholder implementation
raise NotImplementedError()
79 changes: 79 additions & 0 deletions tests/test_gguf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import unittest
from unittest.mock import patch
import hashlib
import json
import os
import pickle
from lm_eval.models.gguf import GGUFLM

base_url = "https://matthoffner-ggml-llm-api.hf.space"

def gguf_completion_mock(base_url, **kwargs):
# Generate a hash from the parameters
hash_kwargs = {'base_url': base_url, **kwargs}
hash = hashlib.sha256(json.dumps(hash_kwargs, sort_keys=True).encode('utf-8')).hexdigest()

fname = f"./tests/testdata/ggml_test_{hash}.pkl"

if os.path.exists(fname):
with open(fname, "rb") as fh:
return pickle.load(fh)
else:
print("The file does not exist, attempting to write...")
if 'stop' in kwargs:
result = {"choices": [{"text": f"generated text until {kwargs['stop']}", "logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}
else:
result = {"choices": [{"logprobs": {"token_logprobs": [-1.2345]}, "finish_reason": "length"}]}

try:
os.makedirs(os.path.dirname(fname), exist_ok=True)
print('Writing file at', fname)
with open(fname, "wb") as fh:
pickle.dump(result, fh)
print('File written successfully')
except Exception as e:
print('File writing failed:', e)

return result


class GGUFLMTest(unittest.TestCase):
@patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock)
def test_loglikelihood(self, gguf_completion_mock):
lm = GGUFLM(base_url)

# Test loglikelihood
requests = [("context1", "continuation1"), ("context2", "continuation2")]
res = lm.loglikelihood(requests)

# Assert the loglikelihood response is correct
expected_res = [(logprob, True) for logprob in [-1.2345, -1.2345]]
self.assertEqual(res, expected_res)

@patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock)
def test_greedy_until(self, gguf_completion_mock):
lm = GGUFLM(base_url)

# Test greedy_until
requests = [("input1", {"until": "stop1"}), ("input2", {"until": "stop2"})]
res = lm.greedy_until(requests)

# Assert the greedy_until response is correct
expected_res = ["generated text until stop1", "generated text until stop2"]
self.assertEqual(res, expected_res)

@patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock)
def test_loglikelihood_rolling(self, gguf_completion_mock):
lm = GGUFLM(base_url)

# Test loglikelihood_rolling
requests = ["input1", "input2"]
res = lm.loglikelihood_rolling(requests)

# Assert the loglikelihood_rolling response is correct
expected_res = [(-1.2345, True), (-1.2345, True)]
self.assertEqual(res, expected_res)


if __name__ == "__main__":
unittest.main()
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 008fc2a

Please sign in to comment.