# PythonプログラムからVertex AI エンドポイントを利用する

## (非常時用)

In [None]:
import sys

module_list = sys.modules
ENV_COLAB = ('google.colab' in module_list)

if ENV_COLAB:
    print("google_colab")
    
    !pip install -q google-cloud-aiplatform
    from google.colab import auth
    auth.authenticate_user()
    
    from google.colab import drive
    drive.mount('/content/drive')

    import os
    if not os.getenv("IS_TESTING"):
        # Restart the kernel after pip installs
        import IPython
        app = IPython.Application.instance()
        app.kernel.do_shutdown(True)

else:
    print("Not google_colab")

## 必要なライブラリのインストール、インポート

In [None]:
!pip install -q opencv-python==4.1.2.30

In [None]:
import base64

import cv2
import matplotlib.pyplot as plt
from google.cloud import aiplatform
from google.cloud.aiplatform.gapic.schema import predict

## 関数を定義

In [None]:
# 参考: https://github.com/googleapis/python-aiplatform/blob/main/samples/snippets/prediction_service/predict_image_object_detection_sample.py

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# [START aiplatform_predict_image_object_detection_sample]
def predict_image_object_detection(
    project: str,
    endpoint_id: str,
    filename: str,
    location: str = "us-central1",
    api_endpoint: str = "us-central1-aiplatform.googleapis.com",
    confidence_threshold: float = 0.5,
    max_predictions: int = 10,
):
    # The AI Platform services require regional API endpoints.
    client_options = {"api_endpoint": api_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)
    with open(filename, "rb") as f:
        file_content = f.read()

    # The format of each instance should conform to the deployed model's prediction input schema.
    encoded_content = base64.b64encode(file_content).decode("utf-8")
    instance = predict.instance.ImageObjectDetectionPredictionInstance(
        content=encoded_content,
    ).to_value()
    instances = [instance]
    # See gs://google-cloud-aiplatform/schema/predict/params/image_object_detection_1.0.0.yaml for the format of the parameters.
    parameters = predict.params.ImageObjectDetectionPredictionParams(
        #confidence_threshold=0.5, max_predictions=10,
        confidence_threshold=confidence_threshold, max_predictions=max_predictions,
    ).to_value()
    endpoint = client.endpoint_path(
        project=project, location=location, endpoint=endpoint_id
    )
    response = client.predict(
        endpoint=endpoint, instances=instances, parameters=parameters
    )
    # print("response: " + filename)
    # print(" deployed_model_id:", response.deployed_model_id)
    # See gs://google-cloud-aiplatform/schema/predict/prediction/image_object_detection_1.0.0.yaml for the format of the predictions.
    predictions = response.predictions
    for prediction in predictions:
        # print(" prediction:", dict(prediction))
        pred_result = dict(prediction)
    return pred_result

# [END aiplatform_predict_image_object_detection_sample]

In [None]:
def show_image_with_bbox(filename, displayNames, confidences, bboxes, confidence_threshold=0):
    text_color = (0, 0, 0)
    box_color = (0, 255, 0)
    font = cv2.FONT_HERSHEY_SIMPLEX
    text_size = 0.5
        
    ship_count = 0
    
    img = cv2.imread(filename)
    height = img.shape[0]
    width  = img.shape[1]
    
    for index in range(len(confidences)):
        confidence = confidences[index]
        
        # 信頼度が閾値未満の場合はスキップ
        if confidence < confidence_threshold:
            continue

        ship_count += 1

        bbox = bboxes[index]
        x1 = int(bbox[0] * width)
        x2 = int(bbox[1] * width)
        y1 = int(bbox[2] * height)
        y2 = int(bbox[3] * height)

        cv2.rectangle(img, (x1, y1), (x2, y2), box_color, 2)# ボックス

        conf_str = '{:.3f}'.format(confidence)
        txt = f'{displayNames[index]} {conf_str}'
        txt_size_x, txt_size_y = cv2.getTextSize(txt, font, text_size, 0)[0]

        # bboxラベル位置 y
        txt_padding_y = 4
        txt_rectangle_pts_y = [y1 - txt_size_y - txt_padding_y*2, y1]
        txt_puttext_pt_y = y1 - txt_padding_y
        # ラベルが画像上部にかかる場合、bbox下部にラベルを表示
        if (txt_rectangle_pts_y[0] < 0):
            txt_rectangle_pts_y = [y2 + txt_size_y + txt_padding_y*2, y2]
            txt_puttext_pt_y = y2 + txt_size_y + txt_padding_y

        # bboxラベル位置 x
        txt_padding_x = 2
        txt_rectangle_pts_x = [x1, x1 + txt_size_x + txt_padding_x*2]
        txt_puttext_pt_x = x1 + txt_padding_x
        # ラベルが画像右部にかかる場合、bbox右寄せにラベルを表示
        if (txt_rectangle_pts_x[1] > width):
            txt_rectangle_pts_x = [x2 - txt_size_x - txt_padding_x*2, x2]
            txt_puttext_pt_x = x2 - txt_size_x + txt_padding_x

        cv2.rectangle(img, (txt_rectangle_pts_x[0], txt_rectangle_pts_y[0]), (txt_rectangle_pts_x[1], txt_rectangle_pts_y[1]), box_color, -1)# テキストの下地
        cv2.putText(img, txt, (txt_puttext_pt_x, txt_puttext_pt_y), font, text_size, text_color, thickness=1, lineType=cv2.LINE_AA)# テキスト
        
    fig, ax = plt.subplots() 
    ax.axis("off")
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
    plt.title(f'{filename} => ship: {ship_count}')
    plt.show()

## 推論の実行（Vertex AI エンドポイントへリクエスト）

### リクエストを送信

In [None]:
project="プロジェクト名 or ID"
endpoint_id="エンドポイントID"
location="リージョン"
filename="推論したい画像ファイルパス（ローカル）"
confidence_threshold=0.5
max_predictions=50

result = predict_image_object_detection(
    project=project,
    endpoint_id=endpoint_id,
    location=location,
    filename=filename,
    confidence_threshold=confidence_threshold,
    max_predictions=max_predictions,
)
result

### 結果を画像上に表示

In [None]:
show_image_with_bbox(
    filename=filename,
    displayNames=result['displayNames'],
    confidences=result['confidences'],
    bboxes=result['bboxes'],
)