Skip to content

Commit

Permalink
refactor(provider): handling in config system (#163)
Browse files Browse the repository at this point in the history
- Update condition check to use `String()` method for `openai.provider` in `commit.go`
- Replace constants `OPENAI` and `AZURE` with `Provider` type and add `String()` and `IsValid()` methods in `options.go`
- Modify `WithProvider` function to use `Provider` type and validate with `IsValid()` method in `options.go`
- Change `provider` field in `config` struct from `string` to `Provider` type in `options.go`
- Update `WithProvider` calls in `options_test.go` to use `String()` method for `Provider` type

Signed-off-by: appleboy <appleboy.tw@gmail.com>
  • Loading branch information
appleboy committed May 12, 2024
1 parent e8b58c4 commit 967ab5a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 20 deletions.
2 changes: 1 addition & 1 deletion cmd/commit.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ var commitCmd = &cobra.Command{
}

currentModel := viper.GetString("openai.model")
if viper.GetString("openai.provider") == openai.AZURE {
if viper.GetString("openai.provider") == openai.AZURE.String() {
currentModel = viper.GetString("openai.model_name")
}

Expand Down
41 changes: 25 additions & 16 deletions openai/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,30 @@ var (
errorsMissingAzureModel = errors.New("missing Azure deployments model name")
)

const (
OPENAI = "openai"
AZURE = "azure"
type Provider string

func (p Provider) String() string {
return string(p)
}

func (p Provider) IsValid() bool {
switch p {
case OPENAI, AZURE:
return true
default:
return false
}
}

var (
OPENAI Provider = "openai"
AZURE Provider = "azure"
)

const (
defaultMaxTokens = 300
defaultModel = openai.GPT3Dot5Turbo
defaultTemperature = 1.0
defaultProvider = OPENAI
defaultTopP = 1.0
)

Expand Down Expand Up @@ -121,20 +135,15 @@ func WithTemperature(val float32) Option {
})
}

// WithProvider sets the `provider` variable based on the value of the `val` parameter.
// If `val` is not set to `OPENAI` or `AZURE`, it will be set to the default value `defaultProvider`.
// This function returns an `Option` object.
// WithProvider returns a new Option that sets the provider for the client configuration.
func WithProvider(val string) Option {
// Check if `val` is set to `OPENAI` or `AZURE`. If not, set it to the default value.
switch val {
case OPENAI, AZURE:
default:
val = defaultProvider
provider := Provider(val)
if !provider.IsValid() {
provider = OPENAI
}

// Return an `optionFunc` object with `c.provider` set to `val`.
return optionFunc(func(c *config) {
c.provider = val
c.provider = provider
})
}

Expand Down Expand Up @@ -205,7 +214,7 @@ type config struct {
presencePenalty float32
frequencyPenalty float32

provider string
provider Provider
modelName string
skipVerify bool
headers []string
Expand Down Expand Up @@ -241,7 +250,7 @@ func newConfig(opts ...Option) *config {
model: defaultModel,
maxTokens: defaultMaxTokens,
temperature: defaultTemperature,
provider: defaultProvider,
provider: OPENAI,
topP: defaultTopP,
}

Expand Down
6 changes: 3 additions & 3 deletions openai/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func Test_config_valid(t *testing.T) {
cfg: newConfig(
WithToken("test"),
WithModel(openai.GPT3Dot5Turbo),
WithProvider(OPENAI),
WithProvider(OPENAI.String()),
),
wantErr: nil,
},
Expand All @@ -31,7 +31,7 @@ func Test_config_valid(t *testing.T) {
cfg: newConfig(
WithToken("test"),
WithModel("test"),
WithProvider(OPENAI),
WithProvider(OPENAI.String()),
),
wantErr: errorsMissingModel,
},
Expand All @@ -40,7 +40,7 @@ func Test_config_valid(t *testing.T) {
cfg: newConfig(
WithToken("test"),
WithModel(openai.GPT3Dot5Turbo),
WithProvider(AZURE),
WithProvider(AZURE.String()),
),
wantErr: errorsMissingAzureModel,
},
Expand Down

0 comments on commit 967ab5a

Please sign in to comment.