# LlamaIndex 定制兼容 OpenAI Embedding

LlamaIndex 因为通过字典检查嵌入模型名称，因此其他云端模型虽然兼容 OpenAI API，但是不能直接使用 `llama-index-embeddings-openai`.

以下定制了新的 Embedding，修改 LlamaIndex 的实现，取消对模型名称的检查。

考察云端模型是否可用：

- 智谱ai    ✅
- 阿里云百炼 ✅
- 豆包      ✅

In [3]:
%%time

from dotenv import load_dotenv

load_dotenv()

CPU times: user 3.75 ms, sys: 489 µs, total: 4.24 ms
Wall time: 3.71 ms


True

## 自定义 Embedding

In [2]:
%%time

from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingMode,OpenAIEmbeddingModelType
import httpx
from typing import Any, Dict, List, Optional, Tuple
from llama_index.core.callbacks.base import CallbackManager

def get_engine(
    mode: str,
    model: str,
    # mode_model_dict: Dict[Tuple[OpenAIEmbeddingMode, str], OpenAIEmbeddingModeModel],
) -> str:
    """Get engine."""
    # key = (OpenAIEmbeddingMode(mode), OpenAIEmbeddingModelType(model))
    # if key not in mode_model_dict:
    #     raise ValueError(f"Invalid mode, model combination: {key}")
    # return mode_model_dict[key].value
    return model


class MyOpenAIEmbedding(OpenAIEmbedding):


    def __init__(
        self,
        mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
        model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
        embed_batch_size: int = 100,
        dimensions: Optional[int] = None,
        additional_kwargs: Optional[Dict[str, Any]] = None,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
        api_version: Optional[str] = None,
        max_retries: int = 10,
        timeout: float = 60.0,
        reuse_client: bool = True,
        callback_manager: Optional[CallbackManager] = None,
        default_headers: Optional[Dict[str, str]] = None,
        http_client: Optional[httpx.Client] = None,
        async_http_client: Optional[httpx.AsyncClient] = None,
        num_workers: Optional[int] = None,
        **kwargs: Any,
    ) -> None:
        additional_kwargs = additional_kwargs or {}
        if dimensions is not None:
            additional_kwargs["dimensions"] = dimensions

        api_key, api_base, api_version = self._resolve_credentials(
            api_key=api_key,
            api_base=api_base,
            api_version=api_version,
        )

        query_engine = get_engine(mode, model)
        text_engine = get_engine(mode, model)

        if "model_name" in kwargs:
            model_name = kwargs.pop("model_name")
            query_engine = text_engine = model_name
        else:
            model_name = model

        super().__init__(
            embed_batch_size=embed_batch_size,
            dimensions=dimensions,
            callback_manager=callback_manager,
            model_name=model_name,
            additional_kwargs=additional_kwargs,
            api_key=api_key,
            api_base=api_base,
            api_version=api_version,
            max_retries=max_retries,
            reuse_client=reuse_client,
            timeout=timeout,
            default_headers=default_headers,
            num_workers=num_workers,
            **kwargs,
        )
        self._query_engine = query_engine
        self._text_engine = text_engine

        self._client = None
        self._aclient = None
        self._http_client = http_client
        self._async_http_client = async_http_client

CPU times: user 3.68 ms, sys: 0 ns, total: 3.68 ms
Wall time: 3.64 ms


## 对云端嵌入模型的考察

### 智谱ai

In [4]:
import os

api_key = os.getenv("ZHIPU_API_KEY")

embed_model = MyOpenAIEmbedding(
    api_key=api_key,
    model="embedding-2",
    api_base="https://open.bigmodel.cn/api/paas/v4"
)

embeddings = embed_model.get_text_embedding(
    "花椰菜又称菜花、花菜，是一种常见的蔬菜。"
)
embeddings[:5]

[-0.030153707, 0.019338187, -0.059128232, -0.004209153, -0.013122211]

### 阿里云百炼

In [5]:
%%time

api_key = os.getenv("ALIYUN_API_KEY")

embed_model = MyOpenAIEmbedding(
    api_key=api_key,
    model="text-embedding-v1",
    api_base="https://dashscope.aliyuncs.com/compatible-mode/v1"
)

embeddings = embed_model.get_text_embedding(
    "花椰菜又称菜花、花菜，是一种常见的蔬菜。"
)
embeddings[:5]

CPU times: user 50.9 ms, sys: 2.38 ms, total: 53.3 ms
Wall time: 178 ms


[-1.0765721797943115,
 -8.908791542053223,
 0.5645343661308289,
 1.4505490064620972,
 -0.9708723425865173]

### 豆包

In [6]:
%%time

api_key = os.getenv("DOUBAO_API_KEY")

embed_model = MyOpenAIEmbedding(
    api_key=api_key,
    model="ep-20240918183239-l2jqq",
    api_base="https://ark.cn-beijing.volces.com/api/v3"
)

embeddings = embed_model.get_text_embedding(
    "花椰菜又称菜花、花菜，是一种常见的蔬菜。"
)
embeddings[:5]

CPU times: user 53.8 ms, sys: 243 µs, total: 54 ms
Wall time: 148 ms


[-0.337890625, -1.890625, -1.9140625, -1.6875, -1.921875]