In [None]:
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TF-Keras Image Classification Distributed Multi-Worker Training on GPU using Vertex Training with Custom Container

<table align="left">
  <td>
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/master/community-content/tf_keras_image_classification_distributed_multi_worker_with_vertex_sdk/multi_worker_vertex_training_on_gpu_with_custom_container.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      View on GitHub
    </a>
  </td>
</table>

## Setup

In [None]:
PROJECT_ID = "YOUR PROJECT ID"
BUCKET_NAME = "gs://YOUR BUCKET NAME"
REGION = "YOUR REGION"
SERVICE_ACCOUNT = "YOUR SERVICE ACCOUNT"

In [None]:
content_name = "tf-keras-img-cls-dist-multi-worker-gpu-cust-cont"

## Vertex Training using Vertex SDK and Custom Container

### Build Custom Container

In [None]:
hostname = "gcr.io"
image_name = content_name
tag = "latest"

custom_container_image_uri = f"{hostname}/{PROJECT_ID}/{image_name}:{tag}"

In [None]:
! cd trainer && docker build -t $custom_container_image_uri -f gpu.Dockerfile .

In [None]:
! docker run --rm $custom_container_image_uri --epochs 2 --local-mode

In [None]:
! docker push $custom_container_image_uri

In [None]:
! gcloud container images list --repository $hostname/$PROJECT_ID

### Initialize Vertex SDK

In [None]:
! pip install -r requirements.txt

In [None]:
from google.cloud import aiplatform

aiplatform.init(
    project=PROJECT_ID,
    staging_bucket=BUCKET_NAME,
    location=REGION,
)

### Create a Vertex Tensorboard Instance

In [None]:
tensorboard = aiplatform.Tensorboard.create(
    display_name=content_name,
)

#### Option: Use a Previously Created Vertex Tensorboard Instance

```
tensorboard_name = "Your Tensorboard Resource Name or Tensorboard ID"
tensorboard = aiplatform.Tensorboard(tensorboard_name=tensorboard_name)
```

### Run a Vertex SDK CustomContainerTrainingJob

In [None]:
display_name = content_name
gcs_output_uri_prefix = f"{BUCKET_NAME}/{display_name}"

replica_count = 4
machine_type = "n1-standard-4"
accelerator_count = 1
accelerator_type = "NVIDIA_TESLA_K80"

container_args = [
    "--epochs",
    "50",
    "--batch-size",
    "32",
]

In [None]:
custom_container_training_job = aiplatform.CustomContainerTrainingJob(
    display_name=display_name,
    container_uri=custom_container_image_uri,
)

In [None]:
custom_container_training_job.run(
    args=container_args,
    base_output_dir=gcs_output_uri_prefix,
    replica_count=replica_count,
    machine_type=machine_type,
    accelerator_type=accelerator_type,
    accelerator_count=accelerator_count,
    tensorboard=tensorboard.resource_name,
    service_account=SERVICE_ACCOUNT,
)

In [None]:
print(f"Custom Training Job Name: {custom_container_training_job.resource_name}")
print(f"GCS Output URI Prefix: {gcs_output_uri_prefix}")

### Training Output Artifact

In [None]:
! gsutil ls $gcs_output_uri_prefix

## Clean Up Artifact

In [None]:
! gsutil rm -rf $gcs_output_uri_prefix