Skip to content

Commit

Permalink
feat(groq): update codebase to support new AI models (#155)
Browse files Browse the repository at this point in the history
- Update model references from `LLaMA2-70b-chat`, `Mixtral-8x7b-Instruct-v0.1`, and `Gemma-7b-it` to `llama3-8b-8192`, `llama3-70b-8192`, `mixtral-8x7b-32768`, and `gemma-7b-it` in README.md
- Replace old model constants with new ones in `groq/model.go`
- Remove deprecated `GetModel` function from `groq/model.go`
- Update model validation cases to use new model constants in `groq/model.go`
- Update model mappings in `openai/openai.go` to reflect new model constants
- Adjust completion function in `openai/openai.go` to use new model strings
- Add new models to the allowed function call list in `openai/openai.go`

Signed-off-by: Bo-Yi Wu <appleboy.tw@gmail.com>
  • Loading branch information
appleboy committed Apr 27, 2024
1 parent e6a5a52 commit 15c4919
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 61 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,15 @@ Please get the `API key` from Groq API Service, please vist [here][31]. Update t
codegpt config set openai.provider openai
codegpt config set openai.base_url https://api.groq.com/openai/v1
codegpt config set openai.api_key gsk_xxxxxxxxxxxxxx
codegpt config set openai.model LLaMA2-70b-chat
codegpt config set openai.model llama3-8b-8192
```

Support the [following models][32]:

1. LLaMA2-70b-chat (Meta) **recommended**
2. Mixtral-8x7b-Instruct-v0.1 (Mistral)
3. Gemma-7b-it (Google)
1. `llama3-8b-8192` (Meta) **recommended**
2. `llama3-70b-8192` (Meta)
3. `mixtral-8x7b-32768` (Mistral)
4. `gemma-7b-it` (Google)

[30]: https://groq.com/
[31]: https://console.groq.com/keys
Expand Down
27 changes: 5 additions & 22 deletions groq/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,21 @@ package groq
type Model string

const (
LLaMA270bChat Model = "LLaMA2-70b-chat"
Mixtral8x7bInstructV01 Model = "Mixtral-8x7b-Instruct-v0.1"
Gemma7bIt Model = "Gemma-7b-it"
LLaMA38b Model = "llama3-8b-8192" //
LLaMA370b Model = "llama3-70b-8192"
Mixtral8x7b Model = "mixtral-8x7b-32768"
Gemma7b Model = "gemma-7b-it"
)

func (m Model) String() string {
return string(m)
}

func (m Model) GetModel() string {
return GetModel(m)
}

func (m Model) IsVaild() bool {
switch m {
case LLaMA270bChat, Mixtral8x7bInstructV01, Gemma7bIt:
case LLaMA38b, LLaMA370b, Mixtral8x7b, Gemma7b:
return true
default:
return false
}
}

var model = map[Model]string{
LLaMA270bChat: "llama2-70b-4096",
Mixtral8x7bInstructV01: "mixtral-8x7b-32768",
Gemma7bIt: "gemma-7b-it",
}

// GetModel returns the model name.
func GetModel(modelName Model) string {
if _, ok := model[modelName]; !ok {
return model[LLaMA270bChat]
}
return model[modelName]
}
76 changes: 41 additions & 35 deletions openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,38 @@ var DefaultModel = openai.GPT3Dot5Turbo

// modelMaps maps model names to their corresponding model ID strings.
var modelMaps = map[string]string{
"gpt-4-32k-0613": openai.GPT432K0613,
"gpt-4-32k-0314": openai.GPT432K0314,
"gpt-4-32k": openai.GPT432K,
"gpt-4-0613": openai.GPT40613,
"gpt-4-0314": openai.GPT40314,
"gpt-4-turbo": openai.GPT4Turbo,
"gpt-4-turbo-2024-04-09": openai.GPT4Turbo20240409,
"gpt-4-0125-preview": openai.GPT4Turbo0125,
"gpt-4-1106-preview": openai.GPT4Turbo1106,
"gpt-4-turbo-preview": openai.GPT4TurboPreview,
"gpt-4-vision-preview": openai.GPT4VisionPreview,
"gpt-4": openai.GPT4,
"gpt-3.5-turbo-0125": openai.GPT3Dot5Turbo0125,
"gpt-3.5-turbo-1106": openai.GPT3Dot5Turbo1106,
"gpt-3.5-turbo-0613": openai.GPT3Dot5Turbo0613,
"gpt-3.5-turbo-0301": openai.GPT3Dot5Turbo0301,
"gpt-3.5-turbo-16k": openai.GPT3Dot5Turbo16K,
"gpt-3.5-turbo-16k-0613": openai.GPT3Dot5Turbo16K0613,
"gpt-3.5-turbo": openai.GPT3Dot5Turbo,
"gpt-3.5-turbo-instruct": openai.GPT3Dot5TurboInstruct,
"davinci": openai.GPT3Davinci,
"davinci-002": openai.GPT3Davinci002,
"curie": openai.GPT3Curie,
"curie-002": openai.GPT3Curie002,
"ada": openai.GPT3Ada,
"ada-002": openai.GPT3Ada002,
"babbage": openai.GPT3Babbage,
"babbage-002": openai.GPT3Babbage002,
groq.LLaMA270bChat.String(): groq.LLaMA270bChat.GetModel(),
groq.Mixtral8x7bInstructV01.String(): groq.Mixtral8x7bInstructV01.GetModel(),
groq.Gemma7bIt.String(): groq.Gemma7bIt.GetModel(),
"gpt-4-32k-0613": openai.GPT432K0613,
"gpt-4-32k-0314": openai.GPT432K0314,
"gpt-4-32k": openai.GPT432K,
"gpt-4-0613": openai.GPT40613,
"gpt-4-0314": openai.GPT40314,
"gpt-4-turbo": openai.GPT4Turbo,
"gpt-4-turbo-2024-04-09": openai.GPT4Turbo20240409,
"gpt-4-0125-preview": openai.GPT4Turbo0125,
"gpt-4-1106-preview": openai.GPT4Turbo1106,
"gpt-4-turbo-preview": openai.GPT4TurboPreview,
"gpt-4-vision-preview": openai.GPT4VisionPreview,
"gpt-4": openai.GPT4,
"gpt-3.5-turbo-0125": openai.GPT3Dot5Turbo0125,
"gpt-3.5-turbo-1106": openai.GPT3Dot5Turbo1106,
"gpt-3.5-turbo-0613": openai.GPT3Dot5Turbo0613,
"gpt-3.5-turbo-0301": openai.GPT3Dot5Turbo0301,
"gpt-3.5-turbo-16k": openai.GPT3Dot5Turbo16K,
"gpt-3.5-turbo-16k-0613": openai.GPT3Dot5Turbo16K0613,
"gpt-3.5-turbo": openai.GPT3Dot5Turbo,
"gpt-3.5-turbo-instruct": openai.GPT3Dot5TurboInstruct,
"davinci": openai.GPT3Davinci,
"davinci-002": openai.GPT3Davinci002,
"curie": openai.GPT3Curie,
"curie-002": openai.GPT3Curie002,
"ada": openai.GPT3Ada,
"ada-002": openai.GPT3Ada002,
"babbage": openai.GPT3Babbage,
"babbage-002": openai.GPT3Babbage002,
groq.LLaMA38b.String(): groq.LLaMA38b.String(),
groq.LLaMA370b.String(): groq.LLaMA370b.String(),
groq.Mixtral8x7b.String(): groq.Mixtral8x7b.String(),
groq.Gemma7b.String(): groq.Gemma7b.String(),
}

// GetModel returns the model ID corresponding to the given model name.
Expand Down Expand Up @@ -194,9 +195,10 @@ func (c *Client) Completion(
openai.GPT4VisionPreview,
openai.GPT4Turbo,
openai.GPT4Turbo20240409,
groq.LLaMA270bChat.GetModel(),
groq.Mixtral8x7bInstructV01.GetModel(),
groq.Gemma7bIt.GetModel():
groq.LLaMA38b.String(),
groq.LLaMA370b.String(),
groq.Mixtral8x7b.String(),
groq.Gemma7b.String():
r, err := c.CreateChatCompletion(ctx, content)
if err != nil {
return nil, err
Expand Down Expand Up @@ -323,7 +325,11 @@ func (c *Client) allowFuncCall(cfg *config) bool {
openai.GPT3Dot5Turbo,
openai.GPT3Dot5Turbo0125,
openai.GPT3Dot5Turbo0613,
openai.GPT3Dot5Turbo1106:
openai.GPT3Dot5Turbo1106,
groq.LLaMA38b.String(),
groq.LLaMA370b.String(),
groq.Mixtral8x7b.String(),
groq.Gemma7b.String():
return true
default:
return false
Expand Down

0 comments on commit 15c4919

Please sign in to comment.