forked from tmc/langchaingo
-
Notifications
You must be signed in to change notification settings - Fork 2
/
palm_llm.go
116 lines (96 loc) · 3.14 KB
/
palm_llm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
// package palm implements a langchaingo provider for Google Vertex AI legacy
// PaLM models. Use the newer Gemini models via llms/googleai/vertex if
// possible.
package palm
import (
"context"
"errors"
"github.com/ankit-arora/langchaingo/callbacks"
"github.com/ankit-arora/langchaingo/llms"
"github.com/ankit-arora/langchaingo/llms/googleai/internal/palmclient"
)
var (
ErrEmptyResponse = errors.New("no response")
ErrMissingProjectID = errors.New("missing the GCP Project ID, set it in the GOOGLE_CLOUD_PROJECT environment variable") //nolint:lll
ErrUnexpectedResponseLength = errors.New("unexpected length of response")
ErrNotImplemented = errors.New("not implemented")
)
type LLM struct {
CallbacksHandler callbacks.Handler
client *palmclient.PaLMClient
}
var _ llms.Model = (*LLM)(nil)
// Call requests a completion for the given prompt.
func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) {
return llms.GenerateFromSinglePrompt(ctx, o, prompt, options...)
}
// GenerateContent implements the Model interface.
func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, whitespace
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages)
}
opts := llms.CallOptions{}
for _, opt := range options {
opt(&opts)
}
// Assume we get a single text message
msg0 := messages[0]
part := msg0.Parts[0]
results, err := o.client.CreateCompletion(ctx, &palmclient.CompletionRequest{
Prompts: []string{part.(llms.TextContent).Text},
MaxTokens: opts.MaxTokens,
Temperature: opts.Temperature,
StopSequences: opts.StopWords,
})
if err != nil {
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
}
resp := &llms.ContentResponse{
Choices: []*llms.ContentChoice{
{
Content: results[0].Text,
},
},
}
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMGenerateContentEnd(ctx, resp)
}
return resp, nil
}
// CreateEmbedding creates embeddings for the given input texts.
func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float32, error) {
embeddings, err := o.client.CreateEmbedding(ctx, &palmclient.EmbeddingRequest{
Input: inputTexts,
})
if err != nil {
return [][]float32{}, err
}
if len(embeddings) == 0 {
return nil, ErrEmptyResponse
}
if len(inputTexts) != len(embeddings) {
return embeddings, ErrUnexpectedResponseLength
}
return embeddings, nil
}
// New returns a new palmclient PaLM LLM.
func New(opts ...Option) (*LLM, error) {
client, err := newClient(opts...)
return &LLM{client: client}, err
}
func newClient(opts ...Option) (*palmclient.PaLMClient, error) {
// Ensure options are initialized only once.
initOptions.Do(initOpts)
options := &options{}
*options = *defaultOptions // Copy default options.
for _, opt := range opts {
opt(options)
}
if len(options.projectID) == 0 {
return nil, ErrMissingProjectID
}
return palmclient.New(options.projectID, options.clientOptions...)
}