-
Notifications
You must be signed in to change notification settings - Fork 0
/
openai.go
112 lines (88 loc) · 2.57 KB
/
openai.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package openai
import (
"fmt"
"net/http"
"time"
"github.com/askasoft/pango/log"
"github.com/askasoft/pango/sdk"
)
type OpenAI struct {
Domain string
Apikey string
Transport http.RoundTripper
Timeout time.Duration
Logger log.Logger
MaxRetries int
RetryAfter time.Duration
AbortOnRetry func() bool
AbortInterval time.Duration
}
func (oai *OpenAI) endpoint(format string, a ...any) string {
return "https://" + oai.Domain + "/v1" + fmt.Sprintf(format, a...)
}
func (oai *OpenAI) call(req *http.Request) (res *http.Response, err error) {
client := &http.Client{
Transport: oai.Transport,
Timeout: oai.Timeout,
}
if log := oai.Logger; log != nil {
log.Debugf("%s %s", req.Method, req.URL)
}
rid := log.TraceHttpRequest(oai.Logger, req)
res, err = client.Do(req)
if err != nil {
return res, sdk.NewNetError(err, oai.RetryAfter)
}
log.TraceHttpResponse(oai.Logger, res, rid)
return res, nil
}
func (oai *OpenAI) authAndCall(req *http.Request) (res *http.Response, err error) {
oai.authenticate(req)
return oai.call(req)
}
func (oai *OpenAI) authenticate(req *http.Request) {
if req.Header.Get("Content-Type") == "" {
req.Header.Set("Content-Type", contentTypeJSON)
}
req.Header.Set("Authorization", "Bearer "+oai.Apikey)
}
func (oai *OpenAI) doCall(req *http.Request, result any) error {
res, err := oai.authAndCall(req)
if err != nil {
return err
}
return decodeResponse(res, result, oai.RetryAfter)
}
func (oai *OpenAI) doPostWithRetry(url string, source, result any) error {
return sdk.RetryForError(func() error {
return oai.doPost(url, source, result)
}, oai.MaxRetries, oai.AbortOnRetry, oai.AbortInterval, oai.Logger)
}
func (oai *OpenAI) doPost(url string, source, result any) error {
buf, ct, err := buildJsonRequest(source)
if err != nil {
return err
}
req, err := http.NewRequest(http.MethodPost, url, buf)
if err != nil {
return err
}
if ct != "" {
req.Header.Set("Content-Type", ct)
}
return oai.doCall(req, result)
}
// https://platform.openai.com/docs/api-reference/chat/create
func (oai *OpenAI) CreateChatCompletion(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
url := oai.endpoint("/chat/completions")
res := &ChatCompletionResponse{}
err := oai.doPostWithRetry(url, req, res)
return res, err
}
// https://platform.openai.com/docs/api-reference/embeddings/create
func (oai *OpenAI) CreateTextEmbeddings(req *TextEmbeddingsRequest) (*TextEmbeddingsResponse, error) {
url := oai.endpoint("/embeddings")
res := &TextEmbeddingsResponse{}
err := oai.doPostWithRetry(url, req, res)
return res, err
}