Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add plugin: ai-rag #1038

Merged
merged 4 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions plugins/wasm-go/extensions/ai-rag/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
config.yaml
main.wasm
tmp/
22 changes: 22 additions & 0 deletions plugins/wasm-go/extensions/ai-rag/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# 简介
通过对接阿里云向量检索服务实现LLM-RAG,流程如图所示:

![](https://img.alicdn.com/imgextra/i1/O1CN01LuRVs41KhoeuzakeF_!!6000000001196-0-tps-1926-1316.jpg)

rinfx marked this conversation as resolved.
Show resolved Hide resolved
# 示例
[CEC-Corpus](https://github.com/shijiebei2009/CEC-Corpus) 数据集包含 332 篇突发事件的新闻报道的语料和标注数据,提取其原始的新闻稿文本,将其向量化后添加到阿里云向量检索服务。文本向量化的教程可以参考[《基于向量检索服务与灵积实现语义搜索》](https://help.aliyun.com/document_detail/2510234.html)。

以下为使用RAG进行增强的例子,原始请求为:
```
海南追尾事故,发生在哪里?原因是什么?人员伤亡情况如何?
```

未经过RAG插件处理LLM返回的结果为:
```
抱歉,作为AI模型,我无法实时获取和更新新闻事件的具体信息,包括地点、原因、人员伤亡等细节。对于此类具体事件,建议您查阅最新的新闻报道或官方通报以获取准确信息。您可以访问主流媒体网站、使用新闻应用或者关注相关政府部门的公告来获取这类动态资讯。
```

经过RAG插件处理后LLM返回的结果为:
```
海南追尾事故发生在海文高速公路文昌至海口方向37公里处。关于事故的具体原因,交警部门当时仍在进一步调查中,所以根据提供的信息无法确定事故的确切原因。人员伤亡情况是1人死亡(司机当场死亡),另有8人受伤(包括2名儿童和6名成人),所有受伤人员都被解救并送往医院进行治疗。
```
36 changes: 36 additions & 0 deletions plugins/wasm-go/extensions/ai-rag/dashscope/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package dashscope

// DashScope embedding service: Request
type Request struct {
Model string `json:"model"`
Input Input `json:"input"`
Parameter Parameter `json:"parameters"`
}

type Input struct {
Texts []string `json:"texts"`
}

type Parameter struct {
TextType string `json:"text_type"`
}

// DashScope embedding service: Response
type Response struct {
Output Output `json:"output"`
Usage Usage `json:"usage"`
RequestID string `json:"request_id"`
}

type Output struct {
Embeddings []Embedding `json:"embeddings"`
}

type Embedding struct {
Embedding []float32 `json:"embedding"`
TextIndex int32 `json:"text_index"`
}

type Usage struct {
TotalTokens int32 `json:"total_tokens"`
}
26 changes: 26 additions & 0 deletions plugins/wasm-go/extensions/ai-rag/dashvector/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package dashvector

// DashVecotor document search: Request
type Request struct {
TopK int32 `json:"topk"`
OutputFileds []string `json:"output_fileds"`
Vector []float32 `json:"vector"`
}

// DashVecotor document search: Response
type Response struct {
Code int32 `json:"code"`
RequestID string `json:"request_id"`
Message string `json:"message"`
Output []OutputObject `json:"output"`
}

type OutputObject struct {
ID string `json:"id"`
Fields FieldObject `json:"fields"`
Score float32 `json:"score"`
}

type FieldObject struct {
Raw string `json:"raw"`
}
19 changes: 19 additions & 0 deletions plugins/wasm-go/extensions/ai-rag/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module myplugin
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

module ai-rag


go 1.18

require (
github.com/alibaba/higress/plugins/wasm-go v1.3.5
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a
github.com/tidwall/gjson v1.14.3
)

require (
github.com/google/uuid v1.3.0 // indirect
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 // indirect
github.com/magefile/mage v1.14.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/resp v0.1.1 // indirect
github.com/tidwall/sjson v1.2.5
)
25 changes: 25 additions & 0 deletions plugins/wasm-go/extensions/ai-rag/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
github.com/alibaba/higress/plugins/wasm-go v1.3.5 h1:VOLL3m442IHCSu8mR5AZ4sc6LVT9X0w1hdqDI7oB9jY=
github.com/alibaba/higress/plugins/wasm-go v1.3.5/go.mod h1:kr3V9Ntbspj1eSrX8rgjBsdMXkGupYEf+LM72caGPQc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520 h1:IHDghbGQ2DTIXHBHxWfqCYQW1fKjyJ/I7W1pMyUDeEA=
github.com/higress-group/nottinygc v0.0.0-20231101025119-e93c4c2f8520/go.mod h1:Nz8ORLaFiLWotg6GeKlJMhv8cci8mM43uEnLA5t8iew=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a h1:luYRvxLTE1xYxrXYj7nmjd1U0HHh8pUPiKfdZ0MhCGE=
github.com/higress-group/proxy-wasm-go-sdk v0.0.0-20240226064518-b3dc4646a35a/go.mod h1:hNFjhrLUIq+kJ9bOcs8QtiplSQ61GZXtd2xHKx4BYRo=
github.com/magefile/mage v1.14.0 h1:6QDX3g6z1YvJ4olPhT1wksUcSa/V0a1B+pJb73fBjyo=
github.com/magefile/mage v1.14.0/go.mod h1:z5UZb/iS3GoOSn0JgWuiw7dxlurVYTu+/jHXqQg881A=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw=
github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/resp v0.1.1 h1:Ly20wkhqKTmDUPlyM1S7pWo5kk0tDu8OoC/vFArXmwE=
github.com/tidwall/resp v0.1.1/go.mod h1:3/FrruOBAxPTPtundW0VXgmsQ4ZBA0Aw714lVYgwFa0=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
124 changes: 124 additions & 0 deletions plugins/wasm-go/extensions/ai-rag/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package main

import (
"encoding/json"
"fmt"
"net/http"

"myplugin/dashscope"
"myplugin/dashvector"

"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
"github.com/tidwall/gjson"
)

func main() {
wrapper.SetCtx(
"ai-rag",
wrapper.ParseConfigBy(parseConfig),
wrapper.ProcessRequestHeadersBy(onHttpRequestHeaders),
wrapper.ProcessRequestBodyBy(onHttpRequestBody),
)
}

type AIRagConfig struct {
DashScopeClient wrapper.HttpClient
DashScopeAPIKey string
DashVectorClient wrapper.HttpClient
DashVectorAPIKey string
DashVectorCollection string
}

type Request struct {
Model string `json:"model"`
Input Input `json:"input"`
Parameter Parameter `json:"parameters"`
}

type Input struct {
Messages []Message `json:"messages"`
}

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}

type Parameter struct {
ResultFormat string `json:"result_format"`
}

func parseConfig(json gjson.Result, config *AIRagConfig, log wrapper.Log) error {
config.DashScopeAPIKey = json.Get("dashscope.apiKey").String()

config.DashScopeClient = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: json.Get("dashscope.serviceName").String(),
Port: json.Get("dashscope.servicePort").Int(),
Domain: json.Get("dashscope.domain").String(),
})
config.DashVectorAPIKey = json.Get("dashvector.apiKey").String()
config.DashVectorCollection = json.Get("dashvector.collection").String()
config.DashVectorClient = wrapper.NewClusterClient(wrapper.DnsCluster{
ServiceName: json.Get("dashvector.serviceName").String(),
Port: json.Get("dashvector.servicePort").Int(),
Domain: json.Get("dashvector.domain").String(),
})
return nil
}

