From ba35aef093fc07e4141852bae6de4ee657a0ff2f Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Tue, 13 Feb 2024 14:59:21 -0600 Subject: [PATCH 1/4] handle batched embeddings --- llama_cpp/_internals.py | 22 +++++++ llama_cpp/llama.py | 124 ++++++++++++++++++++++++++++------------ 2 files changed, 111 insertions(+), 35 deletions(-) diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 3a71ef0fa..911ac1f57 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -506,6 +506,14 @@ def __del__(self): self._llama_batch_free(self.batch) self.batch = None + def n_tokens(self) -> int: + assert self.batch is not None + return self.batch.n_tokens + + def reset(self): + assert self.batch is not None + self.batch.n_tokens = 0 + def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): assert self.batch is not None n_tokens = len(batch) @@ -518,6 +526,20 @@ def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): self.batch.logits[i] = logits_all self.batch.logits[n_tokens - 1] = True + def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool): + assert self.batch is not None + n_tokens = len(batch) + n_tokens0 = self.batch.n_tokens + self.batch.n_tokens += n_tokens + for i in range(n_tokens): + j = n_tokens0 + i + self.batch.token[j] = batch[i] + self.batch.pos[j] = i + self.batch.seq_id[j][0] = seq_id + self.batch.n_seq_id[j] = 1 + self.batch.logits[j] = logits_all + self.batch.logits[n_tokens - 1] = True + class _LlamaTokenDataArray: def __init__(self, *, n_vocab: int): diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 8d726d38f..3e448ea71 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -717,10 +717,44 @@ def create_embedding( Returns: An embedding object. """ - assert self._ctx.ctx is not None assert self._model.model is not None model_name: str = model if model is not None else self.model_path + # get numeric embeddings + embeds, total_tokens = self.embed(input, return_count=True) + + # convert to CreateEmbeddingResponse + data = [ + { + "object": "embedding", + "embedding": emb, + "index": idx, + } for idx, emb in enumerate(embeds) + ] + + return { + "object": "list", + "data": data, + "model": model_name, + "usage": { + "prompt_tokens": total_tokens, + "total_tokens": total_tokens, + }, + } + + def embed(self, input: str, normalize: bool = True, truncate: bool = True, return_count: bool = False) -> List[float]: + """Embed a string. + + Args: + input: The utf-8 encoded string to embed. + + Returns: + A list of embeddings + """ + assert self._ctx.ctx is not None + n_embd = self.n_embd() + n_ctx = self.n_ctx() + if self.context_params.embedding == False: raise RuntimeError( "Llama model must be created with embedding=True to call this method" @@ -734,48 +768,68 @@ def create_embedding( else: inputs = input + def normalize(x): + norm = np.linalg.norm(x) + return [v/norm for v in x] + + # reset batch + self._batch.reset() + + # decode and fetch embeddings data: List[Embedding] = [] + def decode_batch(n_seq): + llama_cpp.llama_kv_cache_clear(self._ctx.ctx) + self._ctx.decode(self._batch) + self._batch.reset() + + # store embeddings + for i in range(n_seq): + embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[:n_embd] + if normalize: + embedding = normalize(embedding) + data.append(embedding) + + # init state total_tokens = 0 - for index, input in enumerate(inputs): - tokens = self.tokenize(input.encode("utf-8"), special=True) - self.reset() - self.eval(tokens) + p_batch = 0 + t_batch = 0 + + # accumulate batches and encode + for text in inputs: + tokens = self.tokenize(text.encode("utf-8")) + if truncate: + tokens = tokens[:n_ctx] n_tokens = len(tokens) - total_tokens += n_tokens - embedding = llama_cpp.llama_get_embeddings(self._ctx.ctx)[ - : llama_cpp.llama_n_embd(self._model.model) - ] - data.append( - { - "object": "embedding", - "embedding": embedding, - "index": index, - } - ) - if self.verbose: - llama_cpp.llama_print_timings(self._ctx.ctx) + # check for overrun + if n_tokens > n_ctx: + raise ValueError( + f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}" + ) - return { - "object": "list", - "data": data, - "model": model_name, - "usage": { - "prompt_tokens": total_tokens, - "total_tokens": total_tokens, - }, - } + # time to eval batch + if n_tokens + t_batch > self._n_ctx: + decode_batch(p_batch) + total_tokens += t_batch + p_batch = 0 + t_batch = 0 - def embed(self, input: str) -> List[float]: - """Embed a string. + # add to batch + self._batch.add_sequence(tokens, p_batch, False) + p_batch += 1 + t_batch += n_tokens - Args: - input: The utf-8 encoded string to embed. + # hanlde last batch + decode_batch(p_batch) + total_tokens += t_batch - Returns: - A list of embeddings - """ - return list(map(float, self.create_embedding(input)["data"][0]["embedding"])) + if self.verbose: + llama_cpp.llama_print_timings(self._ctx.ctx) + + if return_count: + return data, total_tokens + else: + return data def _create_completion( self, From d331b29f40ea3fc0fc47c045990f269320a8350f Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Tue, 13 Feb 2024 23:06:08 -0600 Subject: [PATCH 2/4] fix normalization issue --- llama_cpp/llama.py | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 3e448ea71..9cf7a5b9c 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -768,38 +768,36 @@ def embed(self, input: str, normalize: bool = True, truncate: bool = True, retur else: inputs = input - def normalize(x): - norm = np.linalg.norm(x) - return [v/norm for v in x] - # reset batch self._batch.reset() # decode and fetch embeddings data: List[Embedding] = [] - def decode_batch(n_seq): + def decode_batch(sizes): llama_cpp.llama_kv_cache_clear(self._ctx.ctx) self._ctx.decode(self._batch) self._batch.reset() # store embeddings - for i in range(n_seq): + for i, s in enumerate(sizes): embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[:n_embd] - if normalize: - embedding = normalize(embedding) + norm = np.linalg.norm(embedding) if normalize else s + embedding = [v/norm for v in embedding] data.append(embedding) # init state total_tokens = 0 - p_batch = 0 t_batch = 0 + s_sizes = [] # accumulate batches and encode for text in inputs: tokens = self.tokenize(text.encode("utf-8")) if truncate: tokens = tokens[:n_ctx] + n_tokens = len(tokens) + total_tokens += n_tokens # check for overrun if n_tokens > n_ctx: @@ -808,20 +806,18 @@ def decode_batch(n_seq): ) # time to eval batch - if n_tokens + t_batch > self._n_ctx: - decode_batch(p_batch) - total_tokens += t_batch - p_batch = 0 + if t_batch + n_tokens > self._n_ctx: + decode_batch(s_sizes) t_batch = 0 + s_sizes = [] # add to batch - self._batch.add_sequence(tokens, p_batch, False) - p_batch += 1 + self._batch.add_sequence(tokens, len(s_sizes), False) t_batch += n_tokens + s_sizes.append(n_tokens) # hanlde last batch - decode_batch(p_batch) - total_tokens += t_batch + decode_batch(s_sizes) if self.verbose: llama_cpp.llama_print_timings(self._ctx.ctx) From afc819da81695c7d3ca92eb3072cdfc18c206c9d Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 14 Feb 2024 04:16:29 -0500 Subject: [PATCH 3/4] fix type hints, ensure no breaking changes to embed --- llama_cpp/llama.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 9cf7a5b9c..b5fba6da3 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -721,15 +721,18 @@ def create_embedding( model_name: str = model if model is not None else self.model_path # get numeric embeddings - embeds, total_tokens = self.embed(input, return_count=True) + embeds: List[List[float]] + total_tokens: int + embeds, total_tokens = self.embed(input, return_count=True) # type: ignore # convert to CreateEmbeddingResponse - data = [ + data: List[Embedding] = [ { "object": "embedding", "embedding": emb, "index": idx, - } for idx, emb in enumerate(embeds) + } + for idx, emb in enumerate(embeds) ] return { @@ -742,7 +745,13 @@ def create_embedding( }, } - def embed(self, input: str, normalize: bool = True, truncate: bool = True, return_count: bool = False) -> List[float]: + def embed( + self, + input: Union[str, List[str]], + normalize: bool = True, + truncate: bool = True, + return_count: bool = False, + ): """Embed a string. Args: @@ -772,23 +781,26 @@ def embed(self, input: str, normalize: bool = True, truncate: bool = True, retur self._batch.reset() # decode and fetch embeddings - data: List[Embedding] = [] - def decode_batch(sizes): + data: List[List[float]] = [] + def decode_batch(sizes: List[int]): + assert self._ctx.ctx is not None llama_cpp.llama_kv_cache_clear(self._ctx.ctx) self._ctx.decode(self._batch) self._batch.reset() # store embeddings for i, s in enumerate(sizes): - embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[:n_embd] + embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[ + :n_embd + ] norm = np.linalg.norm(embedding) if normalize else s - embedding = [v/norm for v in embedding] + embedding: List[float] = [v / float(norm) for v in embedding] data.append(embedding) # init state total_tokens = 0 t_batch = 0 - s_sizes = [] + s_sizes: List[int] = [] # accumulate batches and encode for text in inputs: @@ -822,10 +834,12 @@ def decode_batch(sizes): if self.verbose: llama_cpp.llama_print_timings(self._ctx.ctx) + output = data[0] if isinstance(input, str) else data + if return_count: - return data, total_tokens + return output, total_tokens else: - return data + return output def _create_completion( self, From e3e04f85669200d15aeaca61e68d1a69ad633703 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 14 Feb 2024 04:24:03 -0500 Subject: [PATCH 4/4] Clear kv cache / reset internal state after embedding complete --- llama_cpp/llama.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index b5fba6da3..3e09a20b5 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -836,6 +836,9 @@ def decode_batch(sizes: List[int]): output = data[0] if isinstance(input, str) else data + llama_cpp.llama_kv_cache_clear(self._ctx.ctx) + self.reset() + if return_count: return output, total_tokens else: