Skip to content

Commit ba35aef

Browse files
committed
handle batched embeddings
1 parent 07a7837 commit ba35aef

File tree

2 files changed

+111
-35
lines changed

2 files changed

+111
-35
lines changed

llama_cpp/_internals.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,14 @@ def __del__(self):
506506
self._llama_batch_free(self.batch)
507507
self.batch = None
508508

509+
def n_tokens(self) -> int:
510+
assert self.batch is not None
511+
return self.batch.n_tokens
512+
513+
def reset(self):
514+
assert self.batch is not None
515+
self.batch.n_tokens = 0
516+
509517
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
510518
assert self.batch is not None
511519
n_tokens = len(batch)
@@ -518,6 +526,20 @@ def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
518526
self.batch.logits[i] = logits_all
519527
self.batch.logits[n_tokens - 1] = True
520528

529+
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
530+
assert self.batch is not None
531+
n_tokens = len(batch)
532+
n_tokens0 = self.batch.n_tokens
533+
self.batch.n_tokens += n_tokens
534+
for i in range(n_tokens):
535+
j = n_tokens0 + i
536+
self.batch.token[j] = batch[i]
537+
self.batch.pos[j] = i
538+
self.batch.seq_id[j][0] = seq_id
539+
self.batch.n_seq_id[j] = 1
540+
self.batch.logits[j] = logits_all
541+
self.batch.logits[n_tokens - 1] = True
542+
521543

522544
class _LlamaTokenDataArray:
523545
def __init__(self, *, n_vocab: int):

llama_cpp/llama.py

Lines changed: 89 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -717,10 +717,44 @@ def create_embedding(
717717
Returns:
718718
An embedding object.
719719
"""
720-
assert self._ctx.ctx is not None
721720
assert self._model.model is not None
722721
model_name: str = model if model is not None else self.model_path
723722

723+
# get numeric embeddings
724+
embeds, total_tokens = self.embed(input, return_count=True)
725+
726+
# convert to CreateEmbeddingResponse
727+
data = [
728+
{
729+
"object": "embedding",
730+
"embedding": emb,
731+
"index": idx,
732+
} for idx, emb in enumerate(embeds)
733+
]
734+
735+
return {
736+
"object": "list",
737+
"data": data,
738+
"model": model_name,
739+
"usage": {
740+
"prompt_tokens": total_tokens,
741+
"total_tokens": total_tokens,
742+
},
743+
}
744+
745+
def embed(self, input: str, normalize: bool = True, truncate: bool = True, return_count: bool = False) -> List[float]:
746+
"""Embed a string.
747+
748+
Args:
749+
input: The utf-8 encoded string to embed.
750+
751+
Returns:
752+
A list of embeddings
753+
"""
754+
assert self._ctx.ctx is not None
755+
n_embd = self.n_embd()
756+
n_ctx = self.n_ctx()
757+
724758
if self.context_params.embedding == False:
725759
raise RuntimeError(
726760
"Llama model must be created with embedding=True to call this method"
@@ -734,48 +768,68 @@ def create_embedding(
734768
else:
735769
inputs = input
736770

771+
def normalize(x):
772+
norm = np.linalg.norm(x)
773+
return [v/norm for v in x]
774+
775+
# reset batch
776+
self._batch.reset()
777+
778+
# decode and fetch embeddings
737779
data: List[Embedding] = []
780+
def decode_batch(n_seq):
781+
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
782+
self._ctx.decode(self._batch)
783+
self._batch.reset()
784+
785+
# store embeddings
786+
for i in range(n_seq):
787+
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[:n_embd]
788+
if normalize:
789+
embedding = normalize(embedding)
790+
data.append(embedding)
791+
792+
# init state
738793
total_tokens = 0
739-
for index, input in enumerate(inputs):
740-
tokens = self.tokenize(input.encode("utf-8"), special=True)
741-
self.reset()
742-
self.eval(tokens)
794+
p_batch = 0
795+
t_batch = 0
796+
797+
# accumulate batches and encode
798+
for text in inputs:
799+
tokens = self.tokenize(text.encode("utf-8"))
800+
if truncate:
801+
tokens = tokens[:n_ctx]
743802
n_tokens = len(tokens)
744-
total_tokens += n_tokens
745-
embedding = llama_cpp.llama_get_embeddings(self._ctx.ctx)[
746-
: llama_cpp.llama_n_embd(self._model.model)
747-
]
748803

749-
data.append(
750-
{
751-
"object": "embedding",
752-
"embedding": embedding,
753-
"index": index,
754-
}
755-
)
756-
if self.verbose:
757-
llama_cpp.llama_print_timings(self._ctx.ctx)
804+
# check for overrun
805+
if n_tokens > n_ctx:
806+
raise ValueError(
807+
f"Requested tokens ({n_tokens}) exceed context window of {n_ctx}"
808+
)
758809

759-
return {
760-
"object": "list",
761-
"data": data,
762-
"model": model_name,
763-
"usage": {
764-
"prompt_tokens": total_tokens,
765-
"total_tokens": total_tokens,
766-
},
767-
}
810+
# time to eval batch
811+
if n_tokens + t_batch > self._n_ctx:
812+
decode_batch(p_batch)
813+
total_tokens += t_batch
814+
p_batch = 0
815+
t_batch = 0
768816

769-
def embed(self, input: str) -> List[float]:
770-
"""Embed a string.
817+
# add to batch
818+
self._batch.add_sequence(tokens, p_batch, False)
819+
p_batch += 1
820+
t_batch += n_tokens
771821

772-
Args:
773-
input: The utf-8 encoded string to embed.
822+
# hanlde last batch
823+
decode_batch(p_batch)
824+
total_tokens += t_batch
774825

775-
Returns:
776-
A list of embeddings
777-
"""
778-
return list(map(float, self.create_embedding(input)["data"][0]["embedding"]))
826+
if self.verbose:
827+
llama_cpp.llama_print_timings(self._ctx.ctx)
828+
829+
if return_count:
830+
return data, total_tokens
831+
else:
832+
return data
779833

780834
def _create_completion(
781835
self,

0 commit comments

Comments
 (0)