func onHttpRequestHeaders(ctx wrapper.HttpContext, config AIRagConfig, log wrapper.Log) types.Action {
proxywasm.RemoveHttpRequestHeader("content-length")
return types.ActionContinue
}

func onHttpRequestBody(ctx wrapper.HttpContext, config AIRagConfig, body []byte, log wrapper.Log) types.Action {
var rawRequest Request
_ = json.Unmarshal(body, &rawRequest)
rawContent := rawRequest.Input.Messages[0].Content
requestEmbedding := dashscope.Request{
Model: "text-embedding-v1",
Input: dashscope.Input{
Texts: []string{rawContent},
},
Parameter: dashscope.Parameter{
TextType: "query",
},
}
headers := [][2]string{{"Content-Type", "application/json"}, {"Authorization", "Bearer " + config.DashScopeAPIKey}}
reqEmbeddingSerialized, _ := json.Marshal(requestEmbedding)
// log.Info(string(reqEmbeddingSerialized))
config.DashScopeClient.Post(
"/api/v1/services/embeddings/text-embedding/text-embedding",
headers,
reqEmbeddingSerialized,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
var responseEmbedding dashscope.Response
_ = json.Unmarshal(responseBody, &responseEmbedding)
requestQuery := dashvector.Request{
TopK: 1,
OutputFileds: []string{"raw"},
Vector: responseEmbedding.Output.Embeddings[0].Embedding,
}
requestQuerySerialized, _ := json.Marshal(requestQuery)
config.DashVectorClient.Post(
fmt.Sprintf("/v1/collections/%s/query", config.DashVectorCollection),
[][2]string{{"Content-Type", "application/json"}, {"dashvector-auth-token", config.DashVectorAPIKey}},
requestQuerySerialized,
func(statusCode int, responseHeaders http.Header, responseBody []byte) {
var response dashvector.Response
_ = json.Unmarshal(responseBody, &response)
doc := response.Output[0].Fields.Raw
rawRequest.Input.Messages[0].Content = fmt.Sprintf("%s\n参考以上信息,对以下问题做出回答:%s", doc, rawContent)
newBody, _ := json.Marshal(rawRequest)
// log.Info(string(newBody))
proxywasm.ReplaceHttpRequestBody(newBody)
proxywasm.ResumeHttpRequest()
},
)
},
50000,
)
return types.ActionPause
}
Loading