Skip to content

Commit

Permalink
Merge pull request #265 from ConnectAI-E/support_vision
Browse files Browse the repository at this point in the history
feat: 支持gpt4v 「WIP」
  • Loading branch information
Leizhenpeng committed Nov 20, 2023
2 parents 20f1ea2 + c1d811e commit 79dd756
Show file tree
Hide file tree
Showing 14 changed files with 521 additions and 24 deletions.
6 changes: 3 additions & 3 deletions code/handlers/card_common_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"context"
"encoding/json"
"fmt"
"start-feishubot/logger"

larkcard "github.com/larksuite/oapi-sdk-go/v3/card"
)

Expand All @@ -20,11 +18,13 @@ func NewCardHandler(m MessageHandler) CardHandlerFunc {
handlers := []CardHandlerMeta{
NewClearCardHandler,
NewPicResolutionHandler,
NewVisionResolutionHandler,
NewPicTextMoreHandler,
NewPicModeChangeHandler,
NewRoleTagCardHandler,
NewRoleCardHandler,
NewAIModeCardHandler,
NewVisionModeChangeHandler,
}

return func(ctx context.Context, cardAction *larkcard.CardAction) (interface{}, error) {
Expand All @@ -35,7 +35,7 @@ func NewCardHandler(m MessageHandler) CardHandlerFunc {
return nil, err
}
//pp.Println(cardMsg)
logger.Debug("cardMsg ", cardMsg)
//logger.Debug("cardMsg ", cardMsg)
for _, handler := range handlers {
h := handler(cardMsg, m)
i, err := h(ctx, cardAction)
Expand Down
74 changes: 74 additions & 0 deletions code/handlers/card_vision_action.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package handlers

import (
"context"
"fmt"
larkcard "github.com/larksuite/oapi-sdk-go/v3/card"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
"start-feishubot/services"
)

func NewVisionResolutionHandler(cardMsg CardMsg,
m MessageHandler) CardHandlerFunc {
return func(ctx context.Context, cardAction *larkcard.CardAction) (interface{}, error) {
if cardMsg.Kind == VisionStyleKind {
CommonProcessVisionStyle(cardMsg, cardAction, m.sessionCache)
return nil, nil
}
return nil, ErrNextHandler
}
}
func NewVisionModeChangeHandler(cardMsg CardMsg,
m MessageHandler) CardHandlerFunc {
return func(ctx context.Context, cardAction *larkcard.CardAction) (interface{}, error) {
if cardMsg.Kind == VisionModeChangeKind {
newCard, err, done := CommonProcessVisionModeChange(cardMsg, m.sessionCache)
if done {
return newCard, err
}
return nil, nil
}
return nil, ErrNextHandler
}
}

func CommonProcessVisionStyle(msg CardMsg,
cardAction *larkcard.CardAction,
cache services.SessionServiceCacheInterface) {
option := cardAction.Action.Option
fmt.Println(larkcore.Prettify(msg))
cache.SetVisionDetail(msg.SessionId, services.VisionDetail(option))
//send text
replyMsg(context.Background(), "图片解析度调整为:"+option,
&msg.MsgId)
}

func CommonProcessVisionModeChange(cardMsg CardMsg,
session services.SessionServiceCacheInterface) (
interface{}, error, bool) {
if cardMsg.Value == "1" {

sessionId := cardMsg.SessionId
session.Clear(sessionId)
session.SetMode(sessionId,
services.ModeVision)
session.SetVisionDetail(sessionId,
services.VisionDetailLow)

newCard, _ :=
newSendCard(
withHeader("🕵️️ 已进入图片推理模式", larkcard.TemplateBlue),
withVisionDetailLevelBtn(&sessionId),
withNote("提醒:回复图片,让LLM和你一起推理图片的内容。"))
return newCard, nil, true
}
if cardMsg.Value == "0" {
newCard, _ := newSendCard(
withHeader("️🎒 机器人提醒", larkcard.TemplateGreen),
withMainMd("依旧保留此话题的上下文信息"),
withNote("我们可以继续探讨这个话题,期待和您聊天。如果您有其他问题或者想要讨论的话题,请告诉我哦"),
)
return newCard, nil, true
}
return nil, nil, false
}
28 changes: 27 additions & 1 deletion code/handlers/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ func msgFilter(msg string) string {
//replace @到下一个非空的字段 为 ''
regex := regexp.MustCompile(`@[^ ]*`)
return regex.ReplaceAllString(msg, "")

}

// Parse rich text json to text
Expand Down Expand Up @@ -47,6 +46,33 @@ func parsePostContent(content string) string {
return msgFilter(text)
}

func parsePostImageKeys(content string) []string {
var contentMap map[string]interface{}
err := json.Unmarshal([]byte(content), &contentMap)

if err != nil {
fmt.Println(err)
return nil
}

var imageKeys []string

if contentMap["content"] == nil {
return imageKeys
}

contentList := contentMap["content"].([]interface{})
for _, v := range contentList {
for _, v1 := range v.([]interface{}) {
if v1.(map[string]interface{})["tag"] == "img" {
imageKeys = append(imageKeys, v1.(map[string]interface{})["image_key"].(string))
}
}
}

return imageKeys
}

func parseContent(content, msgType string) string {
//"{\"text\":\"@_user_1 hahaha\"}",
//only get text content hahaha
Expand Down
1 change: 1 addition & 0 deletions code/handlers/event_common_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type MsgInfo struct {
qParsed string
fileKey string
imageKey string
imageKeys []string // post 消息卡片中的图片组
sessionId *string
mention []*larkim.MentionEvent
}
Expand Down
17 changes: 17 additions & 0 deletions code/handlers/event_msg_action.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ func setDefaultPrompt(msg []openai.Messages) []openai.Messages {
return msg
}

//func setDefaultVisionPrompt(msg []openai.VisionMessages) []openai.VisionMessages {
// if !hasSystemRole(msg) {
// msg = append(msg, openai.VisionMessages{
// Role: "system", Content: []openai.ContentType{
// {Type: "text", Text: "You are ChatGPT4V, " +
// "You are ChatGPT4V, " +
// "a large language and picture model trained by" +
// " OpenAI. " +
// "Answer in user's language as concisely as" +
// " possible. Knowledge cutoff: 20230601 " +
// "Current date" + time.Now().Format("20060102"),
// }},
// })
// }
// return msg
//}

type MessageAction struct { /*消息*/
}

Expand Down
160 changes: 160 additions & 0 deletions code/handlers/event_vision_action.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package handlers

import (
"context"
"fmt"
"os"
"start-feishubot/initialization"
"start-feishubot/services"
"start-feishubot/services/openai"
"start-feishubot/utils"

larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
)

type VisionAction struct { /*图片推理*/
}

func (va *VisionAction) Execute(a *ActionInfo) bool {
if !AzureModeCheck(a) {
return true
}

if isVisionCommand(a) {
initializeVisionMode(a)
sendVisionInstructionCard(*a.ctx, a.info.sessionId, a.info.msgId)
return false
}

mode := a.handler.sessionCache.GetMode(*a.info.sessionId)

if a.info.msgType == "image" {
if mode != services.ModeVision {
sendVisionModeCheckCard(*a.ctx, a.info.sessionId, a.info.msgId)
return false
}

return va.handleVisionImage(a)
}

if a.info.msgType == "post" && mode == services.ModeVision {
return va.handleVisionPost(a)
}

return true
}

func isVisionCommand(a *ActionInfo) bool {
_, foundPic := utils.EitherTrimEqual(a.info.qParsed, "/vision", "图片推理")
return foundPic
}

func initializeVisionMode(a *ActionInfo) {
a.handler.sessionCache.Clear(*a.info.sessionId)
a.handler.sessionCache.SetMode(*a.info.sessionId, services.ModeVision)
a.handler.sessionCache.SetVisionDetail(*a.info.sessionId, services.VisionDetailHigh)
}

func (va *VisionAction) handleVisionImage(a *ActionInfo) bool {
detail := a.handler.sessionCache.GetVisionDetail(*a.info.sessionId)
base64, err := downloadAndEncodeImage(a.info.imageKey, a.info.msgId)
if err != nil {
replyWithErrorMsg(*a.ctx, err, a.info.msgId)
return false
}

return va.processImageAndReply(a, base64, detail)
}

func (va *VisionAction) handleVisionPost(a *ActionInfo) bool {
detail := a.handler.sessionCache.GetVisionDetail(*a.info.sessionId)
var base64s []string

for _, imageKey := range a.info.imageKeys {
if imageKey == "" {
continue
}
base64, err := downloadAndEncodeImage(imageKey, a.info.msgId)
if err != nil {
replyWithErrorMsg(*a.ctx, err, a.info.msgId)
return false
}
base64s = append(base64s, base64)
}

if len(base64s) == 0 {
replyMsg(*a.ctx, "🤖️:请发送一张图片", a.info.msgId)
return false
}

return va.processMultipleImagesAndReply(a, base64s, detail)
}

func downloadAndEncodeImage(imageKey string, msgId *string) (string, error) {
f := fmt.Sprintf("%s.png", imageKey)
defer os.Remove(f)

req := larkim.NewGetMessageResourceReqBuilder().MessageId(*msgId).FileKey(imageKey).Type("image").Build()
resp, err := initialization.GetLarkClient().Im.MessageResource.Get(context.Background(), req)
if err != nil {
return "", err
}

resp.WriteFile(f)
return openai.GetBase64FromImage(f)
}

func replyWithErrorMsg(ctx context.Context, err error, msgId *string) {
replyMsg(ctx, fmt.Sprintf("🤖️:图片下载失败,请稍后再试~\n 错误信息: %v", err), msgId)
}

func (va *VisionAction) processImageAndReply(a *ActionInfo, base64 string, detail string) bool {
msg := createVisionMessages("解释这个图片", base64, detail)
completions, err := a.handler.gpt.GetVisionInfo(msg)
if err != nil {
replyWithErrorMsg(*a.ctx, err, a.info.msgId)
return false
}
sendVisionTopicCard(*a.ctx, a.info.sessionId, a.info.msgId, completions.Content)
return false
}

func (va *VisionAction) processMultipleImagesAndReply(a *ActionInfo, base64s []string, detail string) bool {
msg := createMultipleVisionMessages(a.info.qParsed, base64s, detail)
completions, err := a.handler.gpt.GetVisionInfo(msg)
if err != nil {
replyWithErrorMsg(*a.ctx, err, a.info.msgId)
return false
}
sendVisionTopicCard(*a.ctx, a.info.sessionId, a.info.msgId, completions.Content)
return false
}

func createVisionMessages(query, base64Image, detail string) []openai.VisionMessages {
return []openai.VisionMessages{
{
Role: "user",
Content: []openai.ContentType{
{Type: "text", Text: query},
{Type: "image_url", ImageURL: &openai.ImageURL{
URL: "data:image/jpeg;base64," + base64Image,
Detail: detail,
}},
},
},
}
}

func createMultipleVisionMessages(query string, base64Images []string, detail string) []openai.VisionMessages {
content := []openai.ContentType{{Type: "text", Text: query}}
for _, base64Image := range base64Images {
content = append(content, openai.ContentType{
Type: "image_url",
ImageURL: &openai.ImageURL{
URL: "data:image/jpeg;base64," + base64Image,
Detail: detail,
},
})
}
return []openai.VisionMessages{{Role: "user", Content: content}}
}
5 changes: 3 additions & 2 deletions code/handlers/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func (m MessageHandler) msgReceivedHandler(ctx context.Context, event *larkim.P2
qParsed: strings.Trim(parseContent(*content, msgType), " "),
fileKey: parseFileKey(*content),
imageKey: parseImageKey(*content),
imageKeys: parsePostImageKeys(*content),
sessionId: sessionId,
mention: mention,
}
Expand All @@ -94,17 +95,17 @@ func (m MessageHandler) msgReceivedHandler(ctx context.Context, event *larkim.P2
&ProcessedUniqueAction{}, //避免重复处理
&ProcessMentionAction{}, //判断机器人是否应该被调用
&AudioAction{}, //语音处理
&EmptyAction{}, //空消息处理
&ClearAction{}, //清除消息处理
&VisionAction{}, //图片推理处理
&PicAction{}, //图片处理
&AIModeAction{}, //模式切换处理
&RoleListAction{}, //角色列表处理
&HelpAction{}, //帮助处理
&BalanceAction{}, //余额处理
&RolePlayAction{}, //角色扮演处理
&MessageAction{}, //消息处理
&EmptyAction{}, //空消息处理
&StreamMessageAction{}, //流式消息处理

}
chain(data, actions...)
return nil
Expand Down
Loading

0 comments on commit 79dd756

Please sign in to comment.