## 采用paddlenlp中的embedding

In [2]:
import os
from copy import copy
from typing import Dict, List, Optional, Tuple, Union
import numpy as np

from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())


class BaseEmbeddings:
    """
    Base class for embeddings
    """
    def __init__(self, path: str, is_api: bool) -> None:
        self.path = path
        self.is_api = is_api
    
    def get_embedding(self, text: str, model: str) -> List[float]:
        raise NotImplementedError

    def get_embeddings(self, text: List[str], model: str) -> List[List[float]]:
        raise NotImplementedError
    
    @classmethod
    def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
        """
        calculate cosine similarity between two vectors
        """
        dot_product = np.dot(vector1, vector2)
        magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)
        if not magnitude:
            return 0
        return dot_product / magnitude
    

class OpenAIEmbedding(BaseEmbeddings):
    """
    class for OpenAI embeddings
    """
    def __init__(self, path: str = '', is_api: bool = True) -> None:
        super().__init__(path, is_api)
        if self.is_api:
            from openai import OpenAI
            self.client = OpenAI()
            self.client.api_key = os.getenv("OPENAI_API_KEY")
            self.client.base_url = os.getenv("OPENAI_BASE_URL")
    
    def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]:
        if self.is_api:
            text = text.replace("\n", " ")
            return self.client.embeddings.create(input=[text], model=model).data[0].embedding
        else:
            raise NotImplementedError

class JinaEmbedding(BaseEmbeddings):
    """
    class for Jina embeddings
    """
    def __init__(self, path: str = 'jinaai/jina-embeddings-v2-base-zh', is_api: bool = False) -> None:
        super().__init__(path, is_api)
        self._model = self.load_model()
        
    def get_embedding(self, text: str) -> List[float]:
        return self._model.encode([text])[0].tolist()
    
    def load_model(self):
        import torch
        from transformers import AutoModel
        if torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
        model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
        return model

class ZhipuEmbedding(BaseEmbeddings):
    """
    class for Zhipu embeddings
    """
    def __init__(self, path: str = '', is_api: bool = True, embedding_dim = 1024) -> None:
        super().__init__(path, is_api)
        if self.is_api:
            from zhipuai import ZhipuAI
            self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY")) 
        self.embedding_dim = embedding_dim


    def get_embedding(self, text: str) -> List[float]:
        response = self.client.embeddings.create(
        model="embedding-2",
        input=text,
        )
        return response.data[0].embedding



In [None]:
from typing import List
from paddlenlp import Taskflow

class PaddleEmbedding(BaseEmbeddings):
    def __init__(
            self,
            model:str="rocketqa-zh-base-query-encoder",
            batch_size:int =1,
            **kwargs

    ) -> None:
        self.client = Taskflow("feature_extraction",model=model ,batch_size = batch_size,return_tensors='np',**kwargs)

    def get_embedding(self, text: str)  -> List[float]:
        text_embeds = self.client([text])
        result = text_embeds["features"]
        data = result[0]
        return data.tolist()

    def get_embeddings(self, text: List[str])  -> List[List[float]]:
        text_embeds = self.client(text)
        result = text_embeds["features"]
        data = result
        return data.tolist()

In [None]:
embedding = PaddleEmbedding()

In [7]:
output = embedding.get_embedding("你好，我是程序锅")
output

[0.05002579465508461,
 0.17876456677913666,
 -0.03920114040374756,
 -0.25773388147354126,
 0.5590400099754333,
 -0.03011620044708252,
 -0.17774274945259094,
 0.25186994671821594,
 0.14149263501167297,
 -0.10307693481445312,
 0.046508852392435074,
 -0.028397087007761,
 -0.11078149825334549,
 -0.07999950647354126,
 0.00031697750091552734,
 0.032515719532966614,
 -0.0244218111038208,
 0.02001810073852539,
 0.2677673101425171,
 0.08093968033790588,
 0.012369967997074127,
 0.3121291399002075,
 -0.18647229671478271,
 0.027275685220956802,
 -0.151125967502594,
 0.1030336394906044,
 0.1875016689300537,
 -0.03331463038921356,
 0.06458021700382233,
 0.3084167242050171,
 -0.15596449375152588,
 -0.24720197916030884,
 -0.04501324146986008,
 0.3177567720413208,
 -0.2251533418893814,
 0.08213544636964798,
 -0.09071619063615799,
 -0.010918959975242615,
 -0.376307874917984,
 3.258854627609253,
 -0.1362328827381134,
 -0.1020917296409607,
 0.14859209954738617,
 0.324101984500885,
 0.2453329712152481,
 -0

In [8]:
len(output)

768