@@ -717,10 +717,44 @@ def create_embedding(
717717        Returns: 
718718            An embedding object. 
719719        """ 
720-         assert  self ._ctx .ctx  is  not   None 
721720        assert  self ._model .model  is  not   None 
722721        model_name : str  =  model  if  model  is  not   None  else  self .model_path 
723722
723+         # get numeric embeddings 
724+         embeds , total_tokens  =  self .embed (input , return_count = True )
725+ 
726+         # convert to CreateEmbeddingResponse 
727+         data  =  [
728+             {
729+                 "object" : "embedding" ,
730+                 "embedding" : emb ,
731+                 "index" : idx ,
732+             } for  idx , emb  in  enumerate (embeds )
733+         ]
734+ 
735+         return  {
736+             "object" : "list" ,
737+             "data" : data ,
738+             "model" : model_name ,
739+             "usage" : {
740+                 "prompt_tokens" : total_tokens ,
741+                 "total_tokens" : total_tokens ,
742+             },
743+         }
744+ 
745+     def  embed (self , input : str , normalize : bool  =  True , truncate : bool  =  True , return_count : bool  =  False ) ->  List [float ]:
746+         """Embed a string. 
747+ 
748+         Args: 
749+             input: The utf-8 encoded string to embed. 
750+ 
751+         Returns: 
752+             A list of embeddings 
753+         """ 
754+         assert  self ._ctx .ctx  is  not   None 
755+         n_embd  =  self .n_embd ()
756+         n_ctx  =  self .n_ctx ()
757+ 
724758        if  self .context_params .embedding  ==  False :
725759            raise  RuntimeError (
726760                "Llama model must be created with embedding=True to call this method" 
@@ -734,48 +768,68 @@ def create_embedding(
734768        else :
735769            inputs  =  input 
736770
771+         def  normalize (x ):
772+             norm  =  np .linalg .norm (x )
773+             return  [v / norm  for  v  in  x ]
774+ 
775+         # reset batch 
776+         self ._batch .reset ()
777+ 
778+         # decode and fetch embeddings 
737779        data : List [Embedding ] =  []
780+         def  decode_batch (n_seq ):
781+             llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
782+             self ._ctx .decode (self ._batch )
783+             self ._batch .reset ()
784+ 
785+             # store embeddings 
786+             for  i  in  range (n_seq ):
787+                 embedding  =  llama_cpp .llama_get_embeddings_ith (self ._ctx .ctx , i )[:n_embd ]
788+                 if  normalize :
789+                     embedding  =  normalize (embedding )
790+                 data .append (embedding )
791+ 
792+         # init state 
738793        total_tokens  =  0 
739-         for  index , input  in  enumerate (inputs ):
740-             tokens  =  self .tokenize (input .encode ("utf-8" ), special = True )
741-             self .reset ()
742-             self .eval (tokens )
794+         p_batch  =  0 
795+         t_batch  =  0 
796+ 
797+         # accumulate batches and encode 
798+         for  text  in  inputs :
799+             tokens  =  self .tokenize (text .encode ("utf-8" ))
800+             if  truncate :
801+                 tokens  =  tokens [:n_ctx ]
743802            n_tokens  =  len (tokens )
744-             total_tokens  +=  n_tokens 
745-             embedding  =  llama_cpp .llama_get_embeddings (self ._ctx .ctx )[
746-                 : llama_cpp .llama_n_embd (self ._model .model )
747-             ]
748803
749-             data .append (
750-                 {
751-                     "object" : "embedding" ,
752-                     "embedding" : embedding ,
753-                     "index" : index ,
754-                 }
755-             )
756-         if  self .verbose :
757-             llama_cpp .llama_print_timings (self ._ctx .ctx )
804+             # check for overrun 
805+             if  n_tokens  >  n_ctx :
806+                 raise  ValueError (
807+                     f"Requested tokens ({ n_tokens }  ) exceed context window of { n_ctx }  " 
808+                 )
758809
759-         return  {
760-             "object" : "list" ,
761-             "data" : data ,
762-             "model" : model_name ,
763-             "usage" : {
764-                 "prompt_tokens" : total_tokens ,
765-                 "total_tokens" : total_tokens ,
766-             },
767-         }
810+             # time to eval batch 
811+             if  n_tokens  +  t_batch  >  self ._n_ctx :
812+                 decode_batch (p_batch )
813+                 total_tokens  +=  t_batch 
814+                 p_batch  =  0 
815+                 t_batch  =  0 
768816
769-     def  embed (self , input : str ) ->  List [float ]:
770-         """Embed a string. 
817+             # add to batch 
818+             self ._batch .add_sequence (tokens , p_batch , False )
819+             p_batch  +=  1 
820+             t_batch  +=  n_tokens 
771821
772-         Args: 
773-             input: The utf-8 encoded string to embed. 
822+         # hanlde last batch 
823+         decode_batch (p_batch )
824+         total_tokens  +=  t_batch 
774825
775-         Returns: 
776-             A list of embeddings 
777-         """ 
778-         return  list (map (float , self .create_embedding (input )["data" ][0 ]["embedding" ]))
826+         if  self .verbose :
827+             llama_cpp .llama_print_timings (self ._ctx .ctx )
828+ 
829+         if  return_count :
830+             return  data , total_tokens 
831+         else :
832+             return  data 
779833
780834    def  _create_completion (
781835        self ,
0 commit comments