@@ -768,38 +768,36 @@ def embed(self, input: str, normalize: bool = True, truncate: bool = True, retur
768768 else :
769769 inputs = input
770770
771- def normalize (x ):
772- norm = np .linalg .norm (x )
773- return [v / norm for v in x ]
774-
775771 # reset batch
776772 self ._batch .reset ()
777773
778774 # decode and fetch embeddings
779775 data : List [Embedding ] = []
780- def decode_batch (n_seq ):
776+ def decode_batch (sizes ):
781777 llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
782778 self ._ctx .decode (self ._batch )
783779 self ._batch .reset ()
784780
785781 # store embeddings
786- for i in range ( n_seq ):
782+ for i , s in enumerate ( sizes ):
787783 embedding = llama_cpp .llama_get_embeddings_ith (self ._ctx .ctx , i )[:n_embd ]
788- if normalize :
789- embedding = normalize ( embedding )
784+ norm = np . linalg . norm ( embedding ) if normalize else s
785+ embedding = [ v / norm for v in embedding ]
790786 data .append (embedding )
791787
792788 # init state
793789 total_tokens = 0
794- p_batch = 0
795790 t_batch = 0
791+ s_sizes = []
796792
797793 # accumulate batches and encode
798794 for text in inputs :
799795 tokens = self .tokenize (text .encode ("utf-8" ))
800796 if truncate :
801797 tokens = tokens [:n_ctx ]
798+
802799 n_tokens = len (tokens )
800+ total_tokens += n_tokens
803801
804802 # check for overrun
805803 if n_tokens > n_ctx :
@@ -808,20 +806,18 @@ def decode_batch(n_seq):
808806 )
809807
810808 # time to eval batch
811- if n_tokens + t_batch > self ._n_ctx :
812- decode_batch (p_batch )
813- total_tokens += t_batch
814- p_batch = 0
809+ if t_batch + n_tokens > self ._n_ctx :
810+ decode_batch (s_sizes )
815811 t_batch = 0
812+ s_sizes = []
816813
817814 # add to batch
818- self ._batch .add_sequence (tokens , p_batch , False )
819- p_batch += 1
815+ self ._batch .add_sequence (tokens , len (s_sizes ), False )
820816 t_batch += n_tokens
817+ s_sizes .append (n_tokens )
821818
822819 # hanlde last batch
823- decode_batch (p_batch )
824- total_tokens += t_batch
820+ decode_batch (s_sizes )
825821
826822 if self .verbose :
827823 llama_cpp .llama_print_timings (self ._ctx .ctx )
0 commit comments