Skip to content

Commit

Permalink
feat: add Gemini support. (#177)
Browse files Browse the repository at this point in the history
- Add Gemini support
- Add `gemini` package
- Add `gemini/func.go` file
- Add `gemini/gemini.go` file
- Add `gemini/options.go` file
- Update `go.mod` file
- Add `Float32Ptr` and `Int32Ptr` functions to `util/util.go` file

Signed-off-by: appleboy <appleboy.tw@gmail.com>
  • Loading branch information
appleboy committed Jun 2, 2024
1 parent f208110 commit 398e246
Show file tree
Hide file tree
Showing 7 changed files with 470 additions and 2 deletions.
14 changes: 13 additions & 1 deletion cmd/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"

"github.com/appleboy/CodeGPT/core"
"github.com/appleboy/CodeGPT/gemini"
"github.com/appleboy/CodeGPT/openai"

"github.com/spf13/viper"
Expand All @@ -30,11 +31,22 @@ func NewOpenAI() (*openai.Client, error) {
)
}

// NewGemini returns a new Gemini client
func NewGemini() (*gemini.Client, error) {
return gemini.New(
gemini.WithToken(viper.GetString("openai.api_key")),
gemini.WithModel(viper.GetString("openai.model")),
gemini.WithMaxTokens(viper.GetInt("openai.max_tokens")),
gemini.WithTemperature(float32(viper.GetFloat64("openai.temperature"))),
gemini.WithTopP(float32(viper.GetFloat64("openai.top_p"))),
)
}

