Skip to content

Commit

Permalink
feat: implement label_initial_prompt; better error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
AnotiaWang committed Apr 18, 2023
1 parent 4d4a24a commit f6d3543
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 13 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ The configuration file is saved in `config.json`.

## Plans

- [ ] Support task `label_initial_prompt`

This task is a bit hard to find for my language configuration, I'll keep finding it.
- [x] ~~Support task `label_initial_prompt`~~
- [ ] Support other language models (maybe)

There have been many excellent LLMs coming out recently, like GPT4All. I may try to support them in my free time(but no guarantees). if you have ideas, PRs are welcome!
Expand Down
4 changes: 1 addition & 3 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@

## 计划

- [ ] 支持任务 `label_initial_prompt`

目前中文的任务还是比较少的,尤其这个类别很少出现。等下次我遇到了再做支持,也欢迎大家 PR。
- [x] ~~支持任务 `label_initial_prompt`~~
- [ ] 支持其他语言模型(或许)

最近出现了许多优秀的 LLM,比如 GPT4All,我有空的话可能会尝试支持它们(但不能保证)。如果您有想法,欢迎提 PR!
Expand Down
26 changes: 26 additions & 0 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@ type OALabelPrompterReplyTask struct {
UserId string `json:"userId"`
}

type OALabelInitialPromptTask struct {
Id string `json:"id"`
Mode string `json:"mode"`
Type string `json:"type"`
Labels []OARandomTaskLabel `json:"labels"`
Prompt string `json:"prompt"`
MessageId string `json:"message_id"`
Disposition string `json:"disposition"`
Conversation OAConversation `json:"conversation"`
ValidLabels []OALabel `json:"valid_labels"`
MandatoryLabels []OALabel `json:"mandatory_labels"`
UserId string `json:"userId"`
}

