-
-
Notifications
You must be signed in to change notification settings - Fork 101
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Add Gemini support - Add `gemini` package - Add `gemini/func.go` file - Add `gemini/gemini.go` file - Add `gemini/options.go` file - Update `go.mod` file - Add `Float32Ptr` and `Int32Ptr` functions to `util/util.go` file Signed-off-by: appleboy <appleboy.tw@gmail.com>
- Loading branch information
Showing
7 changed files
with
470 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
package gemini | ||
|
||
import "github.com/google/generative-ai-go/genai" | ||
|
||
var summaryPrefixFunc = &genai.Tool{ | ||
FunctionDeclarations: []*genai.FunctionDeclaration{{ | ||
Name: "get_summary_prefix", | ||
Description: "Get a summary prefix using function call", | ||
Parameters: &genai.Schema{ | ||
Type: genai.TypeObject, | ||
Properties: map[string]*genai.Schema{ | ||
"prefix": { | ||
Type: genai.TypeString, | ||
Description: "The prefix to use for the summary", | ||
Enum: []string{ | ||
"build", "chore", "ci", | ||
"docs", "feat", "fix", | ||
"perf", "refactor", "style", | ||
"test", | ||
}, | ||
}, | ||
}, | ||
Required: []string{"prefix"}, | ||
}, | ||
}}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
package gemini | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"strings" | ||
|
||
"github.com/appleboy/CodeGPT/core" | ||
"github.com/appleboy/CodeGPT/util" | ||
|
||
"github.com/google/generative-ai-go/genai" | ||
"google.golang.org/api/option" | ||
) | ||
|
||
type Client struct { | ||
client *genai.GenerativeModel | ||
model string | ||
maxTokens int | ||
temperature float32 | ||
topP float32 | ||
debug bool | ||
} | ||
|
||
// Completion is a method on the Client struct that takes a context.Context and a string argument | ||
func (c *Client) Completion(ctx context.Context, content string) (*core.Response, error) { | ||
resp, err := c.client.GenerateContent(ctx, genai.Text(content)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
var ret string | ||
|
||
for _, cand := range resp.Candidates { | ||
for _, part := range cand.Content.Parts { | ||
ret += fmt.Sprintf("%v", part) | ||
} | ||
} | ||
|
||
return &core.Response{ | ||
Content: ret, | ||
Usage: core.Usage{ | ||
PromptTokens: int(resp.UsageMetadata.PromptTokenCount), | ||
CompletionTokens: int(resp.UsageMetadata.CandidatesTokenCount), | ||
TotalTokens: int(resp.UsageMetadata.TotalTokenCount), | ||
}, | ||
}, nil | ||
} | ||
|
||
// GetSummaryPrefix is an API call to get a summary prefix using function call. | ||
func (c *Client) GetSummaryPrefix(ctx context.Context, content string) (*core.Response, error) { | ||
c.client.Tools = []*genai.Tool{summaryPrefixFunc} | ||
|
||
// Start new chat session. | ||
session := c.client.StartChat() | ||
|
||
// Send the message to the generative model. | ||
resp, err := session.SendMessage(ctx, genai.Text(content)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
part := resp.Candidates[0].Content.Parts[0] | ||
|
||
r := &core.Response{ | ||
Content: strings.TrimSpace(strings.TrimSuffix(fmt.Sprintf("%v", part), "\n")), | ||
Usage: core.Usage{ | ||
PromptTokens: int(resp.UsageMetadata.PromptTokenCount), | ||
CompletionTokens: int(resp.UsageMetadata.CandidatesTokenCount), | ||
TotalTokens: int(resp.UsageMetadata.TotalTokenCount), | ||
}, | ||
} | ||
|
||
if c.debug { | ||
// Check that you got the expected function call back. | ||
funcall, ok := part.(genai.FunctionCall) | ||
if !ok { | ||
return nil, fmt.Errorf("expected type FunctionCall, got %T", part) | ||
} | ||
if g, e := funcall.Name, summaryPrefixFunc.FunctionDeclarations[0].Name; g != e { | ||
return nil, fmt.Errorf("expected FunctionCall.Name %q, got %q", e, g) | ||
} | ||
} | ||
|
||
return r, nil | ||
} | ||
|
||
func New(opts ...Option) (c *Client, err error) { | ||
// Create a new config object with the given options. | ||
cfg := newConfig(opts...) | ||
|
||
// Validate the config object, returning an error if it is invalid. | ||
if err := cfg.valid(); err != nil { | ||
return nil, err | ||
} | ||
|
||
// Create a new client instance with the necessary fields. | ||
engine := &Client{ | ||
model: cfg.model, | ||
maxTokens: cfg.maxTokens, | ||
temperature: cfg.temperature, | ||
} | ||
|
||
client, err := genai.NewClient(context.Background(), option.WithAPIKey(cfg.token)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
engine.client = client.GenerativeModel(engine.model) | ||
engine.client.MaxOutputTokens = util.Int32Ptr(int32(engine.maxTokens)) | ||
engine.client.Temperature = util.Float32Ptr(engine.temperature) | ||
engine.client.TopP = util.Float32Ptr(engine.topP) | ||
|
||
return engine, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
package gemini | ||
|
||
import ( | ||
"errors" | ||
) | ||
|
||
var ( | ||
errorsMissingToken = errors.New("missing gemini api key") | ||
errorsMissingModel = errors.New("missing model") | ||
) | ||
|
||
const ( | ||
defaultMaxTokens = 300 | ||
defaultModel = "gemini-1.5-flash-latest" | ||
defaultTemperature = 1.0 | ||
defaultTopP = 1.0 | ||
) | ||
|
||
// Option is an interface that specifies instrumentation configuration options. | ||
type Option interface { | ||
apply(*config) | ||
} | ||
|
||
// optionFunc is a type of function that can be used to implement the Option interface. | ||
// It takes a pointer to a config struct and modifies it. | ||
type optionFunc func(*config) | ||
|
||
// Ensure that optionFunc satisfies the Option interface. | ||
var _ Option = (*optionFunc)(nil) | ||
|
||
// The apply method of optionFunc type is implemented here to modify the config struct based on the function passed. | ||
func (o optionFunc) apply(c *config) { | ||
o(c) | ||
} | ||
|
||
// WithToken is a function that returns an Option, which sets the token field of the config struct. | ||
func WithToken(val string) Option { | ||
return optionFunc(func(c *config) { | ||
c.token = val | ||
}) | ||
} | ||
|
||
// WithModel is a function that returns an Option, which sets the model field of the config struct. | ||
func WithModel(val string) Option { | ||
return optionFunc(func(c *config) { | ||
c.model = val | ||
}) | ||
} | ||
|
||
// WithMaxTokens returns a new Option that sets the max tokens for the client configuration. | ||
// The maximum number of tokens to generate in the chat completion. | ||
// The total length of input tokens and generated tokens is limited by the model's context length. | ||
func WithMaxTokens(val int) Option { | ||
if val <= 0 { | ||
val = defaultMaxTokens | ||
} | ||
return optionFunc(func(c *config) { | ||
c.maxTokens = val | ||
}) | ||
} | ||
|
||
// WithTemperature returns a new Option that sets the temperature for the client configuration. | ||
// What sampling temperature to use, between 0 and 2. | ||
// Higher values like 0.8 will make the output more random, | ||
// while lower values like 0.2 will make it more focused and deterministic. | ||
func WithTemperature(val float32) Option { | ||
if val <= 0 { | ||
val = defaultTemperature | ||
} | ||
return optionFunc(func(c *config) { | ||
c.temperature = val | ||
}) | ||
} | ||
|
||
// WithTopP returns a new Option that sets the topP for the client configuration. | ||
func WithTopP(val float32) Option { | ||
return optionFunc(func(c *config) { | ||
c.topP = val | ||
}) | ||
} | ||
|
||
// config is a struct that stores configuration options for the instrumentation. | ||
type config struct { | ||
token string | ||
model string | ||
maxTokens int | ||
temperature float32 | ||
topP float32 | ||
} | ||
|
||
// valid checks whether a config object is valid, returning an error if it is not. | ||
func (cfg *config) valid() error { | ||
// Check that the token is not empty. | ||
if cfg.token == "" { | ||
return errorsMissingToken | ||
} | ||
|
||
if cfg.model == "" { | ||
return errorsMissingModel | ||
} | ||
|
||
// If all checks pass, return nil (no error). | ||
return nil | ||
} | ||
|
||
// newConfig creates a new config object with default values, and applies the given options. | ||
func newConfig(opts ...Option) *config { | ||
// Create a new config object with default values. | ||
c := &config{ | ||
model: defaultModel, | ||
maxTokens: defaultMaxTokens, | ||
temperature: defaultTemperature, | ||
topP: defaultTopP, | ||
} | ||
|
||
// Apply each of the given options to the config object. | ||
for _, opt := range opts { | ||
opt.apply(c) | ||
} | ||
|
||
// Return the resulting config object. | ||
return c | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.