In [None]:
import logging
import json
import subprocess
import time
from collections import namedtuple
from pathlib import Path

import google.cloud.aiplatform as aiplatform
from google.cloud import storage

logging.getLogger().setLevel(logging.INFO)

In [None]:
APP_NAME = 'ViT'
MODEL_PT_FILEPATH = ''
MAR_MODEL_OUT_PATH = ''
handler = ''
MODEL_DISPLAY_NAME = 'Vit-model'
model_version = 1
CUSTOM_PREDICTOR_IMAGE_URI = ''
PROJECT_ID = ''
BUCKET_NAME = ''

In [None]:
aiplatform.init(project=PROJECT_ID, staging_bucket=BUCKET_NAME)

In [None]:
# create directory to save model archive file
model_output_root = MODEL_PT_FILEPATH
mar_output_root = MAR_MODEL_OUT_PATH
export_path = f"{mar_output_root}/model-store"
try:
    Path(export_path).mkdir(parents=True, exist_ok=True)
except Exception as e:
    logging.warning(e)
    # retry after pause
    time.sleep(2)
    Path(export_path).mkdir(parents=True, exist_ok=True)

# parse and configure paths for model archive config
handler_path = (
    handler.replace("gs://", "/gcs/") + "predictor/handler.py"
    if handler.startswith("gs://")
    else handler
)
model_artifacts_dir = f"{model_output_root}/model/{MODEL_DISPLAY_NAME}"


# define model archive config
mar_config = {
    "MODEL_NAME": MODEL_DISPLAY_NAME,
    "HANDLER": handler_path,
    "SERIALIZED_FILE": f"{model_artifacts_dir}/ViT.pt",
    "VERSION": model_version,
    "EXPORT_PATH": f"{MAR_MODEL_OUT_PATH}/model-store",
}

# generate model archive command
archiver_cmd = (
    "torch-model-archiver --force "
    f"--model-name {mar_config['MODEL_NAME']} "
    f"--serialized-file {mar_config['SERIALIZED_FILE']} "
    f"--handler {mar_config['HANDLER']} "
    f"--version {mar_config['VERSION']}"
)
if "EXPORT_PATH" in mar_config:
    archiver_cmd += f" --export-path {mar_config['EXPORT_PATH']}"
if "EXTRA_FILES" in mar_config:
    archiver_cmd += f" --extra-files {mar_config['EXTRA_FILES']}"
if "REQUIREMENTS_FILE" in mar_config:
    archiver_cmd += f" --requirements-file {mar_config['REQUIREMENTS_FILE']}"

# run archiver command
logging.warning("Running archiver command: %s", archiver_cmd)
with subprocess.Popen(
        archiver_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
) as p:
    _, err = p.communicate()
    if err:
        raise ValueError(err)


In [None]:
!docker build --tag=$CUSTOM_PREDICTOR_IMAGE_URI ./predictor

In [None]:
!docker push $CUSTOM_PREDICTOR_IMAGE_URI

In [None]:
model_display_name = f"{APP_NAME}-v{model_version}"
model_description = "PyTorch based text classifier with custom container"

MODEL_NAME = APP_NAME
health_route = "/ping"
predict_route = f"/predictions/{MODEL_NAME}"
serving_container_ports = [7080]

In [None]:

model = aiplatform.Model.upload(
    display_name=model_display_name,
    description=model_description,
    serving_container_image_uri=CUSTOM_PREDICTOR_IMAGE_URI,
    serving_container_predict_route=predict_route,
    serving_container_health_route=health_route,
    serving_container_ports=serving_container_ports,
)

model.wait()

print(model.display_name)
print(model.resource_name)

In [None]:
endpoint_display_name = f"{APP_NAME}-endpoint"
endpoint = aiplatform.Endpoint.create(display_name=endpoint_display_name)

In [None]:
traffic_percentage = 100
machine_type = "n1-standard-4"
deployed_model_display_name = model_display_name
min_replica_count = 1
max_replica_count = 3
sync = True

model.deploy(
    endpoint=endpoint,
    deployed_model_display_name=deployed_model_display_name,
    machine_type=machine_type,
    traffic_percentage=traffic_percentage,
    sync=sync,
)

In [None]:
endpoint_display_name = f"{APP_NAME}-endpoint"
filter = f'display_name="{endpoint_display_name}"'

for endpoint_info in aiplatform.Endpoint.list(filter=filter):
    print(
        f"Endpoint display name = {endpoint_info.display_name} resource id ={endpoint_info.resource_name} "
    )

endpoint = aiplatform.Endpoint(endpoint_info.resource_name)

In [None]:
endpoint.list_models()

In [None]:
test_images = ''

In [None]:
print("=" * 100)
for image in test_images:
    print(f"Formatted input: \n{json.dumps(image, indent=4)}\n")
    prediction = endpoint.predict(instances=image)
    print(f"Prediction response: \n\t{prediction}")
    print("=" * 100)