```
Copyright (c) 2024, Google Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
   list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.
3. Neither the name of Google Inc. nor the names of its contributors
   may be used to endorse or promote products derived from this software without
   specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```


# Imports

In [None]:
import datetime
import os
import random
from typing import List

import google.auth
import google.auth.transport.requests
from google.cloud import aiplatform
from google.cloud.aiplatform.aiplatform import gapic
from google.cloud.aiplatform.aiplatform import jobs
from google.protobuf import json_format
from google.protobuf import struct_pb2

# Authentication

The JSON file mentioned in the cell below is created by running the following command (for service accounts)

```
gcloud auth application-default login --impersonate-service-account SERVICE_ACCT
```

or that command

```
gcloud auth application-default login 
```

to identify with your own account.

This assumes that you have first [installed](https://cloud.google.com/sdk/docs/install) `gcloud` CLI and created a service account (see [[1]](https://cloud.google.com/iam/docs/service-account-overview), [[2]](https://cloud.google.com/iam/docs/service-accounts-create)) (identified by `SERVICE_ACCT` above)

In [None]:
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/path/to/your/credentials/json/file'

# Online predictions

## With raw audio

In [None]:
def predict_endpoint_sample(
    project: str,
    endpoint_id: str,
    raw_audio: list[list[float]],
    location: str = "us-west1",
    api_endpoint: str = "us-west1-aiplatform.googleapis.com",
) -> list[dict[str, float]]:
  assert {len(x) for x in raw_audio} == {32000}, "All clips should have exactly 32000 steps."
  client_options = {'api_endpoint': api_endpoint}
  client = gapic.PredictionServiceClient(client_options=client_options)
  endpoint = client.endpoint_path(
      project=project, location=location, endpoint=endpoint_id
  )
  response = client.predict(endpoint=endpoint, instances=raw_audio)
  return response.predictions


In [None]:
endpoint_id = '200'
project = '4016704501'
raw_audio = [[random.random() for _ in range(32000)] for _ in range(4)]
embeddings = predict_endpoint_sample(
  project=project,
  endpoint_id=endpoint_id,
  raw_audio=raw_audio,
)

## With GCS bucket URIs

In [None]:

gcs_creds, project = google.auth.default()


def initial_token_refresh():
  """Obtain short lived credentials for your GCS bucket."""
  auth_req = google.auth.transport.requests.Request()
  gcs_creds.refresh(auth_req)
  assert (
      gcs_creds.valid
  ), f'Unexpected error: GCS Credentials are invalid'
  time_until_expiry = (
      gcs_creds.expiry - datetime.datetime.utcnow()
  ).total_seconds() // 60
  print(
      'Token will expire at'
      f' {gcs_creds.expiry.strftime("%Y-%m-%d %H:%M:%S")} UTC'
      f' ({time_until_expiry} minutes)'
  )


initial_token_refresh()


In [None]:

PredictionServiceClient = aiplatform.aiplatform.gapic.PredictionServiceClient

vertex_endpoint_id = '220'
vertex_endpoint_project_id = '4016704501'
vertex_endpoint_location = 'us-west1'
gcs_bucket_name = 'YOUR_BUCKET_NAME'


def create_prediction_service_client_and_endpoint_path():
  client_options = {
      'api_endpoint': (
          f'{vertex_endpoint_location}-aiplatform.googleapis.com'
      )
  }
  # Initialize client that will be used to create and send requests.
  # This client only needs to be created once, and can be reused for multiple
  # requests.
  client = PredictionServiceClient(client_options=client_options)
  endpoint_path = client.endpoint_path(
      project=vertex_endpoint_project_id,
      location=vertex_endpoint_location,
      endpoint=vertex_endpoint_id,
  )
  return client, endpoint_path


def get_prediction_instances(image_uris: List[str]):
  """Returns a list of JSON dicts to pass as Vertex PredictionService instances."""
  instances = []
  for image_uri in image_uris:
    instance_dict = {
        'bucket_name': gcs_bucket_name,
        'object_uri': image_uri,
        'bearer_token': gcs_creds.token,
    }
    instance = json_format.ParseDict(instance_dict, struct_pb2.Value())
    instances.append(instance)
  return instances


def predict(
    client: PredictionServiceClient, endpoint_path: str, image_uris: List[str]
):
  """Calls predict for a Vertex endpoint using the given image paths."""
  instances = get_prediction_instances(image_uris)
  parameters_dict = {}
  parameters = json_format.ParseDict(parameters_dict, struct_pb2.Value())

  return client.predict(
      endpoint=endpoint_path, instances=instances, parameters=parameters
  )



In [None]:

client, endpoint_path = create_prediction_service_client_and_endpoint_path()
predictions = predict(
  client,
  endpoint_path=endpoint_path,
  image_uris=['data/test.wav', 'data/test.wav']
)

# Batch prediction job

In [None]:


BUCKET_NAME = 'YOUR_BUCKET_NAME'
PROJECT_NAME = '4016704501'
MODEL_ENDPOINT = '200'  # or '220' for the endpoint eating GCS URIs
JOB_DISPLAY_NAME = 'your_job_display_name'
MODEL_NAME = f'projects/{PROJECT_NAME}/locations/us-west1/models/{MODEL_ENDPOINT}'
INSTANCES_FORMAT = 'jsonl'
GCS_SOURCE = f'gs://{BUCKET_NAME}/path/to/your/input/jsonl/file'
GCS_DESTINATION_PREFIX = f'gs://{BUCKET_NAME}/path/to/your/output/jsonl/file'
PREDICTIONS_FORMAT = 'jsonl'
ACCELERATOR_TYPE = 'NVIDIA_TESLA_V100'
ACCELERATOR_COUNT = 1
MACHINE_TYPE = 'n1-standard-8'
LOCATION = 'us-west1'

jobs.BatchPredictionJob.submit(
        job_display_name=JOB_DISPLAY_NAME,
        model_name=MODEL_NAME,
        instances_format=INSTANCES_FORMAT,
        gcs_source=GCS_SOURCE,
        gcs_destination_prefix=GCS_DESTINATION_PREFIX,
        predictions_format=PREDICTIONS_FORMAT,
        accelerator_type=ACCELERATOR_TYPE,
        accelerator_count=ACCELERATOR_COUNT,
        machine_type=MACHINE_TYPE,
        project=PROJECT_NAME,
        location=LOCATION,
)