Skip to content

Commit

Permalink
openAI support
Browse files Browse the repository at this point in the history
  • Loading branch information
aakashshankar committed Jun 3, 2024
1 parent 1c16f10 commit 3b7dff1
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 4 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ A CLI tool to interact with LLMs. Provide your API keys in this format: `<comple
The supported completers are:
- [x] `ANTHROPIC`
- [x] `MISTRAL`
- [ ] `OPENAI`
- [x] `OPENAI`
- [ ] `PERPLEXITY`

## Installation ⚙️

Expand All @@ -28,6 +29,12 @@ Your conversation history can be shared among all supported LLMs. This allows yo
### 3. Terminal assistance
Prefixing any command with llm provides a clear and concise explanation of that command, along with examples demonstrating its usage.

### Coming soon 🤫
- [ ] Default models for each variant
- [ ] Carefully crafted system prompts for terminal assistance
- [ ] Multiple sessions
- [ ] UI using [tview](https://github.com/rivo/tview)

## Usage 💻

### 1. CLIs
Expand All @@ -50,5 +57,6 @@ And of course, add `/path/to/repo/bin` to your `$PATH` to use the CLI from anywh
Create an environment variable `DEFAULT_COMPLETER` with the name of the completer you want to use. They are:
- `claude`
- `mistral`
- `gpt`

Then, you can prefix any command you want to understand with `llm` to understand how it works with examples.
1 change: 1 addition & 0 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func Chat(llmType string) {
fmt.Println("Error getting default model:", err)
os.Exit(1)
}
fmt.Println("Entering chat mode. Type 'exit' to exit.")
for {
fmt.Print("> ")
text, _ := reader.ReadString('\n')
Expand Down
3 changes: 2 additions & 1 deletion cmd/static/variants.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[
"claude",
"mistral"
"mistral",
"gpt"
]
1 change: 0 additions & 1 deletion cmd/variants.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ var variants []byte

var stream bool
var tokens int
var model string
var system string
var clr bool

Expand Down
3 changes: 2 additions & 1 deletion defaults/static/model_configs.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"default_models": {
"claude": "claude-3-sonnet-20240229",
"mistral": "mistral-large-latest"
"mistral": "mistral-large-latest",
"gpt": "gpt-4o"
}
}
7 changes: 7 additions & 0 deletions llm/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"github.com/aakashshankar/llm-cli/session"
"github.com/aakashshankar/llm-cli/variants/anthropic"
"github.com/aakashshankar/llm-cli/variants/mistral"
"github.com/aakashshankar/llm-cli/variants/openai"
"net/http"
)

Expand All @@ -29,6 +30,12 @@ func NewLLM(llmType string) (LLM, error) {
return nil, fmt.Errorf("error loading config for Mistral: %w", err)
}
return mistral.NewClient(config), nil
case "gpt":
config, err := openai.LoadConfig()
if err != nil {
return nil, fmt.Errorf("error loading config for OpenAI: %w", err)
}
return openai.NewClient(config), nil
default:
return nil, fmt.Errorf("unknown LLM type: %s", llmType)
}
Expand Down
178 changes: 178 additions & 0 deletions variants/openai/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package openai

import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/aakashshankar/llm-cli/session"
"io"
"net/http"
"os"
"strings"
)

const v1apiBaseURL = "https://api.openai.com/v1"

type Client struct {
config *Config
client *http.Client
}

func NewClient(config *Config) *Client {
return &Client{
config: config,
client: &http.Client{},
}
}

func (c *Client) Prompt(prompt string, stream bool, tokens int, model string, system string, clear bool) (string, error) {
if clear {
session.ClearSession()
}
s := session.NewSession()
if err := s.LoadLatest(); err != nil {
fmt.Println("Error loading session:", err)
os.Exit(1)
}
req, err := c.MarshalRequest(prompt, stream, tokens, model, system, s)
if err != nil {
return "", err
}
resp, err := c.client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}

defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {

}
}(resp.Body)

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("request failed with status: %d %v", resp.StatusCode, resp.Status)
}
var response string
var ok error
if stream {
response, ok = c.ParseStreamingResponse(resp)
if ok != nil {
return "", ok
}
s.AddMessage("assistant", response)
} else {
completion, ok := c.ParseResponse(resp)
if ok != nil {
fmt.Println("Error parsing completion:", ok)
os.Exit(1)
}
fmt.Println(completion)
s.AddMessage("assistant", completion)
}
err = s.Save()
if err != nil {
return "", err
}

return response, nil
}

func (c *Client) MarshalRequest(prompt string, stream bool, tokens int, model string, system string,
session *session.Session) (*http.Request, error) {
messages := prependContext(prompt, system, session)
payload := CompletionRequest{
Model: model,
Messages: messages,
MaxTokens: tokens,
Stream: stream,
}
jsonPayload, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
}

req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/chat/completions", v1apiBaseURL), bytes.NewBuffer(jsonPayload))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
req.Header.Set("Accept", "application/json")
return req, nil
}

func prependContext(prompt string, system string, session *session.Session) []Message {
var messages []Message
if system != "" {
messages = append(messages, Message{
Role: "system",
Content: system,
})
}
for _, message := range session.Messages {
messages = append(messages, Message{
Role: message.Role,
Content: message.Content,
})
}
messages = append(messages, Message{
Role: "user",
Content: prompt,
})
session.AddMessage("user", prompt)
return messages
}

func (c *Client) ParseResponse(resp *http.Response) (string, error) {
var completion CompletionResponse
err := json.NewDecoder(resp.Body).Decode(&completion)
if err != nil {
return "", fmt.Errorf("failed to unmarshal response payload: %w", err)
}
return completion.Choices[0].Message.Content, nil
}

func (c *Client) ParseStreamingResponse(resp *http.Response) (string, error) {
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {

}
}(resp.Body)

var contentBuilder strings.Builder
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
return "", fmt.Errorf("error reading response: %w", err)
}
line = strings.TrimSpace(line)
if line == "" {
continue
}

if line == "data: [DONE]" {
break
}

if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")

var event CompletionStreamResponse
if err := json.Unmarshal([]byte(data), &event); err != nil {
return "", fmt.Errorf("error unmarshaling event: %w", err)
}
content := event.Choices[0].Delta.Content
fmt.Print(content)
contentBuilder.WriteString(content)
}
}

return contentBuilder.String(), nil
}
21 changes: 21 additions & 0 deletions variants/openai/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package openai

import (
"fmt"
"os"
)

type Config struct {
APIKey string
}

func LoadConfig() (*Config, error) {
apiKey, ok := os.LookupEnv("OPENAI_API_KEY")
if !ok {
return nil, fmt.Errorf("OPENAI_API_KEY environment variable not set")
}

return &Config{
APIKey: apiKey,
}, nil
}
72 changes: 72 additions & 0 deletions variants/openai/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package openai

type CompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
LogitBias interface{} `json:"logit_bias,omitempty"`
N int `json:"n,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions interface{} `json:"stream_options,omitempty"`
Seed int `json:"seed,omitempty"`
Stop []string `json:"stop,omitempty"`
Tools []struct {
Type string `json:"type"`
Function string `json:"function"`
} `json:"tools,omitempty"`
User string `json:"user,omitempty"`
}

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}

type CompletionResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []struct {
Message struct {
Content string `json:"content"`
Role string `json:"role"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
Index int `json:"index"`
}
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
SystemFingerprint string `json:"system_fingerprint"`
}

type CompletionStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Delta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
} `json:"delta"`
FinishReason interface{} `json:"finish_reason"`
Logprobs interface{} `json:"logprobs"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}

0 comments on commit 3b7dff1

Please sign in to comment.