# 使用 one-api embedding

测试包括下面几种情况：

- 直接使用 OpenAI API
    - 智谱ai ✅
- LlamaIndex OpenAI embedding
    - 智谱ai ❌ - 应该都不可行，llamaindex做了检查
- LangChain OpenAI embedding
    - 智谱ai ❌ - 应该是 api 兼容方面的问题，one-api 到 zhipuai
- LlamaIndex 改动 OpenAI embedding，取消模型校验
    - 智谱ai ✅ -可以

## OpenAI API

In [1]:
%%time

import os
from openai import OpenAI

client = OpenAI(
    api_key="sk-bJP6QSnUfjAYeYeE505d3eBf63A643BeB0B8E350Df9b7750",
    base_url="http://ape:3000/v1",
)

print("----- embeddings request -----")
resp = client.embeddings.create(
    model="embedding-2",
    input=["花椰菜又称菜花、花菜，是一种常见的蔬菜。"],
    encoding_format="float"
)
resp.data[0].embedding[:5]

----- embeddings request -----
CPU times: user 367 ms, sys: 36.4 ms, total: 403 ms
Wall time: 585 ms


[-0.030153707, 0.019338187, -0.059128232, -0.004209153, -0.013122211]

## LlamaIndex

In [2]:
%%time

from llama_index.embeddings.openai import OpenAIEmbedding

embed_model = OpenAIEmbedding(
    model="embedding-2",
    api_key="sk-bJP6QSnUfjAYeYeE505d3eBf63A643BeB0B8E350Df9b7750",
    api_base="http://ape:3000/v1"
)

ValueError: 'embedding-2' is not a valid OpenAIEmbeddingModelType

## LangChain

In [6]:
%%time
%%capture

!pip install langchain-openai

# from langchain_openai import OpenAIEmbeddings

CPU times: user 14.9 ms, sys: 401 µs, total: 15.3 ms
Wall time: 2.67 s


In [8]:
%%time

from langchain_openai import OpenAIEmbeddings
from llama_index.embeddings.langchain import LangchainEmbedding

embed_model = OpenAIEmbeddings(
    model="embedding-2",
    openai_api_base="http://ape:3000/v1",
    openai_api_key="sk-bJP6QSnUfjAYeYeE505d3eBf63A643BeB0B8E350Df9b7750",
    dimensions=1024,
)

embed_model = LangchainEmbedding(embed_model)

embeddings = embed_model.get_text_embedding(
    "花椰菜又称菜花、花菜，是一种常见的蔬菜。"
)

embeddings[:10]

InternalServerError: Error code: 500 - {'error': {'message': 'invalid input length, zhipu only support one input (request id: 2024101123024045701951934009301)', 'type': 'one_api_error', 'param': '', 'code': 'convert_request_failed'}}

## workaround

尝试自定义修改 LlamaIndex OpenAI Embedding 实现，避免检查

In [18]:
%%time

from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingMode,OpenAIEmbeddingModelType
import httpx
from typing import Any, Dict, List, Optional, Tuple
from llama_index.core.callbacks.base import CallbackManager

def get_engine(
    mode: str,
    model: str,
    # mode_model_dict: Dict[Tuple[OpenAIEmbeddingMode, str], OpenAIEmbeddingModeModel],
) -> str:
    """Get engine."""
    # key = (OpenAIEmbeddingMode(mode), OpenAIEmbeddingModelType(model))
    # if key not in mode_model_dict:
    #     raise ValueError(f"Invalid mode, model combination: {key}")
    # return mode_model_dict[key].value
    return model


class MyOpenAIEmbedding(OpenAIEmbedding):


    def __init__(
        self,
        mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
        model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
        embed_batch_size: int = 100,
        dimensions: Optional[int] = None,
        additional_kwargs: Optional[Dict[str, Any]] = None,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
        api_version: Optional[str] = None,
        max_retries: int = 10,
        timeout: float = 60.0,
        reuse_client: bool = True,
        callback_manager: Optional[CallbackManager] = None,
        default_headers: Optional[Dict[str, str]] = None,
        http_client: Optional[httpx.Client] = None,
        async_http_client: Optional[httpx.AsyncClient] = None,
        num_workers: Optional[int] = None,
        **kwargs: Any,
    ) -> None:
        additional_kwargs = additional_kwargs or {}
        if dimensions is not None:
            additional_kwargs["dimensions"] = dimensions

        api_key, api_base, api_version = self._resolve_credentials(
            api_key=api_key,
            api_base=api_base,
            api_version=api_version,
        )

        query_engine = get_engine(mode, model)
        text_engine = get_engine(mode, model)

        if "model_name" in kwargs:
            model_name = kwargs.pop("model_name")
            query_engine = text_engine = model_name
        else:
            model_name = model

        super().__init__(
            embed_batch_size=embed_batch_size,
            dimensions=dimensions,
            callback_manager=callback_manager,
            model_name=model_name,
            additional_kwargs=additional_kwargs,
            api_key=api_key,
            api_base=api_base,
            api_version=api_version,
            max_retries=max_retries,
            reuse_client=reuse_client,
            timeout=timeout,
            default_headers=default_headers,
            num_workers=num_workers,
            **kwargs,
        )
        self._query_engine = query_engine
        self._text_engine = text_engine

        self._client = None
        self._aclient = None
        self._http_client = http_client
        self._async_http_client = async_http_client

embed_model = MyOpenAIEmbedding(
    model="embedding-2",
    api_key="sk-bJP6QSnUfjAYeYeE505d3eBf63A643BeB0B8E350Df9b7750",
    api_base="http://ape:3000/v1"
)

CPU times: user 4.47 ms, sys: 56 µs, total: 4.52 ms
Wall time: 4.46 ms


In [19]:
%%time

embeddings = embed_model.get_text_embedding(
    "花椰菜又称菜花、花菜，是一种常见的蔬菜。"
)
embeddings[:5]

CPU times: user 55.6 ms, sys: 293 µs, total: 55.9 ms
Wall time: 249 ms


[-0.030153707, 0.019338187, -0.059128232, -0.004209153, -0.013122211]