diff --git a/aiplatform/snippets/embeddings.go b/aiplatform/snippets/embeddings.go index be3f0816db..0209e0bda6 100644 --- a/aiplatform/snippets/embeddings.go +++ b/aiplatform/snippets/embeddings.go @@ -28,9 +28,14 @@ import ( ) func embedTexts( - apiEndpoint, project, model string, texts []string, task string) ([][]float32, error) { + project, location string, texts []string) ([][]float32, error) { ctx := context.Background() + apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location) + model := "text-embedding-004" + task := "QUESTION_ANSWERING" + customOutputDimensionality := 5 + client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint)) if err != nil { return nil, err @@ -38,7 +43,6 @@ func embedTexts( defer client.Close() match := regexp.MustCompile(`^(\w+-\w+)`).FindStringSubmatch(apiEndpoint) - location := "us-central1" if match != nil { location = match[1] } @@ -52,10 +56,17 @@ func embedTexts( }, }) } + outputDimensionality := structpb.NewNullValue() + outputDimensionality = structpb.NewNumberValue(float64(customOutputDimensionality)) + + params := structpb.NewStructValue(&structpb.Struct{ + Fields: map[string]*structpb.Value{"outputDimensionality": outputDimensionality}, + }) req := &aiplatformpb.PredictRequest{ - Endpoint: endpoint, - Instances: instances, + Endpoint: endpoint, + Instances: instances, + Parameters: params, } resp, err := client.Predict(ctx, req) if err != nil { diff --git a/aiplatform/snippets/embeddings_test.go b/aiplatform/snippets/embeddings_test.go index e5ebf7f776..a286947e98 100644 --- a/aiplatform/snippets/embeddings_test.go +++ b/aiplatform/snippets/embeddings_test.go @@ -22,15 +22,20 @@ import ( func TestGenerateEmbeddings(t *testing.T) { tc := testutil.SystemTest(t) - apiEndpoint := "us-central1-aiplatform.googleapis.com:443" - model := "textembedding-gecko@003" texts := []string{"banana muffins? ", "banana bread? banana muffins?"} - embeddings, err := embedTexts(apiEndpoint, tc.ProjectID, model, texts, "RETRIEVAL_DOCUMENT") + dimensionality := 5 + location := "us-central1" + embeddings, err := embedTexts(tc.ProjectID, location, texts) if err != nil { t.Fatal(err) } - if len(embeddings) != len(texts) || len(embeddings[0]) != 768 { - t.Errorf("len(embeddings), len(embeddings[0]) = %d, %d, want %d, 768", len(embeddings), len(embeddings[0]), len(texts)) + + embeddingsLen := len(embeddings) + textsLen := len(texts) + embeddingDimensionality := len(embeddings[0]) + + if embeddingsLen != textsLen || embeddingDimensionality != dimensionality { + t.Errorf("embeddingsLen, embeddingDimensionality = %d, %d, want %d, %d", embeddingsLen, embeddingDimensionality, textsLen, dimensionality) } } diff --git a/aiplatform/text-embeddings/embeddings.go b/aiplatform/text-embeddings/embeddings.go deleted file mode 100644 index 320813e5e4..0000000000 --- a/aiplatform/text-embeddings/embeddings.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package snippets - -// [START aiplatform_text_embeddings] -import ( - "context" - "fmt" - "io" - - aiplatform "cloud.google.com/go/aiplatform/apiv1beta1" - "cloud.google.com/go/aiplatform/apiv1beta1/aiplatformpb" - "google.golang.org/api/option" - "google.golang.org/protobuf/types/known/structpb" -) - -// generateEmbeddings creates embeddings from text provided. -func generateEmbeddings(w io.Writer, prompt, project, location, publisher, model string) error { - ctx := context.Background() - - apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location) - - client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint)) - if err != nil { - fmt.Fprintf(w, "unable to create prediction client: %v", err) - return err - } - defer client.Close() - - // PredictRequest requires an endpoint, instances, and parameters - // Endpoint - base := fmt.Sprintf("projects/%s/locations/%s/publishers/%s/models", project, location, publisher) - url := fmt.Sprintf("%s/%s", base, model) - - // Instances: the prompt - promptValue, err := structpb.NewValue(map[string]interface{}{ - "content": prompt, - }) - if err != nil { - fmt.Fprintf(w, "unable to convert prompt to Value: %v", err) - return err - } - - // PredictRequest: create the model prediction request - req := &aiplatformpb.PredictRequest{ - Endpoint: url, - Instances: []*structpb.Value{promptValue}, - } - - // PredictResponse: receive the response from the model - resp, err := client.Predict(ctx, req) - if err != nil { - fmt.Fprintf(w, "error in prediction: %v", err) - return err - } - - fmt.Fprintf(w, "embeddings generated: %v", resp.Predictions[0]) - return nil -} - -// [END aiplatform_text_embeddings] diff --git a/aiplatform/text-embeddings/embeddings_test.go b/aiplatform/text-embeddings/embeddings_test.go deleted file mode 100644 index 21bf6c64f5..0000000000 --- a/aiplatform/text-embeddings/embeddings_test.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package snippets - -import ( - "bytes" - "strings" - "testing" - - "github.com/GoogleCloudPlatform/golang-samples/internal/testutil" -) - -func TestGenerateEmbeddings(t *testing.T) { - tc := testutil.SystemTest(t) - - prompt := "hello, say something nice." - projectID := tc.ProjectID - location := "us-central1" - publisher := "google" - model := "textembedding-gecko" - - var buf bytes.Buffer - if err := generateEmbeddings(&buf, prompt, projectID, location, publisher, model); err != nil { - t.Fatal(err) - } - - if got := buf.String(); !strings.Contains(got, "embeddings generated:") { - t.Error("generated embeddings content not found in response") - } - -}