# Vertex SDK Custom Training with Training Script for TF-Keras Image Classification Distributed Multi Replica

## Setup

In [None]:
PROJECT_ID="aiplatform-dev"
BUCKET_NAME="gs://aiplatform-dev"
REGION="us-central1"

In [None]:
! gsutil mb -l $REGION $BUCKET_NAME

In [None]:
! gsutil ls -al $BUCKET_NAME

In [None]:
tutorial_name = "tf-keras-img-cls-distributed-multi-replica-cpu"

## Local Training

In [None]:
!ls trainer

In [None]:
!cat trainer/requirements.txt

In [None]:
!pip install -r trainer/requirements.txt

In [None]:
!cat trainer/task.py

In [None]:
%run trainer/task.py \
  --epochs 5

In [None]:
!ls ./tmp

In [None]:
!rm -rf ./tmp

## Vertex SDK Custom Training using Training Script

### Pre-Built Tensorflow Container for Training

In [None]:
prebuilt_container_image_uri = "us-docker.pkg.dev/vertex-ai/training/tf-cpu.2-4:latest"

### Configs

In [None]:
!cat requirements.txt

In [None]:
!pip install -r requirements.txt

In [None]:
from google.cloud import aiplatform

aiplatform.init(
    project=PROJECT_ID,
    staging_bucket=BUCKET_NAME,
    location=REGION,
)

In [None]:
script_path = "./trainer/task.py"
requirements = ["tensorflow_datasets"]

In [None]:
display_name = tutorial_name
gcs_output_uri_prefix = f"{BUCKET_NAME}/{display_name}"

replica_count = 4
machine_type = "n1-standard-4"
accelerator_count = 0

container_args = [
    '--machine-count', f'{replica_count}',
    '--epochs', '10'
]

### Run a CustomTrainingJob

In [None]:
custom_training_job = aiplatform.CustomTrainingJob(
    display_name=display_name,
    script_path=script_path,
    container_uri=prebuilt_container_image_uri,
    requirements=requirements,
)

In [None]:
custom_training_job.run(
    base_output_dir=gcs_output_uri_prefix,
    args=container_args,
    replica_count=replica_count,
    machine_type=machine_type,
    accelerator_count=accelerator_count,
)

In [None]:
print(f'Custom Training Job Name: {custom_training_job.resource_name}')
print(f'GCS Output URI Prefix: {gcs_output_uri_prefix}')

### Training Artifact

In [None]:
!gsutil ls $gcs_output_uri_prefix

In [None]:
!gsutil rm -rf $gcs_output_uri_prefix