Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support ollama ai model #1001

Merged
merged 19 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugins/wasm-go/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ENV GOPROXY=${GOPROXY}
ARG EXTRA_TAGS=""
ENV EXTRA_TAGS=${EXTRA_TAGS}

ARG PLUGIN_NAME=hello-world
ARG PLUGIN_NAME=ai-proxy

WORKDIR /workspace

Expand Down
2 changes: 1 addition & 1 deletion plugins/wasm-go/Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
PLUGIN_NAME ?= hello-world
PLUGIN_NAME ?= ai-proxy
Claire-w marked this conversation as resolved.
Show resolved Hide resolved
BUILDER_REGISTRY ?= higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/
REGISTRY ?= higress-registry.cn-hangzhou.cr.aliyuncs.com/plugins/
GO_VERSION ?= 1.19
Expand Down
9 changes: 9 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ Anthropic Claude 所对应的 `type` 为 `claude`。它特有的配置字段如
|-----------|--------|-----|-----|-------------------|
| `version` | string | 必填 | - | Claude 服务的 API 版本 |

#### Ollama

Ollama 所对应的 `type` 为 `ollama`。它特有的配置字段如下:

| 名称 | 数据类型 | 填写要求 | 默认值 | 描述 |
|-------------------|--------|------|-----|----------------------------------------------|
| `ollamaServerIP` | string | 必填 | - | Ollama 服务器的 IP 地址 |
CH3CHO marked this conversation as resolved.
Show resolved Hide resolved
| `ollamaServerPort` | string | 必填 | - | Ollama 服务器的端口号,默认为11434 |
CH3CHO marked this conversation as resolved.
Show resolved Hide resolved

## 用法示例

### 使用 OpenAI 协议代理 Azure OpenAI 服务
Expand Down
2 changes: 1 addition & 1 deletion plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func main() {
}

func parseConfig(json gjson.Result, pluginConfig *config.PluginConfig, log wrapper.Log) error {
//log.Debugf("loading config: %s", json.String())
// log.Debugf("loading config: %s", json.String())

pluginConfig.FromJson(json)
if err := pluginConfig.Validate(); err != nil {
Expand Down
112 changes: 112 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/provider/ollama.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package provider

import (
"errors"
"fmt"

"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
)

// ollamaProvider is the provider for Ollama service.

const (
// ollamaDomain = "localhost:11434"
CH3CHO marked this conversation as resolved.
Show resolved Hide resolved
ollamaChatCompletionPath = "/v1/chat/completions"
)

type ollamaProviderInitializer struct {
}

func (m *ollamaProviderInitializer) ValidateConfig(config ProviderConfig) error {
if config.ollamaServerIP == "" {
return errors.New("missing ollamaServerIP in provider config")
}
if config.ollamaServerPort == "" {
return errors.New("missing ollamaServerPort in provider config")
}
return nil
}

func (m *ollamaProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
serviceDomain := config.ollamaServerIP + ":" + config.ollamaServerPort
return &ollamaProvider{
config: config,
serviceDomain: serviceDomain,
contextCache: createContextCache(&config),
}, nil
}

type ollamaProvider struct {
config ProviderConfig
serviceDomain string
contextCache *contextCache
}

func (m *ollamaProvider) GetProviderType() string {
return providerTypeOllama
}

func (m *ollamaProvider) OnRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}
_ = util.OverwriteRequestPath(ollamaChatCompletionPath)
_ = util.OverwriteRequestHost(m.serviceDomain)
_ = proxywasm.RemoveHttpRequestHeader("Content-Length")

return types.ActionContinue, nil
}

