Skip to content

Commit

Permalink
feat(openrouter): refactor model handling and error logic (#164)
Browse files Browse the repository at this point in the history
- Update model selection logic to use `IsCustomModel` method for determining the current model in `commit.go` and `review.go`.
- Remove conditional function call logic in `commit.go`.
- Simplify error handling and remove specific model checks in `openai.go`.
- Introduce handling for custom models and update error messages in `options.go` and `options_test.go`.
- Add support for a new provider `OPENROUTER` in `options.go`.

Signed-off-by: appleboy <appleboy.tw@gmail.com>
  • Loading branch information
appleboy committed May 12, 2024
1 parent 967ab5a commit 7b8f448
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 146 deletions.
54 changes: 21 additions & 33 deletions cmd/commit.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ var commitCmd = &cobra.Command{
}

currentModel := viper.GetString("openai.model")
if viper.GetString("openai.provider") == openai.AZURE.String() {
if openai.Provider(viper.GetString("openai.provider")).IsCustomModel() {
currentModel = viper.GetString("openai.model_name")
}

Expand Down Expand Up @@ -206,40 +206,28 @@ var commitCmd = &cobra.Command{
}
message := "We are trying to get conventional commit prefix"
summaryPrix := ""
if client.AllowFuncCall() {
color.Cyan(message + " (Tools)")
resp, err := client.CreateFunctionCall(cmd.Context(), out, openai.SummaryPrefixFunc)
if err != nil || len(resp.Choices) != 1 {
log.Printf("Completion error: err:%v len(choices):%v\n", err,
len(resp.Choices))
return err
}

msg := resp.Choices[0].Message
if len(msg.ToolCalls) == 0 {
color.Red("No tool calls found in the message")
summaryPrix = msg.Content
} else {
args := openai.GetSummaryPrefixArgs(msg.ToolCalls[len(msg.ToolCalls)-1].Function.Arguments)
summaryPrix = args.Prefix
}

color.Magenta("PromptTokens: " + strconv.Itoa(resp.Usage.PromptTokens) +
", CompletionTokens: " + strconv.Itoa(resp.Usage.CompletionTokens) +
", TotalTokens: " + strconv.Itoa(resp.Usage.TotalTokens),
)
color.Cyan(message + " (Tools)")
resp, err := client.CreateFunctionCall(cmd.Context(), out, openai.SummaryPrefixFunc)
if err != nil || len(resp.Choices) != 1 {
log.Printf("Completion error: err:%v len(choices):%v\n", err,
len(resp.Choices))
return err
}

msg := resp.Choices[0].Message
if len(msg.ToolCalls) == 0 {
color.Red("No tool calls found in the message")
summaryPrix = msg.Content
} else {
color.Cyan(message)
resp, err := client.Completion(cmd.Context(), out)
if err != nil {
return err
}
summaryPrix = strings.TrimSpace(resp.Content)
color.Magenta("PromptTokens: " + strconv.Itoa(resp.Usage.PromptTokens) +
", CompletionTokens: " + strconv.Itoa(resp.Usage.CompletionTokens) +
", TotalTokens: " + strconv.Itoa(resp.Usage.TotalTokens),
)
args := openai.GetSummaryPrefixArgs(msg.ToolCalls[len(msg.ToolCalls)-1].Function.Arguments)
summaryPrix = args.Prefix
}

color.Magenta("PromptTokens: " + strconv.Itoa(resp.Usage.PromptTokens) +
", CompletionTokens: " + strconv.Itoa(resp.Usage.CompletionTokens) +
", TotalTokens: " + strconv.Itoa(resp.Usage.TotalTokens),
)

data[prompt.SummarizePrefixKey] = summaryPrix
}

Expand Down
13 changes: 12 additions & 1 deletion cmd/review.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,18 @@ var reviewCmd = &cobra.Command{
return err
}

color.Green("Code review your changes using " + viper.GetString("openai.model") + " model")
// Update the OpenAI client request timeout if the timeout value is greater than the default openai.timeout
if timeout > viper.GetDuration("openai.timeout") ||
timeout != defaultTimeout {
viper.Set("openai.timeout", timeout)
}

currentModel := viper.GetString("openai.model")
if openai.Provider(viper.GetString("openai.provider")).IsCustomModel() {
currentModel = viper.GetString("openai.model_name")
}

color.Green("Code review your changes using " + currentModel + " model")
client, err := openai.New(
openai.WithToken(viper.GetString("openai.api_key")),
openai.WithModel(viper.GetString("openai.model")),
Expand Down
112 changes: 9 additions & 103 deletions openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ type Client struct {
model string
maxTokens int
temperature float32
isFuncCall bool

// An alternative to sampling with temperature, called nucleus sampling,
// where the model considers the results of the tokens with top_p probability mass.
Expand Down Expand Up @@ -156,74 +155,19 @@ func (c *Client) CreateChatCompletion(
return c.client.CreateChatCompletion(ctx, req)
}

// CreateCompletion is an API call to create a completion.
// This is the main endpoint of the API. It returns new text, as well as, if requested,
// the probabilities over each alternative token at each position.
//
// If using a fine-tuned model, simply provide the model's ID in the CompletionRequest object,
// and the server will use the model's parameters to generate the completion.
func (c *Client) CreateCompletion(
ctx context.Context,
content string,
) (resp openai.CompletionResponse, err error) {
req := openai.CompletionRequest{
Model: c.model,
MaxTokens: c.maxTokens,
Temperature: c.temperature,
TopP: c.topP,
FrequencyPenalty: c.frequencyPenalty,
PresencePenalty: c.presencePenalty,
Prompt: content,
}

return c.client.CreateCompletion(ctx, req)
}

// Completion is a method on the Client struct that takes a context.Context and a string argument
// and returns a string and an error.
func (c *Client) Completion(
ctx context.Context,
content string,
) (*Response, error) {
resp := &Response{}
switch c.model {
case openai.GPT3Dot5Turbo,
openai.GPT3Dot5Turbo0301,
openai.GPT3Dot5Turbo0613,
openai.GPT3Dot5Turbo16K,
openai.GPT3Dot5Turbo16K0613,
openai.GPT3Dot5Turbo1106,
openai.GPT3Dot5Turbo0125,
openai.GPT4,
openai.GPT40314,
openai.GPT40613,
openai.GPT432K,
openai.GPT432K0314,
openai.GPT432K0613,
openai.GPT4Turbo1106,
openai.GPT4Turbo0125,
openai.GPT4TurboPreview,
openai.GPT4VisionPreview,
openai.GPT4Turbo,
openai.GPT4Turbo20240409,
groq.LLaMA38b.String(),
groq.LLaMA370b.String(),
groq.Mixtral8x7b.String(),
groq.Gemma7b.String():
r, err := c.CreateChatCompletion(ctx, content)
if err != nil {
return nil, err
}
resp.Content = r.Choices[0].Message.Content
resp.Usage = r.Usage
default:
r, err := c.CreateCompletion(ctx, content)
if err != nil {
return nil, err
}
resp.Content = r.Choices[0].Text
resp.Usage = r.Usage
r, err := c.CreateChatCompletion(ctx, content)
if err != nil {
return nil, err
}
resp.Content = r.Choices[0].Message.Content
resp.Usage = r.Usage
return resp, nil
}

Expand Down Expand Up @@ -302,51 +246,13 @@ func New(opts ...Option) (*Client, error) {
if cfg.apiVersion != "" {
c.APIVersion = cfg.apiVersion
}

if cfg.provider.IsCustomModel() {
engine.model = cfg.modelName
}
engine.client = openai.NewClientWithConfig(c)
}

engine.isFuncCall = engine.allowFuncCall(cfg)

// Return the resulting client engine.
return engine, nil
}

// allowFuncCall returns true if the model supports function calls.
// https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/function-calling
// https://platform.openai.com/docs/guides/function-calling/supported-models
// Not all model versions are trained with function calling data.
// Function calling is supported with the following models:
// gpt-4, gpt-4-turbo-preview, gpt-4-0125-preview, gpt-4-1106-preview, gpt-4-0613,
// gpt-3.5-turbo, gpt-3.5-turbo-0125, gpt-3.5-turbo-1106, and gpt-3.5-turbo-0613
// In addition, parallel function calls is supported on the following models:
// gpt-4-turbo-preview, gpt-4-0125-preview, gpt-4-1106-preview,
// gpt-3.5-turbo-0125, and gpt-3.5-turbo-1106
func (c *Client) allowFuncCall(cfg *config) bool {
if cfg.provider == AZURE && cfg.apiVersion == "2023-07-01-preview" {
return true
}

switch c.model {
case openai.GPT4Turbo,
openai.GPT4Turbo20240409,
openai.GPT4TurboPreview,
openai.GPT4Turbo0125,
openai.GPT4Turbo1106,
openai.GPT40613,
openai.GPT3Dot5Turbo,
openai.GPT3Dot5Turbo0125,
openai.GPT3Dot5Turbo0613,
openai.GPT3Dot5Turbo1106,
groq.LLaMA38b.String():
return true
default:
return false
}
}

// AllowFuncCall returns true if the model supports function calls.
// In an API call, you can describe functions to gpt-3.5-turbo-0613 and gpt-4-0613
// https://platform.openai.com/docs/guides/gpt/chat-completions-api
func (c *Client) AllowFuncCall() bool {
return c.isFuncCall
}
21 changes: 13 additions & 8 deletions openai/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import (
)

var (
errorsMissingToken = errors.New("please set OPENAI_API_KEY environment variable")
errorsMissingModel = errors.New("missing model")
errorsMissingAzureModel = errors.New("missing Azure deployments model name")
errorsMissingToken = errors.New("please set OPENAI_API_KEY environment variable")
errorsMissingModel = errors.New("missing model")
errorsMissingCustomModel = errors.New("missing custom model name")
)

type Provider string
Expand All @@ -21,16 +21,21 @@ func (p Provider) String() string {

func (p Provider) IsValid() bool {
switch p {
case OPENAI, AZURE:
case OPENAI, AZURE, OPENROUTER:
return true
default:
return false
}
}

func (p Provider) IsCustomModel() bool {
return p != OPENAI
}

var (
OPENAI Provider = "openai"
AZURE Provider = "azure"
OPENAI Provider = "openai"
AZURE Provider = "azure"
OPENROUTER Provider = "openrouter"
)

const (
Expand Down Expand Up @@ -235,8 +240,8 @@ func (cfg *config) valid() error {
}

// If the provider is Azure, check that the model name is not empty.
if cfg.provider == AZURE && cfg.modelName == "" {
return errorsMissingAzureModel
if cfg.provider.IsCustomModel() && cfg.modelName == "" {
return errorsMissingCustomModel
}

// If all checks pass, return nil (no error).
Expand Down
11 changes: 10 additions & 1 deletion openai/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,16 @@ func Test_config_valid(t *testing.T) {
WithModel(openai.GPT3Dot5Turbo),
WithProvider(AZURE.String()),
),
wantErr: errorsMissingAzureModel,
wantErr: errorsMissingCustomModel,
},
{
name: "missing OpenRouter Custom model",
cfg: newConfig(
WithToken("test"),
WithModel(openai.GPT3Dot5Turbo),
WithProvider(OPENROUTER.String()),
),
wantErr: errorsMissingCustomModel,
},
}
for _, tt := range tests {
Expand Down

0 comments on commit 7b8f448

Please sign in to comment.