type OAAssistantReplyTask struct {
Id string `json:"id"`
Type string `json:"type"`
Expand Down Expand Up @@ -162,3 +176,15 @@ type PostBodyRankAssistant struct {
UpdateType string `json:"update_type"`
Content PostBodyRankAssistantContent `json:"content"`
}

type PostBodyLabelInitialPromptContent struct {
Labels map[OALabel]float32 `json:"labels"`
MessageId string `json:"message_id"`
Text string `json:"text"`
}
type PostBodyLabelInitialPrompt struct {
Id string `json:"id"`
Lang string `json:"lang"`
UpdateType string `json:"update_type"`
Content PostBodyLabelInitialPromptContent `json:"content"`
}
93 changes: 86 additions & 7 deletions open-assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ func RefreshCookie() error {
return err
}
cs := resp.Header()["Set-Cookie"]
if len(cs) == 0 {
return fmt.Errorf("no cookie")
if len(cs) <= 1 {
return fmt.Errorf("no cookie. Please consider going to https://open-assistant.io , login, and update your cookie in config.json")
}
c := strings.Join(cs, "")
return model.UpdateCookie(c)
Expand Down Expand Up @@ -150,7 +150,7 @@ func LabelPrompterReply(id string, task model.OALabelPrompterReplyTask) error {
text += fmt.Sprintf("User: %s", task.Reply)
}

t, err := Complete(text, `You are a powerful fine-tuner of Open Assistant, an open-source LLM. You will be given conversations between a user and the model, and you need to label the user's last reply and return a JSON string.You should evaluate the conversations based on the following criteria:
t, err := Complete(text, `You are a powerful fine-tuner of Open Assistant, an open-source LLM. You will be given conversations between a user and the model, and you need to label the user's last reply and return a JSON string. You should evaluate the conversations based on the following criteria:
- Spam: 0/1, whether the conversation contains spam / ads / porn / politics / etc.
- Not Appropriate: 0/1, whether the response is reasonable for the user's question
- pii: 0/1
Expand Down Expand Up @@ -206,6 +206,67 @@ You must return a JSON string, DO NOT include any other characters, DO NOT expla
return nil
}

func LabelInitialPrompt(id string, task model.OALabelInitialPromptTask) error {
logx.Infof("LabelInitialPrompt")
text := fmt.Sprintf("Prompt: %s\n\nUser's language code: %s", task.Prompt, model.Conf.Language)

t, err := Complete(text, `You are a powerful fine-tuner of Open Assistant, an open-source LLM. You will be given a prompt from the user, and you need to label it and return a JSON string. You should evaluate the prompt based on the following criteria:
- Spam: 0/1, whether the message contains spam / ads / porn / politics / etc.
- Not Appropriate: 0/1, whether the message is offensive / not respectful
- pii: 0/1
- Hate Speech: 0/1, whether the prompt is aggressive / not respectful
- Sexual Content: 0/1
- Quality: 0-1, step 0.25, how well the response is written respecting grammar, spelling, use of words, etc.
- Lang Mismatch: 0-1, step 0.25, whether the prompt is in the same language as the user's language
- Creativity: 0-1, step 0.25, how less is the prompt
- Humor: 0-1, step 0.25
- Toxicity: 0-1, step 0.25, how aggressive is the prompt
- Violence: 0-1, step 0.25
You must return a JSON string, DO NOT include any other characters, DO NOT explain. Use snake_case for the keys.`)

if err != nil {
return err
}

logx.Infof("LabelInitialPrompt: %s", t)
labels, err := GetLabelsFromChatGPT(t)
if err != nil {
return err
}

body := model.PostBodyLabelInitialPrompt{
Id: id,
Lang: model.Conf.Language,
UpdateType: "text_labels",
Content: model.PostBodyLabelInitialPromptContent{
Labels: labels,
Text: "unused?",
MessageId: task.MessageId,
},
}

resp, err := rty.R().
SetHeaders(map[string]string{
"Cookie": model.Conf.OaCookie,
"Content-Type": "application/json",
}).
SetBody(body).
Post("https://open-assistant.io/api/update_task")

if err != nil {
return err
}

respStr := string(resp.Body())
if respStr == "" {
logx.Infof("LabelInitialPrompt: OK!")
} else {
logx.Errorf("LabelInitialPrompt: %s", respStr)
}
return nil
}

func AssistantReply(id string, task model.OAAssistantReplyTask) error {
logx.Infof("AssistantReply")
text := ""
Expand Down Expand Up @@ -374,9 +435,20 @@ func StartTask() error {
}

var task model.OARandomTaskResponse
err = jsonx.Unmarshal(resp.Body(), &task)
if task.Task == nil {
logx.Infof("GetTasks: get task failed: %s", string(resp.Body()))
body := resp.Body()
err = jsonx.Unmarshal(body, &task)
if resp.StatusCode() == 403 {
logx.Errorf("GetTasks: cookie may have expired (403 Forbidden)")
logx.Errorf("Please login to https://open-assistant.io/dashboard and update your cookie in config.json")
return nil
} else if task.Task == nil {
var _json map[string]interface{}
_ = jsonx.Unmarshal(body, &_json)
if strings.Contains(_json["message"].(string), "No tasks") {
logx.Infof("GetTasks: no tasks at this time")
}
logx.Errorf("GetTasks: get task failed: %s", _json["message"])
logx.Errorf("Please check your network, or login to https://open-assistant.io/dashboard and update your cookie in config.json")
return nil
}

Expand All @@ -403,9 +475,16 @@ func StartTask() error {
var ch model.OARankAssistantRepliesTask
_ = jsonx.Unmarshal(j, &ch)
return RankAssistantReplies(task.Id, ch)
} else if t["type"] == "label_initial_prompt" {
var ch model.OALabelInitialPromptTask
_ = jsonx.Unmarshal(j, &ch)
return LabelInitialPrompt(task.Id, ch)
} else {
logx.Infof("GetTasks: unknown task type: %s", t["type"])
CancelTask(task.Id)
err = CancelTask(task.Id)
if err != nil {
return fmt.Errorf("GetTasks: cancel task failed: %s", err)
}
}

return nil
Expand Down

0 comments on commit f6d3543

Please sign in to comment.