From 5535d2ae7a98376284f3ace04a1ea53026761570 Mon Sep 17 00:00:00 2001 From: Tanner Hobson Date: Sat, 9 Dec 2023 15:33:46 -0500 Subject: [PATCH] Replace logits_to_logprobs implementation with numpy equivalent to llama.cpp See #990. This change makes the logits_to_logprobs function equivalent to the version in the llama.cpp repository. It uses numpy so it's much faster than the previous version. --- llama_cpp/llama.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 807654889..69f9ed9cc 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -2280,10 +2280,14 @@ def token_nl(self) -> int: return self._model.token_nl() @staticmethod - def logits_to_logprobs(logits: List[float]) -> List[float]: - exps = [math.exp(float(x)) for x in logits] - sum_exps = sum(exps) - return [math.log(x / sum_exps) for x in exps] + def logits_to_logprobs(logits: npt.NDArray[np.single]) -> npt.NDArray[np.single]: + maximum = np.max(logits) + tmp = np.subtract(logits, maximum, dtype=np.single) + np.exp(tmp, out=tmp) + normalizer = 1.0 / np.sum(tmp) + np.multiply(normalizer, tmp, out=tmp) + np.log(tmp, out=tmp) + return tmp @staticmethod def longest_token_prefix(a: Sequence[int], b: Sequence[int]):