# Vertex SDK Custom Training with Custom Container for PyTorch Image Classification Distributed Multi Replica

## Setup

In [None]:
PROJECT_ID="aiplatform-dev"
BUCKET_NAME="gs://aiplatform-dev"
REGION="us-central1"

In [None]:
! gsutil mb -l $REGION $BUCKET_NAME

In [None]:
! gsutil ls -al $BUCKET_NAME

In [None]:
tutorial_name = "pytorch-img-cls-distributed-multi-replica-cpu"

## Local Training

In [None]:
!ls trainer

In [None]:
!cat trainer/requirements.txt

In [None]:
!pip install -r trainer/requirements.txt

In [None]:
!cat trainer/task.py

In [None]:
%run trainer/task.py \
  --epochs 5

In [None]:
!ls ./tmp

In [None]:
!rm -rf ./tmp

## Vertex SDK Custom Training using Custom Container

### Custom PyTorch Container for Training

In [None]:
hostname = "gcr.io"
image_name = tutorial_name
tag = "latest"

custom_container_image_uri=f"{hostname}/{PROJECT_ID}/{image_name}:{tag}"

In [None]:
!cd trainer && docker build -t $custom_container_image_uri -f Dockerfile .

In [None]:
!docker run --rm $custom_container_image_uri --epochs 5

In [None]:
!docker push $custom_container_image_uri

In [None]:
!gcloud container images list --repository $hostname/$PROJECT_ID

### Configs

In [None]:
!cat requirements.txt

In [None]:
!pip install -r requirements.txt

In [None]:
from google.cloud import aiplatform

aiplatform.init(
    project=PROJECT_ID,
    staging_bucket=BUCKET_NAME,
    location=REGION,
)

In [None]:
display_name = tutorial_name
gcs_output_uri_prefix = f"{BUCKET_NAME}/{display_name}"

replica_count = 4
machine_type = "n1-standard-4"
accelerator_count = 0

container_args = [
    '--backend', 'gloo',
    '--no-cuda',
    '--batch-size', '128',
]

### Run a CustomContainerTrainingJob

In [None]:
custom_container_training_job = aiplatform.CustomContainerTrainingJob(
    display_name=display_name,
    container_uri=custom_container_image_uri,
)

In [None]:
custom_container_training_job.run(
    args=container_args,
    base_output_dir=gcs_output_uri_prefix,
    replica_count=replica_count,
    machine_type=machine_type,
    accelerator_count=accelerator_count,
)

In [None]:
print(f'Custom Training Job Name: {custom_container_training_job.resource_name}')
print(f'GCS Output URI Prefix: {gcs_output_uri_prefix}')

### Training Artifact

In [None]:
!gsutil ls $gcs_output_uri_prefix

In [None]:
!gsutil rm -rf $gcs_output_uri_prefix