Skip to content

Commit d331b29

Browse files
committed
fix normalization issue
1 parent ba35aef commit d331b29

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

llama_cpp/llama.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -768,38 +768,36 @@ def embed(self, input: str, normalize: bool = True, truncate: bool = True, retur
768768
else:
769769
inputs = input
770770

771-
def normalize(x):
772-
norm = np.linalg.norm(x)
773-
return [v/norm for v in x]
774-
775771
# reset batch
776772
self._batch.reset()
777773

778774
# decode and fetch embeddings
779775
data: List[Embedding] = []
780-
def decode_batch(n_seq):
776+
def decode_batch(sizes):
781777
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
782778
self._ctx.decode(self._batch)
783779
self._batch.reset()
784780

785781
# store embeddings
786-
for i in range(n_seq):
782+
for i, s in enumerate(sizes):
787783
embedding = llama_cpp.llama_get_embeddings_ith(self._ctx.ctx, i)[:n_embd]
788-
if normalize:
789-
embedding = normalize(embedding)
784+
norm = np.linalg.norm(embedding) if normalize else s
785+
embedding = [v/norm for v in embedding]
790786
data.append(embedding)
791787

792788
# init state
793789
total_tokens = 0
794-
p_batch = 0
795790
t_batch = 0
791+
s_sizes = []
796792

797793
# accumulate batches and encode
798794
for text in inputs:
799795
tokens = self.tokenize(text.encode("utf-8"))
800796
if truncate:
801797
tokens = tokens[:n_ctx]
798+
802799
n_tokens = len(tokens)
800+
total_tokens += n_tokens
803801

804802
# check for overrun
805803
if n_tokens > n_ctx:
@@ -808,20 +806,18 @@ def decode_batch(n_seq):
808806
)
809807

810808
# time to eval batch
811-
if n_tokens + t_batch > self._n_ctx:
812-
decode_batch(p_batch)
813-
total_tokens += t_batch
814-
p_batch = 0
809+
if t_batch + n_tokens > self._n_ctx:
810+
decode_batch(s_sizes)
815811
t_batch = 0
812+
s_sizes = []
816813

817814
# add to batch
818-
self._batch.add_sequence(tokens, p_batch, False)
819-
p_batch += 1
815+
self._batch.add_sequence(tokens, len(s_sizes), False)
820816
t_batch += n_tokens
817+
s_sizes.append(n_tokens)
821818

822819
# hanlde last batch
823-
decode_batch(p_batch)
824-
total_tokens += t_batch
820+
decode_batch(s_sizes)
825821

826822
if self.verbose:
827823
llama_cpp.llama_print_timings(self._ctx.ctx)

0 commit comments

Comments
 (0)