In [None]:
PROJECT_ID = ""
REGION = "us-central1"
BUCKET_URI = f"gs://your-bucket-name-{PROJECT_ID}-unique"

In [None]:
from google.cloud import aiplatform

In [None]:
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)

In [None]:
APP_NAME = "ViT-model"

PRE_BUILT_TRAINING_CONTAINER_IMAGE_URI = ("us-docker.pkg.dev/vertex-ai/training/pytorch-gpu.1-9:latest")

source_package_file_name = "/dist/trainer-0.1.tar.gz"
python_package_gcs_uri = (f"{BUCKET_URI}/pytorch-on-gcp/{APP_NAME}/train/python_package/trainer-0.1.tar.gz")
python_module_name = "task"

In [None]:
! cd python3 setup.py sdist --formats=gztar

! gsutil cp {source_package_file_name} {python_package_gcs_uri}

! gsutil ls -l {python_package_gcs_uri}

In [None]:
print(f"APP_NAME={APP_NAME}")
print(
    f"PRE_BUILT_TRAINING_CONTAINER_IMAGE_URI={PRE_BUILT_TRAINING_CONTAINER_IMAGE_URI}"
)
print(f"python_package_gcs_uri={python_package_gcs_uri}")
print(f"python_module_name={python_module_name}")

In [None]:
JOB_NAME = f"ViT-model-server"
print(f"JOB_NAME={JOB_NAME}")

job = aiplatform.CustomPythonPackageTrainingJob(
    display_name=f"{JOB_NAME}",
    python_package_gcs_uri=python_package_gcs_uri,
    python_module_name=python_module_name,
    container_uri=PRE_BUILT_TRAINING_CONTAINER_IMAGE_URI,
)

In [None]:
# Training cluster worker pool configuration
REPLICA_COUNT = 3
MACHINE_TYPE = "n1-standard-16"
ACCELERATOR_TYPE = "NVIDIA_TESLA_V100"
ACCELERATOR_COUNT = 2

# Reduction Server configuration
REDUCTION_SERVER_COUNT = 4
REDUCTION_SERVER_MACHINE_TYPE = "n1-highcpu-16"
REDUCTION_SERVER_IMAGE_URI = (
    "us-docker.pkg.dev/vertex-ai-restricted/training/reductionserver:latest"
)
ENVIRONMENT_VARIABLES = {"NCCL_DEBUG": "INFO"}

In [None]:
model = job.run(
    replica_count=REPLICA_COUNT,
    machine_type=MACHINE_TYPE,
    accelerator_type=ACCELERATOR_TYPE,
    accelerator_count=ACCELERATOR_COUNT,
    reduction_server_replica_count=REDUCTION_SERVER_COUNT,
    reduction_server_machine_type=REDUCTION_SERVER_MACHINE_TYPE,
    reduction_server_container_uri=REDUCTION_SERVER_IMAGE_URI,
    environment_variables=ENVIRONMENT_VARIABLES,
    sync=True,
)