func (m *ollamaProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) (types.Action, error) {
if apiName != ApiNameChatCompletion {
return types.ActionContinue, errUnsupportedApiName
}

request := &chatCompletionRequest{}
if err := decodeChatCompletionRequest(body, request); err != nil {
CH3CHO marked this conversation as resolved.
Show resolved Hide resolved
return types.ActionContinue, err
}

model := request.Model
if model == "" {
return types.ActionContinue, errors.New("missing model in chat completion request")
}
mappedModel := getMappedModel(model, m.config.modelMapping, log)
if mappedModel == "" {
return types.ActionContinue, errors.New("model becomes empty after applying the configured mapping")
}
request.Model = mappedModel

var getcontenterr error
CH3CHO marked this conversation as resolved.
Show resolved Hide resolved
if m.contextCache != nil {
err := m.contextCache.GetContent(func(content string, err error) {
defer func() {
_ = proxywasm.ResumeHttpRequest()
}()
if err != nil {
log.Errorf("failed to load context file: %v", err)
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to load context file: %v", err))
}
insertContextMessage(request, content)
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
}, log)
if err == nil {
return types.ActionPause, nil
}
getcontenterr = err
} else {
if err := replaceJsonRequestBody(request, log); err != nil {
_ = util.SendResponse(500, util.MimeTypeTextPlain, fmt.Sprintf("failed to replace request body: %v", err))
}
_ = proxywasm.ResumeHttpRequest()
return types.ActionPause, nil
}


return types.ActionContinue, getcontenterr
}
12 changes: 11 additions & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const (
providerTypeBaichuan = "baichuan"
providerTypeYi = "yi"
providerTypeDeepSeek = "deepseek"
providerTypeOllama = "ollama"

protocolOpenAI = "openai"
protocolOriginal = "original"
Expand Down Expand Up @@ -60,6 +61,7 @@ var (
providerTypeBaichuan: &baichuanProviderInitializer{},
providerTypeYi: &yiProviderInitializer{},
providerTypeDeepSeek: &deepseekProviderInitializer{},
providerTypeOllama: &ollamaProviderInitializer{},
}
)

Expand Down Expand Up @@ -89,7 +91,7 @@ type ResponseBodyHandler interface {

type ProviderConfig struct {
// @Title zh-CN AI服务提供商
// @Description zh-CN AI服务提供商类型,目前支持的取值为:"moonshot"、"qwen"、"openai"、"azure"、"baichuan"、"yi"
// @Description zh-CN AI服务提供商类型,目前支持的取值为:"moonshot"、"qwen"、"openai"、"azure"、"baichuan"、"yi"、"ollama"
typ string `required:"true" yaml:"type" json:"type"`
// @Title zh-CN API Tokens
// @Description zh-CN 在请求AI服务时用于认证的API Token列表。不同的AI服务提供商可能有不同的名称。部分供应商只支持配置一个API Token(如Azure OpenAI)。
Expand All @@ -109,6 +111,12 @@ type ProviderConfig struct {
// @Title zh-CN 启用通义千问搜索服务
// @Description zh-CN 仅适用于通义千问服务,表示是否启用通义千问的互联网搜索功能。
qwenEnableSearch bool `required:"false" yaml:"qwenEnableSearch" json:"qwenEnableSearch"`
// @Title zh-CN Ollama Server IP/Domain
// @Description zh-CN 仅适用于 Ollama 服务。Ollama 服务器的 IP 地址。
ollamaServerIP string `required:"false" yaml:"ollamaServerIP" json:"ollamaServerIP"`
// @Title zh-CN Ollama Server Port
// @Description zh-CN 仅适用于 Ollama 服务。Ollama 服务器的端口号。
ollamaServerPort string `required:"false" yaml:"ollamaServerPort" json:"ollamaServerPort"`
// @Title zh-CN 模型名称映射表
// @Description zh-CN 用于将请求中的模型名称映射为目标AI服务商支持的模型名称。支持通过“*”来配置全局映射
modelMapping map[string]string `required:"false" yaml:"modelMapping" json:"modelMapping"`
Expand Down Expand Up @@ -137,6 +145,8 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
c.qwenFileIds = append(c.qwenFileIds, fileId.String())
}
c.qwenEnableSearch = json.Get("qwenEnableSearch").Bool()
c.ollamaServerIP = json.Get("ollamaServerIP").String()
c.ollamaServerPort = json.Get("ollamaServerPort").String()
c.modelMapping = make(map[string]string)
for k, v := range json.Get("modelMapping").Map() {
c.modelMapping[k] = v.String()
Expand Down