Skip to content

Commit

Permalink
refactor: provider handling and core interfaces (#173)
Browse files Browse the repository at this point in the history
- Add error handling for invalid provider in `commit.go`
- Replace `NewOpenAI` initialization with a switch case for different providers
- Simplify the summary prefix retrieval logic in `commit.go`
- Create a new `core/openai.go` file defining `Usage`, `Response`, and `Generative` interface
- Create a new `core/platform.go` file defining `Platform` type and its methods
- Implement `Completion` and `GetSummaryPrefix` methods in `openai.go` for the `core.Generative` interface
- Refactor `Completion` method to `completion` in `openai.go`

Signed-off-by: appleboy <appleboy.tw@gmail.com>
  • Loading branch information
appleboy authored Jun 1, 2024
1 parent 976e75d commit 6a7b4b2
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 19 deletions.
39 changes: 21 additions & 18 deletions cmd/commit.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
package cmd

import (
"errors"
"html"
"log"
"os"
"path"
"strconv"
"strings"
"time"

"github.com/appleboy/CodeGPT/core"
"github.com/appleboy/CodeGPT/git"
"github.com/appleboy/CodeGPT/openai"
"github.com/appleboy/CodeGPT/prompt"
"github.com/appleboy/CodeGPT/util"

Expand Down Expand Up @@ -88,13 +88,26 @@ var commitCmd = &cobra.Command{
viper.Set("openai.timeout", timeout)
}

currentModel := viper.GetString("openai.model")
color.Green("Summarize the commit message use " + currentModel + " model")
client, err := NewOpenAI()
// check provider
provider := core.Platform(viper.GetString("openai.provider"))
if !provider.IsValid() {
return errors.New("invalid provider")
}

var client core.Generative
switch provider {
case core.Gemini:
// TODO: implement Gemini
case core.OpenAI, core.Azure:
client, err = NewOpenAI()
}
if err != nil && !promptOnly {
return err
}

currentModel := viper.GetString("openai.model")
color.Green("Summarize the commit message use " + currentModel + " model")

data := util.Data{}
// add template vars
if vars := util.ConvertToMap(templateVars); len(vars) > 0 {
Expand Down Expand Up @@ -189,21 +202,11 @@ var commitCmd = &cobra.Command{
message := "We are trying to get conventional commit prefix"
summaryPrix := ""
color.Cyan(message + " (Tools)")
resp, err := client.CreateFunctionCall(cmd.Context(), out, openai.SummaryPrefixFunc)
if err != nil || len(resp.Choices) != 1 {
log.Printf("Completion error: err:%v len(choices):%v\n", err,
len(resp.Choices))
resp, err := client.GetSummaryPrefix(cmd.Context(), out)
if err != nil {
return err
}

msg := resp.Choices[0].Message
if len(msg.ToolCalls) == 0 {
color.Red("No tool calls found in the message")
summaryPrix = msg.Content
} else {
args := openai.GetSummaryPrefixArgs(msg.ToolCalls[len(msg.ToolCalls)-1].Function.Arguments)
summaryPrix = args.Prefix
}
summaryPrix = resp.Content

color.Magenta("PromptTokens: " + strconv.Itoa(resp.Usage.PromptTokens) +
", CompletionTokens: " + strconv.Itoa(resp.Usage.CompletionTokens) +
Expand Down
23 changes: 23 additions & 0 deletions core/openai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package core

import (
"context"
)

type Usage struct {
PromptTokens int
CompletionTokens int
TotalTokens int
}

type Response struct {
Content string
Usage Usage
}

type Generative interface {
// CreateCompletion is an API call to create a completion.
Completion(ctx context.Context, content string) (resp *Response, err error)
// GetSummaryPrefix is an API call to get a summary prefix using function call.
GetSummaryPrefix(ctx context.Context, content string) (resp *Response, err error)
}
23 changes: 23 additions & 0 deletions core/platform.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package core

type Platform string

const (
OpenAI Platform = "openai"
Azure Platform = "azure"
Gemini Platform = "gemini"
)

// String returns the string representation of the Platform.
func (p Platform) String() string {
return string(p)
}

// IsValid returns true if the Platform is valid.
func (p Platform) IsValid() bool {
switch p {
case OpenAI, Azure, Gemini:
return true
}
return false
}
50 changes: 49 additions & 1 deletion openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ import (
"net/http"
"net/url"

"github.com/appleboy/CodeGPT/core"

openai "github.com/sashabaranov/go-openai"
"golang.org/x/net/proxy"
)

// DefaultModel is the default OpenAI model to use if one is not provided.
var DefaultModel = openai.GPT3Dot5Turbo

var _ core.Generative = (*Client)(nil)

// Client is a struct that represents an OpenAI client.
type Client struct {
client *openai.Client
Expand All @@ -40,6 +44,50 @@ type Response struct {
Usage openai.Usage
}

// Completion is a method on the Client struct that takes a context.Context and a string argument
func (c *Client) Completion(ctx context.Context, content string) (*core.Response, error) {
resp, err := c.completion(ctx, content)
if err != nil {
return nil, err
}

return &core.Response{
Content: resp.Content,
Usage: core.Usage{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
},
}, nil
}

// GetSummaryPrefix is an API call to get a summary prefix using function call.
func (c *Client) GetSummaryPrefix(ctx context.Context, content string) (*core.Response, error) {
resp, err := c.CreateFunctionCall(ctx, content, SummaryPrefixFunc)
if err != nil || len(resp.Choices) != 1 {
return nil, err
}

msg := resp.Choices[0].Message
usage := core.Usage{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}
if len(msg.ToolCalls) == 0 {
return &core.Response{
Content: msg.Content,
Usage: usage,
}, nil
}

args := GetSummaryPrefixArgs(msg.ToolCalls[len(msg.ToolCalls)-1].Function.Arguments)
return &core.Response{
Content: args.Prefix,
Usage: usage,
}, nil
}

// CreateChatCompletion is an API call to create a function call for a chat message.
func (c *Client) CreateFunctionCall(
ctx context.Context,
Expand Down Expand Up @@ -109,7 +157,7 @@ func (c *Client) CreateChatCompletion(

// Completion is a method on the Client struct that takes a context.Context and a string argument
// and returns a string and an error.
func (c *Client) Completion(
func (c *Client) completion(
ctx context.Context,
content string,
) (*Response, error) {
Expand Down

0 comments on commit 6a7b4b2

Please sign in to comment.