// GetClient returns the generative client based on the platform
func GetClient(p core.Platform) (core.Generative, error) {
switch p {
case core.Gemini:
// TODO: implement Gemini
return NewGemini()
case core.OpenAI, core.Azure:
return NewOpenAI()
}
Expand Down
26 changes: 26 additions & 0 deletions gemini/func.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package gemini

import "github.com/google/generative-ai-go/genai"

var summaryPrefixFunc = &genai.Tool{
FunctionDeclarations: []*genai.FunctionDeclaration{{
Name: "get_summary_prefix",
Description: "Get a summary prefix using function call",
Parameters: &genai.Schema{
Type: genai.TypeObject,
Properties: map[string]*genai.Schema{
"prefix": {
Type: genai.TypeString,
Description: "The prefix to use for the summary",
Enum: []string{
"build", "chore", "ci",
"docs", "feat", "fix",
"perf", "refactor", "style",
"test",
},
},
},
Required: []string{"prefix"},
},
}},
}
114 changes: 114 additions & 0 deletions gemini/gemini.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package gemini

import (
"context"
"fmt"
"strings"

"github.com/appleboy/CodeGPT/core"
"github.com/appleboy/CodeGPT/util"

"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
)

type Client struct {
client *genai.GenerativeModel
model string
maxTokens int
temperature float32
topP float32
debug bool
}

// Completion is a method on the Client struct that takes a context.Context and a string argument
func (c *Client) Completion(ctx context.Context, content string) (*core.Response, error) {
resp, err := c.client.GenerateContent(ctx, genai.Text(content))
if err != nil {
return nil, err
}

var ret string

for _, cand := range resp.Candidates {
for _, part := range cand.Content.Parts {
ret += fmt.Sprintf("%v", part)
}
}

return &core.Response{
Content: ret,
Usage: core.Usage{
PromptTokens: int(resp.UsageMetadata.PromptTokenCount),
CompletionTokens: int(resp.UsageMetadata.CandidatesTokenCount),
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
},
}, nil
}

// GetSummaryPrefix is an API call to get a summary prefix using function call.
func (c *Client) GetSummaryPrefix(ctx context.Context, content string) (*core.Response, error) {
c.client.Tools = []*genai.Tool{summaryPrefixFunc}

// Start new chat session.
session := c.client.StartChat()

// Send the message to the generative model.
resp, err := session.SendMessage(ctx, genai.Text(content))
if err != nil {
return nil, err
}

part := resp.Candidates[0].Content.Parts[0]

r := &core.Response{
Content: strings.TrimSpace(strings.TrimSuffix(fmt.Sprintf("%v", part), "\n")),
Usage: core.Usage{
PromptTokens: int(resp.UsageMetadata.PromptTokenCount),
CompletionTokens: int(resp.UsageMetadata.CandidatesTokenCount),
TotalTokens: int(resp.UsageMetadata.TotalTokenCount),
},
}

if c.debug {
// Check that you got the expected function call back.
funcall, ok := part.(genai.FunctionCall)
if !ok {
return nil, fmt.Errorf("expected type FunctionCall, got %T", part)
}
if g, e := funcall.Name, summaryPrefixFunc.FunctionDeclarations[0].Name; g != e {
return nil, fmt.Errorf("expected FunctionCall.Name %q, got %q", e, g)
}
}

return r, nil
}

func New(opts ...Option) (c *Client, err error) {
// Create a new config object with the given options.
cfg := newConfig(opts...)

// Validate the config object, returning an error if it is invalid.
if err := cfg.valid(); err != nil {
return nil, err
}

// Create a new client instance with the necessary fields.
engine := &Client{
model: cfg.model,
maxTokens: cfg.maxTokens,
temperature: cfg.temperature,
}

client, err := genai.NewClient(context.Background(), option.WithAPIKey(cfg.token))
if err != nil {
return nil, err
}

engine.client = client.GenerativeModel(engine.model)
engine.client.MaxOutputTokens = util.Int32Ptr(int32(engine.maxTokens))
engine.client.Temperature = util.Float32Ptr(engine.temperature)
engine.client.TopP = util.Float32Ptr(engine.topP)

return engine, nil
}
123 changes: 123 additions & 0 deletions gemini/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package gemini

import (
"errors"
)

var (
errorsMissingToken = errors.New("missing gemini api key")
errorsMissingModel = errors.New("missing model")
)

const (
defaultMaxTokens = 300
defaultModel = "gemini-1.5-flash-latest"
defaultTemperature = 1.0
defaultTopP = 1.0
)

// Option is an interface that specifies instrumentation configuration options.
type Option interface {
apply(*config)
}

// optionFunc is a type of function that can be used to implement the Option interface.
// It takes a pointer to a config struct and modifies it.
type optionFunc func(*config)

// Ensure that optionFunc satisfies the Option interface.
var _ Option = (*optionFunc)(nil)

// The apply method of optionFunc type is implemented here to modify the config struct based on the function passed.
func (o optionFunc) apply(c *config) {
o(c)
}

// WithToken is a function that returns an Option, which sets the token field of the config struct.
func WithToken(val string) Option {
return optionFunc(func(c *config) {
c.token = val
})
}

// WithModel is a function that returns an Option, which sets the model field of the config struct.
func WithModel(val string) Option {
return optionFunc(func(c *config) {
c.model = val
})
}

// WithMaxTokens returns a new Option that sets the max tokens for the client configuration.
// The maximum number of tokens to generate in the chat completion.
// The total length of input tokens and generated tokens is limited by the model's context length.
func WithMaxTokens(val int) Option {
if val <= 0 {
val = defaultMaxTokens
}
return optionFunc(func(c *config) {
c.maxTokens = val
})
}

// WithTemperature returns a new Option that sets the temperature for the client configuration.
// What sampling temperature to use, between 0 and 2.
// Higher values like 0.8 will make the output more random,
// while lower values like 0.2 will make it more focused and deterministic.
func WithTemperature(val float32) Option {
if val <= 0 {
val = defaultTemperature
}
return optionFunc(func(c *config) {
c.temperature = val
})
}

// WithTopP returns a new Option that sets the topP for the client configuration.
func WithTopP(val float32) Option {
return optionFunc(func(c *config) {
c.topP = val
})
}

// config is a struct that stores configuration options for the instrumentation.
type config struct {
token string
model string
maxTokens int
temperature float32
topP float32
}

// valid checks whether a config object is valid, returning an error if it is not.
func (cfg *config) valid() error {
// Check that the token is not empty.
if cfg.token == "" {
return errorsMissingToken
}

if cfg.model == "" {
return errorsMissingModel
}

// If all checks pass, return nil (no error).
return nil
}

// newConfig creates a new config object with default values, and applies the given options.
func newConfig(opts ...Option) *config {
// Create a new config object with default values.
c := &config{
model: defaultModel,
maxTokens: defaultMaxTokens,
temperature: defaultTemperature,
topP: defaultTopP,
}

// Apply each of the given options to the config object.
for _, opt := range opts {
opt.apply(c)
}

// Return the resulting config object.
return c
}
36 changes: 35 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,21 +1,40 @@
module github.com/appleboy/CodeGPT

go 1.20
go 1.21

toolchain go1.22.2

require (
github.com/appleboy/com v0.1.7
github.com/appleboy/graceful v1.1.1
github.com/fatih/color v1.17.0
github.com/google/generative-ai-go v0.13.0
github.com/joho/godotenv v1.5.1
github.com/rodaine/table v1.2.0
github.com/sashabaranov/go-openai v1.24.0
github.com/spf13/cobra v1.8.0
github.com/spf13/viper v1.18.2
golang.org/x/net v0.25.0
google.golang.org/api v0.178.0
)

require (
cloud.google.com/go v0.113.0 // indirect
cloud.google.com/go/ai v0.5.0 // indirect
cloud.google.com/go/auth v0.4.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
cloud.google.com/go/compute/metadata v0.3.0 // indirect
cloud.google.com/go/longrunning v0.5.7 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.4 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/magiconair/properties v1.8.7 // indirect
Expand All @@ -30,10 +49,25 @@ require (
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel v1.26.0 // indirect
go.opentelemetry.io/otel/metric v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.23.0 // indirect
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
golang.org/x/oauth2 v0.20.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.20.0 // indirect
golang.org/x/text v0.15.0 // indirect
golang.org/x/time v0.5.0 // indirect
google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240506185236-b8a5c65736ae // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240506185236-b8a5c65736ae // indirect
google.golang.org/grpc v1.63.2 // indirect
google.golang.org/protobuf v1.34.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Loading

0 comments on commit 398e246

Please sign in to comment.