-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1c16f10
commit 3b7dff1
Showing
9 changed files
with
292 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
[ | ||
"claude", | ||
"mistral" | ||
"mistral", | ||
"gpt" | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,6 @@ var variants []byte | |
|
||
var stream bool | ||
var tokens int | ||
var model string | ||
var system string | ||
var clr bool | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
{ | ||
"default_models": { | ||
"claude": "claude-3-sonnet-20240229", | ||
"mistral": "mistral-large-latest" | ||
"mistral": "mistral-large-latest", | ||
"gpt": "gpt-4o" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
package openai | ||
|
||
import ( | ||
"bufio" | ||
"bytes" | ||
"encoding/json" | ||
"fmt" | ||
"github.com/aakashshankar/llm-cli/session" | ||
"io" | ||
"net/http" | ||
"os" | ||
"strings" | ||
) | ||
|
||
const v1apiBaseURL = "https://api.openai.com/v1" | ||
|
||
type Client struct { | ||
config *Config | ||
client *http.Client | ||
} | ||
|
||
func NewClient(config *Config) *Client { | ||
return &Client{ | ||
config: config, | ||
client: &http.Client{}, | ||
} | ||
} | ||
|
||
func (c *Client) Prompt(prompt string, stream bool, tokens int, model string, system string, clear bool) (string, error) { | ||
if clear { | ||
session.ClearSession() | ||
} | ||
s := session.NewSession() | ||
if err := s.LoadLatest(); err != nil { | ||
fmt.Println("Error loading session:", err) | ||
os.Exit(1) | ||
} | ||
req, err := c.MarshalRequest(prompt, stream, tokens, model, system, s) | ||
if err != nil { | ||
return "", err | ||
} | ||
resp, err := c.client.Do(req) | ||
if err != nil { | ||
return "", fmt.Errorf("failed to send request: %w", err) | ||
} | ||
|
||
defer func(Body io.ReadCloser) { | ||
err := Body.Close() | ||
if err != nil { | ||
|
||
} | ||
}(resp.Body) | ||
|
||
if resp.StatusCode != http.StatusOK { | ||
return "", fmt.Errorf("request failed with status: %d %v", resp.StatusCode, resp.Status) | ||
} | ||
var response string | ||
var ok error | ||
if stream { | ||
response, ok = c.ParseStreamingResponse(resp) | ||
if ok != nil { | ||
return "", ok | ||
} | ||
s.AddMessage("assistant", response) | ||
} else { | ||
completion, ok := c.ParseResponse(resp) | ||
if ok != nil { | ||
fmt.Println("Error parsing completion:", ok) | ||
os.Exit(1) | ||
} | ||
fmt.Println(completion) | ||
s.AddMessage("assistant", completion) | ||
} | ||
err = s.Save() | ||
if err != nil { | ||
return "", err | ||
} | ||
|
||
return response, nil | ||
} | ||
|
||
func (c *Client) MarshalRequest(prompt string, stream bool, tokens int, model string, system string, | ||
session *session.Session) (*http.Request, error) { | ||
messages := prependContext(prompt, system, session) | ||
payload := CompletionRequest{ | ||
Model: model, | ||
Messages: messages, | ||
MaxTokens: tokens, | ||
Stream: stream, | ||
} | ||
jsonPayload, err := json.Marshal(payload) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to marshal request payload: %w", err) | ||
} | ||
|
||
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/chat/completions", v1apiBaseURL), bytes.NewBuffer(jsonPayload)) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed to create request: %w", err) | ||
} | ||
req.Header.Set("Content-Type", "application/json") | ||
req.Header.Set("Authorization", "Bearer "+c.config.APIKey) | ||
req.Header.Set("Accept", "application/json") | ||
return req, nil | ||
} | ||
|
||
func prependContext(prompt string, system string, session *session.Session) []Message { | ||
var messages []Message | ||
if system != "" { | ||
messages = append(messages, Message{ | ||
Role: "system", | ||
Content: system, | ||
}) | ||
} | ||
for _, message := range session.Messages { | ||
messages = append(messages, Message{ | ||
Role: message.Role, | ||
Content: message.Content, | ||
}) | ||
} | ||
messages = append(messages, Message{ | ||
Role: "user", | ||
Content: prompt, | ||
}) | ||
session.AddMessage("user", prompt) | ||
return messages | ||
} | ||
|
||
func (c *Client) ParseResponse(resp *http.Response) (string, error) { | ||
var completion CompletionResponse | ||
err := json.NewDecoder(resp.Body).Decode(&completion) | ||
if err != nil { | ||
return "", fmt.Errorf("failed to unmarshal response payload: %w", err) | ||
} | ||
return completion.Choices[0].Message.Content, nil | ||
} | ||
|
||
func (c *Client) ParseStreamingResponse(resp *http.Response) (string, error) { | ||
defer func(Body io.ReadCloser) { | ||
err := Body.Close() | ||
if err != nil { | ||
|
||
} | ||
}(resp.Body) | ||
|
||
var contentBuilder strings.Builder | ||
reader := bufio.NewReader(resp.Body) | ||
for { | ||
line, err := reader.ReadString('\n') | ||
if err != nil { | ||
if err == io.EOF { | ||
break | ||
} | ||
return "", fmt.Errorf("error reading response: %w", err) | ||
} | ||
line = strings.TrimSpace(line) | ||
if line == "" { | ||
continue | ||
} | ||
|
||
if line == "data: [DONE]" { | ||
break | ||
} | ||
|
||
if strings.HasPrefix(line, "data: ") { | ||
data := strings.TrimPrefix(line, "data: ") | ||
|
||
var event CompletionStreamResponse | ||
if err := json.Unmarshal([]byte(data), &event); err != nil { | ||
return "", fmt.Errorf("error unmarshaling event: %w", err) | ||
} | ||
content := event.Choices[0].Delta.Content | ||
fmt.Print(content) | ||
contentBuilder.WriteString(content) | ||
} | ||
} | ||
|
||
return contentBuilder.String(), nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
package openai | ||
|
||
import ( | ||
"fmt" | ||
"os" | ||
) | ||
|
||
type Config struct { | ||
APIKey string | ||
} | ||
|
||
func LoadConfig() (*Config, error) { | ||
apiKey, ok := os.LookupEnv("OPENAI_API_KEY") | ||
if !ok { | ||
return nil, fmt.Errorf("OPENAI_API_KEY environment variable not set") | ||
} | ||
|
||
return &Config{ | ||
APIKey: apiKey, | ||
}, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package openai | ||
|
||
type CompletionRequest struct { | ||
Model string `json:"model"` | ||
Messages []Message `json:"messages"` | ||
Temperature float64 `json:"temperature,omitempty"` | ||
TopP float64 `json:"top_p,omitempty"` | ||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` | ||
PresencePenalty float64 `json:"presence_penalty,omitempty"` | ||
LogitBias interface{} `json:"logit_bias,omitempty"` | ||
N int `json:"n,omitempty"` | ||
LogProbs bool `json:"logprobs,omitempty"` | ||
TopLogProbs int `json:"top_logprobs,omitempty"` | ||
ResponseFormat string `json:"response_format,omitempty"` | ||
MaxTokens int `json:"max_tokens,omitempty"` | ||
Stream bool `json:"stream,omitempty"` | ||
StreamOptions interface{} `json:"stream_options,omitempty"` | ||
Seed int `json:"seed,omitempty"` | ||
Stop []string `json:"stop,omitempty"` | ||
Tools []struct { | ||
Type string `json:"type"` | ||
Function string `json:"function"` | ||
} `json:"tools,omitempty"` | ||
User string `json:"user,omitempty"` | ||
} | ||
|
||
type Message struct { | ||
Role string `json:"role"` | ||
Content string `json:"content"` | ||
} | ||
|
||
type CompletionResponse struct { | ||
Id string `json:"id"` | ||
Object string `json:"object"` | ||
Created int `json:"created"` | ||
Model string `json:"model"` | ||
Choices []struct { | ||
Message struct { | ||
Content string `json:"content"` | ||
Role string `json:"role"` | ||
} `json:"message"` | ||
FinishReason string `json:"finish_reason"` | ||
Index int `json:"index"` | ||
} | ||
Usage struct { | ||
PromptTokens int `json:"prompt_tokens"` | ||
CompletionTokens int `json:"completion_tokens"` | ||
TotalTokens int `json:"total_tokens"` | ||
} `json:"usage"` | ||
SystemFingerprint string `json:"system_fingerprint"` | ||
} | ||
|
||
type CompletionStreamResponse struct { | ||
ID string `json:"id"` | ||
Object string `json:"object"` | ||
Created int `json:"created"` | ||
Model string `json:"model"` | ||
Choices []struct { | ||
Index int `json:"index"` | ||
Delta struct { | ||
Role string `json:"role,omitempty"` | ||
Content string `json:"content,omitempty"` | ||
} `json:"delta"` | ||
FinishReason interface{} `json:"finish_reason"` | ||
Logprobs interface{} `json:"logprobs"` | ||
} `json:"choices"` | ||
Usage struct { | ||
PromptTokens int `json:"prompt_tokens"` | ||
CompletionTokens int `json:"completion_tokens"` | ||
TotalTokens int `json:"total_tokens"` | ||
} `json:"usage"` | ||
} |