In [None]:
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI SDK: Using PyTorch torchrun to simplify multi-node training with custom containers
<table align="left">

  <td>
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/official/training/sdk_pytorch_torchrun_custom_container_training_imagenet.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Colab logo"> Run in Colab
    </a>
  </td>
  <td>
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/official/training/sdk_pytorch_torchrun_custom_container_training_imagenet.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      View on GitHub
    </a>
  </td>
  <td>
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/main/notebooks/official/training/sdk_pytorch_torchrun_custom_container_training_imagenet.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo">
      Open in Vertex AI Workbench
    </a>
  </td>                                                                                               
</table>

## Overview

This tutorial uses the Tiny ImageNet dataset to run multi-node distributed training on Vertex AI with Torchrun. It will run distributed training on multiple nodes with GPUs.

Learn more about [Distributed training](https://cloud.google.com/vertex-ai/docs/training/distributed-training).

### Objective

In this tutorial, you will learn how to train an Imagenet model using PyTorch's Torchrun on multiple nodes.

This tutorial uses the following Google Cloud ML services:

- Vertex AI `Training`(Custom Python Package Training) 

The steps performed include:

    * Create a shell script to start an ETCD cluster on the master node
    * Create a training script using code from PyTorch Elastic's Github repository
    * Create containers that download the data, and start an ETCD cluster on the host
    * Train the model using multiple nodes with GPUs

### Dataset

For the sake of training time, the Tiny ImageNet dataset is used in this tutorial: https://image-net.org/data/tiny-imagenet-200.zip

This dataset consists of many small (~2KB) images. To avoid network bottlenecks with the large volume of network transfers from Cloud Storage to the GPUs, we will download this dataset to the containers

The training code is based on this PyTorch Torchrun example for ImageNet: https://github.com/pytorch/elastic/blob/master/examples/imagenet/main.py

### Costs 

This tutorial uses billable components of Google Cloud:

* Vertex AI Training w/ GPUs
* Vertex AI TensorBoard
* Cloud Storage

Learn about [Vertex AI
pricing](https://cloud.google.com/vertex-ai/pricing) and [Cloud Storage
pricing](https://cloud.google.com/storage/pricing), and use the [Pricing
Calculator](https://cloud.google.com/products/calculator/)
to generate a cost estimate based on your projected usage.

## Installation

Install the following packages required to execute this notebook. 

In [None]:
# Install the packages
! pip3 install --upgrade --quiet google-cloud-aiplatform \
                                 python-etcd

### Colab only: Uncomment the following cell to restart the kernel.

In [None]:
# import IPython

# app = IPython.Application.instance()
# app.kernel.do_shutdown(True)

## Before you begin

### Set up your Google Cloud project

**The following steps are required, regardless of your notebook environment.**

1. [Select or create a Google Cloud project](https://console.cloud.google.com/cloud-resource-manager). When you first create an account, you get a $300 free credit towards your compute/storage costs.

2. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

3. [Enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).

4. If you are running this notebook locally, you need to install the [Cloud SDK](https://cloud.google.com/sdk).

#### Set your project ID

**If you don't know your project ID**, try the following:
* Run `gcloud config list`.
* Run `gcloud projects list`.
* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)

In [None]:
PROJECT_ID = "[your-project-id]"  # @param {type:"string"}

# Set the project id
! gcloud config set project {PROJECT_ID}

#### Region

You can also change the `REGION` variable used by Vertex AI. Learn more about [Vertex AI regions](https://cloud.google.com/vertex-ai/docs/general/locations).

In [None]:
REGION = "us-central1"  # @param {type: "string"}

### Authenticate your Google Cloud account

Depending on your Jupyter environment, you may have to manually authenticate. Follow the relevant instructions below.

**1. Vertex AI Workbench**
* Do nothing as you are already authenticated.

**2. Local JupyterLab instance, uncomment and run:**

In [None]:
# ! gcloud auth login

**3. Colab, uncomment and run:**

In [None]:
# from google.colab import auth
# auth.authenticate_user()

**4. Service account or other**
* See how to grant Cloud Storage permissions to your service account at https://cloud.google.com/storage/docs/gsutil/commands/iam#ch-examples.

### Create a Cloud Storage bucket

Create a storage bucket to store intermediate artifacts such as datasets.

- *{Note to notebook author: For any user-provided strings that need to be unique (like bucket names or model ID's), append "-unique" to the end so proper testing can occur}*

In [None]:
BUCKET_URI = f"gs://your-bucket-name-{PROJECT_ID}-unique"  # noqa

**Only if your bucket doesn't already exist**: Run the following cell to create your Cloud Storage bucket.

In [None]:
! gsutil mb -l {REGION} -p {PROJECT_ID} {BUCKET_URI}

### Import libraries

In [None]:
import os

from google.cloud import aiplatform

### Initialize Vertex AI SDK for Python

Initialize the Vertex AI SDK for Python for your project.

In [None]:
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)

### Service Account

You use a service account to create the Vertex AI Training job. If you do not want to use your project's Compute Engine service account, set SERVICE_ACCOUNT to another service account ID.

In [None]:
SERVICE_ACCOUNT = "[your-service-account]"

If you do not provide a service account, run the code below to get the Compute Engine service account

In [None]:
# The Google Cloud Notebook product has specific requirements
IS_GOOGLE_CLOUD_NOTEBOOK = os.path.exists("/opt/deeplearning/metadata/env_version")

if (
    SERVICE_ACCOUNT == ""
    or SERVICE_ACCOUNT is None
    or SERVICE_ACCOUNT == "[your-service-account]"
):
    # Get your service account from gcloud
    if IS_GOOGLE_CLOUD_NOTEBOOK:
        shell_output = !gcloud auth list 2>/dev/null
        SERVICE_ACCOUNT = shell_output[2].replace("*", "").strip()

    if not IS_GOOGLE_CLOUD_NOTEBOOK:
        shell_output = ! gcloud projects describe  $PROJECT_ID
        project_number = shell_output[-1].split(":")[1].strip().replace("'", "")
        SERVICE_ACCOUNT = f"{project_number}-compute@developer.gserviceaccount.com"

    print("Service Account:", SERVICE_ACCOUNT)

### Enable Artifact Registry API

First, you must enable the Artifact Registry API service for your project.

Learn more about [Enabling service](https://cloud.google.com/artifact-registry/docs/enable-service).

In [None]:
! gcloud services enable artifactregistry.googleapis.com

if os.getenv("IS_TESTING"):
    ! sudo apt-get update --yes && sudo apt-get --only-upgrade --yes install google-cloud-sdk-cloud-run-proxy google-cloud-sdk-harbourbridge google-cloud-sdk-cbt google-cloud-sdk-gke-gcloud-auth-plugin google-cloud-sdk-kpt google-cloud-sdk-local-extract google-cloud-sdk-minikube google-cloud-sdk-app-engine-java google-cloud-sdk-app-engine-go google-cloud-sdk-app-engine-python google-cloud-sdk-spanner-emulator google-cloud-sdk-bigtable-emulator google-cloud-sdk-nomos google-cloud-sdk-package-go-module google-cloud-sdk-firestore-emulator kubectl google-cloud-sdk-datastore-emulator google-cloud-sdk-app-engine-python-extras google-cloud-sdk-cloud-build-local google-cloud-sdk-kubectl-oidc google-cloud-sdk-anthos-auth google-cloud-sdk-app-engine-grpc google-cloud-sdk-pubsub-emulator google-cloud-sdk-datalab google-cloud-sdk-skaffold google-cloud-sdk google-cloud-sdk-terraform-tools google-cloud-sdk-config-connector
    ! gcloud components update --quiet

### Create a private Docker repository

Your first step is to create your own Docker repository in Artifact Registry.

1. Run the `gcloud artifacts repositories create` command to create a new Docker repository with your region with the description "docker repository".

2. Run the `gcloud artifacts repositories list` command to verify that your repository was created.

In [None]:
REPOSITORY = "torchrun-imagenet-repo"

In [None]:
! gcloud artifacts repositories create {REPOSITORY} --repository-format=docker --location={REGION} --description="Docker repository"

! gcloud artifacts repositories list

### Configure authentication to your private repo

Before you push or pull container images, configure Docker to use the `gcloud` command-line tool to authenticate requests to `Artifact Registry` for your region.

In [None]:
! gcloud auth configure-docker {REGION}-docker.pkg.dev --quiet

## Vertex AI Training with GPUs

### Create files for the host container

In [None]:
%mkdir -p trainer
%cat /dev/null > trainer/__init__.py

#### Create the Dockerfile
Installs necessary libraries, and downloads the tiny ImageNet data for training

In [None]:
%%writefile trainer/Dockerfile
FROM gcr.io/deeplearning-platform-release/pytorch-gpu.1-13:m102

RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add - && \
    # Install reduction server plugin on GPU containers. google-fast-socket is
    # previously installed in GPU dlenv containers only and it is not compatible
    # with google-reduction-server.
    if dpkg -s google-fast-socket; then \
      apt remove -y google-fast-socket && \
      apt install -y google-reduction-server; \
    fi

RUN rm -f /etc/apt/sources.list.d/cuda.list && \
    rm -f /etc/apt/sources.list.d/nvidia-ml.list

RUN apt-key del 7fa2af80 && \
    apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub && \
    apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub

RUN apt-get update -y && \
    apt-get install -y curl gnupg telnet nano net-tools iputils-ping

# Set ETCD version
ARG ETCD_VER=v2.3.0
# Choose either URL
ARG GOOGLE_URL=https://storage.googleapis.com/etcd
ARG GITHUB_URL=https://github.com/etcd-io/etcd/releases/download
# Set ETCD URL to download from
ARG DOWNLOAD_URL=$GOOGLE_URL

# Install ETCD
RUN mkdir -p /tmp/etcd-download-test && \
    curl -L ${DOWNLOAD_URL}/${ETCD_VER}/etcd-${ETCD_VER}-linux-amd64.tar.gz -o /tmp/etcd-${ETCD_VER}-linux-amd64.tar.gz && \
    tar xzvf /tmp/etcd-${ETCD_VER}-linux-amd64.tar.gz -C /tmp/etcd-download-test --strip-components=1 && \
    rm -f /tmp/etcd-${ETCD_VER}-linux-amd64.tar.gz

# Copy training application code
COPY . /trainer

WORKDIR /trainer

# Install dependencies
RUN pip install -r requirements.txt

RUN chmod 777 main.sh

# Download data to the container
RUN wget -q -P /trainer/data https://image-net.org/data/tiny-imagenet-200.zip
RUN unzip -q /trainer/data/tiny-imagenet-200.zip
RUN rm /trainer/data/tiny-imagenet-200.zip

CMD ["/bin/bash", "main.sh"]

In [None]:
%%writefile trainer/requirements.txt
torch==1.13.0
torchvision==0.14.0
tensorboard==2.5.0
protobuf==3.20.*
python-etcd
python-json-logger

#### Create the main.sh file 
Starts the ETCD server on the host, saves the host IP to Cloud Storage (for the workers), and calls torchrun

In [None]:
%%writefile trainer/main.sh
#!/bin/bash
# Copyright 2022 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#            http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Provision the prerequisites for running a job on TPU VMs in GKE 
# using a Vertex AI Pipeline 
# USAGE:  ./install.sh PROJECT_ID GKE_CLUSTER NAME_PREFIX [ZONE=us-central1-b]
# ./install.sh your-project-id gke-tpu-cluster gke-tpu us-central1-b

# Set up a global error handler
err_handler() {
    echo "Error on line: $1"
    echo "Caused by: $2"
    echo "That returned exit status: $3"
    echo "Aborting..."
    exit $3
}

trap 'err_handler "$LINENO" "$BASH_COMMAND" "$?"' ERR

setup_etcd() {
    HOST_IP=$1
    # Start a local instane of ETCD v2 
    export ETCD_ENABLE_V2=true
    export ETCDCTL_API=2

    /tmp/etcd-download-test/etcd --name s1 --data-dir /tmp/etcd-download-test/s1  \
    --listen-client-urls http://0.0.0.0:2379 --advertise-client-urls http://$HOST_IP:2379 \
    --listen-peer-urls http://0.0.0.0:2380 --initial-advertise-peer-urls http://$HOST_IP:2380 \
    --initial-cluster s1=http://$HOST_IP:2380 --initial-cluster-token tkn \
    --initial-cluster-state new &> /tmp/etcd-download-test/node.log &

    /tmp/etcd-download-test/etcd --version
    /tmp/etcd-download-test/etcdctl --version
}


# Process and print passed in variables
while getopts e:a:b:d:t:w:v:u:i:p:n:r:c: option
do 
    case "${option}"
        in
        e)epochs=${OPTARG};;
        a)arch=${OPTARG};;
        b)batchsize=${OPTARG};;
        d)distbackend=${OPTARG};;
        t)data=${OPTARG};;
        w)workers=${OPTARG};;
        v)env=${OPTARG};;
        u)rdvzbackend=${OPTARG};;
        i)rdvzid=${OPTARG};;
        p)endpoint=${OPTARG};;
        n)nnodes=${OPTARG};;
        r)nprocpernode=${OPTARG};;
        c)ischief=${OPTARG};;
    esac
done

echo "epochs : $epochs"
echo "arch : $arch"
echo "batchsize : $batchsize"
echo "distbackend : $distbackend"
echo "data : $data"
echo "workers : $workers"
echo "env : $env"
echo "rdvzbackend : $rdvzbackend"
echo "rdvzid : $rdvzid"
echo "endpoint : $endpoint"
echo "nnodes : $nnodes"
echo "nprocpernode : $nprocpernode"
echo "ischief : $ischief"

# parse cluster config
IFS=' ' read -a conf <<< $(python parse_cluster_config.py)
WORKERPOOL_TYPE="${conf[0]}"

echo "WORKERPOOL_TYPE=${WORKERPOOL_TYPE}"
echo "CLUSTER_SPEC=${CLUSTER_SPEC}"

gcsfilepath="${env//\/gcs\//gs://}"

if [ "$WORKERPOOL_TYPE" == "workerpool0" ] || [ "$WORKERPOOL_TYPE" == "chief" ]; then
    HOST_IP=$(hostname -i)
    echo "HOST_IP="$HOST_IP
    echo "Writing host IP address to "$gcsfilepath
    echo $HOST_IP| gsutil cp - $gcsfilepath
    setup_etcd $HOST_IP
else
    echo "Wait 60s for the host server to come online"
    sleep 60
    echo "reading host IP address from "$gcsfilepath
    HOST_IP=$(gsutil cat $gcsfilepath)
    echo "HOST_IP="$HOST_IP
fi

env="env://"
ping -c 1 $HOST_IP

set -x

torchrun --rdzv_backend $rdvzbackend --rdzv_id $rdvzid --rdzv_endpoint $HOST_IP:2379 \
--nnodes $nnodes --nproc_per_node $nprocpernode --master_addr $HOST_IP --master_port 2379 \
main.py --epochs $epochs --arch $arch --batch-size $batchsize --dist-backend $distbackend \
--data $data \
--env $env \
--hostip $HOST_IP \
--hostipport 2379 \
--workers $workers \
--ischief $ischief

In [None]:
%%writefile trainer/parse_cluster_config.py
import os
import json

cluster_config_str = os.environ.get('CLUSTER_SPEC')
cluster_config_dict  = json.loads(cluster_config_str)
workerpool_type = cluster_config_dict['task']['type']

print(workerpool_type)

#### Create the main.py file 
Main trainer for the ImageNet training job

In [None]:
%%writefile trainer/main.py
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

r"""
Source: `pytorch imagenet example <https://github.com/pytorch/examples/blob/master/imagenet/main.py>`_ # noqa B950
Modified and simplified to make the original pytorch example compatible with
torchelastic.distributed.launch.
Changes:
1. Removed ``rank``, ``gpu``, ``multiprocessing-distributed``, ``dist_url`` options.
   These are obsolete parameters when using ``torchelastic.distributed.launch``.
2. Removed ``seed``, ``evaluate``, ``pretrained`` options for simplicity.
3. Removed ``resume``, ``start-epoch`` options.
   Loads the most recent checkpoint by default.
4. ``batch-size`` is now per GPU (worker) batch size rather than for all GPUs.
5. Defaults ``workers`` (num data loader workers) to ``0``.
Usage
::
 >>> python -m torchelastic.distributed.launch
        --nnodes=$NUM_NODES
        --nproc_per_node=$WORKERS_PER_NODE
        --rdzv_id=$JOB_ID
        --rdzv_backend=etcd
        --rdzv_endpoint=$ETCD_HOST:$ETCD_PORT
        main.py
        --arch resnet18
        --epochs 20
        --batch-size 32
        <DATA_DIR>
"""

import traceback
import argparse
import io
import os
import shutil
import time
from contextlib import contextmanager
from datetime import timedelta
from typing import List, Tuple

import numpy
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.distributed.elastic.utils.data import ElasticDistributedSampler
from torch.nn.parallel import DistributedDataParallel
from torch.optim import SGD
from torch.utils.data import DataLoader


model_names = sorted(
    name
    for name in models.__dict__
    if name.islower() and not name.startswith("__") and callable(models.__dict__[name])
)

parser = argparse.ArgumentParser(description="PyTorch Elastic ImageNet Training")
parser.add_argument("--data", metavar="DIR", help="path to dataset")
parser.add_argument(
    "-a",
    "--arch",
    metavar="ARCH",
    default="resnet18",
    choices=model_names,
    help="model architecture: " + " | ".join(model_names) + " (default: resnet18)",
)
parser.add_argument(
    "-j",
    "--workers",
    default=0,
    type=int,
    metavar="N",
    help="number of data loading workers",
)
parser.add_argument(
    "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run"
)
parser.add_argument(
    "-b",
    "--batch-size",
    default=32,
    type=int,
    metavar="N",
    help="mini-batch size (default: 32), per worker (GPU)",
)
parser.add_argument(
    "--lr",
    "--learning-rate",
    default=0.1,
    type=float,
    metavar="LR",
    help="initial learning rate",
    dest="lr",
)
parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
parser.add_argument(
    "--wd",
    "--weight-decay",
    default=1e-4,
    type=float,
    metavar="W",
    help="weight decay (default: 1e-4)",
    dest="weight_decay",
)
parser.add_argument(
    "-p",
    "--print-freq",
    default=10,
    type=int,
    metavar="N",
    help="print frequency (default: 10)",
)
parser.add_argument(
    "--dist-backend",
    default="nccl",
    choices=["nccl", "gloo"],
    type=str,
    help="distributed backend",
)
parser.add_argument(
    "--checkpoint-file",
    default="/tmp/checkpoint.pth.tar",
    type=str,
    help="checkpoint file path, to load and save to",
)
parser.add_argument(
    "--env",
    default="env://",
    type=str,
    help="setting for init_method for torch.distributed.init_process_group. Leave default unless you want to pass a shared gcs path",
)
parser.add_argument(
    "--hostip",
    default="localhost",
    type=str,
    help="setting for etcd host ip",
)
parser.add_argument(
    "--hostipport",
    default=2379,
    type=int,
    help="setting for etcd host ip port",
)
parser.add_argument(
    "--ischief",
    default="n", 
    type=str,
    help='is this cheif or worker')

def main():
    args = parser.parse_args()
    print(args)
    device_id = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(device_id)
    print(f"=> set cuda device = {device_id}")

    LOCAL_RANK=int(os.environ["LOCAL_RANK"])
    RANK=int(os.environ["RANK"])
    WORLD_SIZE=int(os.environ["WORLD_SIZE"])

    print (f"LOCAL_RANK={os.environ['LOCAL_RANK']} RANK={os.environ['RANK']} WORLD_SIZE={os.environ['WORLD_SIZE']}")
    print (f"args env= {args.env}")
    
    print (f"Host address: {os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}")
    os.environ['MASTER_ADDR']=args.hostip
    print (f"Updated IPv4 host address: {os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}")
    
    print ('Initialize process group')
    if args.env == "env://":
        dist.init_process_group(
            backend=args.dist_backend, init_method=f"{args.env}", timeout=timedelta(seconds=120)
        )
    else:
        if args.ischief.lower() == 'y':
            print ('Setting store')
            #STORE = dist.FileStore(args.env, WORLD_SIZE)
            STORE = dist.TCPStore(host_name=args.hostip, port=args.hostipport, world_size=WORLD_SIZE, is_master=True, timeout=timedelta(seconds=30))
            print (f'Store set = {STORE}')        
            dist.init_process_group(
                backend=args.dist_backend, store=STORE, timeout=timedelta(seconds=30),
                rank=RANK, world_size=WORLD_SIZE
            )

        dist.init_process_group(
            backend=args.dist_backend, init_method=f"tcp://{args.hostip}:{args.hostipport}", timeout=timedelta(seconds=120), rank=RANK, world_size=WORLD_SIZE
        )
    print ('Process initialized')

    model, criterion, optimizer = initialize_model(
        args.arch, args.lr, args.momentum, args.weight_decay, device_id
    )

    train_loader, val_loader = initialize_data_loader(
        args.data, args.batch_size, args.workers
    )

    # resume from checkpoint if one exists;
    state = load_checkpoint(
        args.checkpoint_file, device_id, args.arch, model, optimizer
    )

    start_epoch = state.epoch + 1
    print(f"=> start_epoch: {start_epoch}, best_acc1: {state.best_acc1}")

    print_freq = args.print_freq
    for epoch in range(start_epoch, args.epochs):
        state.epoch = epoch
        train_loader.batch_sampler.sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args.lr)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, device_id, print_freq)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, device_id, print_freq)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > state.best_acc1
        state.best_acc1 = max(acc1, state.best_acc1)

        if device_id == 0:
            save_checkpoint(state, is_best, args.checkpoint_file)


class State:
    """
    Container for objects that we want to checkpoint. Represents the
    current "state" of the worker. This object is mutable.
    """

    def __init__(self, arch, model, optimizer):
        self.epoch = -1
        self.best_acc1 = 0
        self.arch = arch
        self.model = model
        self.optimizer = optimizer

    def capture_snapshot(self):
        """
        Essentially a ``serialize()`` function, returns the state as an
        object compatible with ``torch.save()``. The following should work
        ::
        snapshot = state_0.capture_snapshot()
        state_1.apply_snapshot(snapshot)
        assert state_0 == state_1
        """
        return {
            "epoch": self.epoch,
            "best_acc1": self.best_acc1,
            "arch": self.arch,
            "state_dict": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
        }

    def apply_snapshot(self, obj, device_id):
        """
        The complimentary function of ``capture_snapshot()``. Applies the
        snapshot object that was returned by ``capture_snapshot()``.
        This function mutates this state object.
        """

        self.epoch = obj["epoch"]
        self.best_acc1 = obj["best_acc1"]
        self.state_dict = obj["state_dict"]
        self.model.load_state_dict(obj["state_dict"])
        self.optimizer.load_state_dict(obj["optimizer"])

    def save(self, f):
        torch.save(self.capture_snapshot(), f)

    def load(self, f, device_id):
        # Map model to be loaded to specified single gpu.
        snapshot = torch.load(f, map_location=f"cuda:{device_id}")
        self.apply_snapshot(snapshot, device_id)


def initialize_model(
    arch: str, lr: float, momentum: float, weight_decay: float, device_id: int
):
    print(f"=> creating model: {arch}")
    model = models.__dict__[arch]()
    # For multiprocessing distributed, DistributedDataParallel constructor
    # should always set the single device scope, otherwise,
    # DistributedDataParallel will use all available devices.
    model.cuda(device_id)
    cudnn.benchmark = True
    model = DistributedDataParallel(model, device_ids=[device_id])
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(device_id)
    optimizer = SGD(
        model.parameters(), lr, momentum=momentum, weight_decay=weight_decay
    )
    return model, criterion, optimizer


def initialize_data_loader(
    data_dir, batch_size, num_data_workers
) -> Tuple[DataLoader, DataLoader]:
    traindir = os.path.join(data_dir, "train")
    valdir = os.path.join(data_dir, "val")
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose(
            [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        ),
    )
    train_sampler = ElasticDistributedSampler(train_dataset)
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        num_workers=num_data_workers,
        pin_memory=True,
        sampler=train_sampler,
    )
    val_loader = DataLoader(
        datasets.ImageFolder(
            valdir,
            transforms.Compose(
                [
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                ]
            ),
        ),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_data_workers,
        pin_memory=True,
    )
    return train_loader, val_loader


def load_checkpoint(
    checkpoint_file: str,
    device_id: int,
    arch: str,
    model: DistributedDataParallel,
    optimizer,  # SGD
) -> State:
    """
    Loads a local checkpoint (if any). Otherwise, checks to see if any of
    the neighbors have a non-zero state. If so, restore the state
    from the rank that has the most up-to-date checkpoint.
    .. note:: when your job has access to a globally visible persistent storage
              (e.g. nfs mount, S3) you can simply have all workers load
              from the most recent checkpoint from such storage. Since this
              example is expected to run on vanilla hosts (with no shared
              storage) the checkpoints are written to local disk, hence
              we have the extra logic to broadcast the checkpoint from a
              surviving node.
    """

    state = State(arch, model, optimizer)

    if os.path.isfile(checkpoint_file):
        print(f"=> loading checkpoint file: {checkpoint_file}")
        state.load(checkpoint_file, device_id)
        print(f"=> loaded checkpoint file: {checkpoint_file}")

    # logic below is unnecessary when the checkpoint is visible on all nodes!
    # create a temporary cpu pg to broadcast most up-to-date checkpoint
    with tmp_process_group(backend="gloo") as pg:
        rank = dist.get_rank(group=pg)

        # get rank that has the largest state.epoch
        epochs = torch.zeros(dist.get_world_size(), dtype=torch.int32)
        epochs[rank] = state.epoch
        dist.all_reduce(epochs, op=dist.ReduceOp.SUM, group=pg)
        t_max_epoch, t_max_rank = torch.max(epochs, dim=0)
        max_epoch = t_max_epoch.item()
        max_rank = t_max_rank.item()

        # max_epoch == -1 means no one has checkpointed return base state
        if max_epoch == -1:
            print(f"=> no workers have checkpoints, starting from epoch 0")
            return state

        # broadcast the state from max_rank (which has the most up-to-date state)
        # pickle the snapshot, convert it into a byte-blob tensor
        # then broadcast it, unpickle it and apply the snapshot
        print(f"=> using checkpoint from rank: {max_rank}, max_epoch: {max_epoch}")

        with io.BytesIO() as f:
            torch.save(state.capture_snapshot(), f)
            raw_blob = numpy.frombuffer(f.getvalue(), dtype=numpy.uint8)

        blob_len = torch.tensor(len(raw_blob))
        dist.broadcast(blob_len, src=max_rank, group=pg)
        print(f"=> checkpoint broadcast size is: {blob_len}")

        if rank != max_rank:
            # pyre-fixme[6]: For 1st param expected `Union[List[int], Size,
            #  typing.Tuple[int, ...]]` but got `Union[bool, float, int]`.
            blob = torch.zeros(blob_len.item(), dtype=torch.uint8)
        else:
            blob = torch.as_tensor(raw_blob, dtype=torch.uint8)

        dist.broadcast(blob, src=max_rank, group=pg)
        print(f"=> done broadcasting checkpoint")

        if rank != max_rank:
            with io.BytesIO(blob.numpy()) as f:
                snapshot = torch.load(f)
            state.apply_snapshot(snapshot, device_id)

        # wait till everyone has loaded the checkpoint
        dist.barrier(group=pg)

    print(f"=> done restoring from previous checkpoint")
    return state


@contextmanager
def tmp_process_group(backend):
    cpu_pg = dist.new_group(backend=backend)
    try:
        yield cpu_pg
    finally:
        dist.destroy_process_group(cpu_pg)


def save_checkpoint(state: State, is_best: bool, filename: str):
    checkpoint_dir = os.path.dirname(filename)
    os.makedirs(checkpoint_dir, exist_ok=True)

    # save to tmp, then commit by moving the file in case the job
    # gets interrupted while writing the checkpoint
    tmp_filename = filename + ".tmp"
    torch.save(state.capture_snapshot(), tmp_filename)
    os.rename(tmp_filename, filename)
    print(f"=> saved checkpoint for epoch {state.epoch} at {filename}")
    if is_best:
        best = os.path.join(checkpoint_dir, "model_best.pth.tar")
        print(f"=> best model found at epoch {state.epoch} saving to {best}")
        shutil.copyfile(filename, best)


def train(
    train_loader: DataLoader,
    model: DistributedDataParallel,
    criterion,  # nn.CrossEntropyLoss
    optimizer,  # SGD,
    epoch: int,
    device_id: int,
    print_freq: int,
):
    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":.4e")
    top1 = AverageMeter("Acc@1", ":6.2f")
    top5 = AverageMeter("Acc@5", ":6.2f")
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch+1),
    )

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        images = images.cuda(device_id, non_blocking=True)
        target = target.cuda(device_id, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            progress.display(i)


def validate(
    val_loader: DataLoader,
    model: DistributedDataParallel,
    criterion,  # nn.CrossEntropyLoss
    device_id: int,
    print_freq: int,
):
    batch_time = AverageMeter("Time", ":6.3f")
    losses = AverageMeter("Loss", ":.4e")
    top1 = AverageMeter("Acc@1", ":6.2f")
    top5 = AverageMeter("Acc@5", ":6.2f")
    progress = ProgressMeter(
        len(val_loader), [batch_time, losses, top1, top5], prefix="Test: "
    )

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if device_id is not None:
                images = images.cuda(device_id, non_blocking=True)
            target = target.cuda(device_id, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                progress.display(i)

        # TODO: this should also be done with the ProgressMeter
        print(
            " * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}".format(top1=top1, top5=top5)
        )

    return top1.avg


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name: str, fmt: str = ":f"):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self) -> None:
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1) -> None:
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches: int, meters: List[AverageMeter], prefix: str = ""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch: int) -> None:
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print("\t".join(entries))

    def _get_batch_fmtstr(self, num_batches: int) -> str:
        num_digits = len(str(num_batches // 1))
        fmt = "{:" + str(num_digits) + "d}"
        return "[" + fmt + "/" + fmt.format(num_batches) + "]"


def adjust_learning_rate(optimizer, epoch: int, lr: float) -> None:
    """
    Sets the learning rate to the initial LR decayed by 10 every 30 epochs
    """
    learning_rate = lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group["lr"] = learning_rate


def accuracy(output, target, topk=(1,)):
    """
    Computes the accuracy over the k top predictions for the specified values of k
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(1, -1).view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        trace_str = ''.join(traceback.format_tb(e.__traceback__))
        print(trace_str)


### Build custom container

In [None]:
CONTENT_NAME = "pytorch-torchrun-imagenet-multi-node"
CONTAINER_NAME = CONTENT_NAME + "-gpu"
TAG = "latest"

custom_container_host_image_uri = (
    f"{REGION}-docker.pkg.dev/{PROJECT_ID}/{REPOSITORY}/{CONTAINER_NAME}:{TAG}"  # noqa
)

In [None]:
!gcloud builds submit \
   --region $REGION \
   --tag $custom_container_host_image_uri \
   --timeout "2h" \
   --machine-type=e2-highcpu-32 \
   trainer

### Run training on Vertex AI using `torchrun` with ETCD on host

In [None]:
from datetime import datetime

BUCKET_NAME = BUCKET_URI.replace("gs://", "")
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
PRIMARY_COMPUTE = "n1-highmem-16"
TRAIN_COMPUTE = "n1-highmem-16"
NUM_CPUS = 14  # Set to a few less than max CPUs per instance for paralle data loading
TRAIN_GPU = "NVIDIA_TESLA_T4"
TRAIN_NGPU = 1
BATCH_SIZE = 512
REPLICAS = 2
EPOCHS = 5
ARCH = "resnet18"
BACKEND = "nccl"  # gloo for CPU only, nccl for GPUs
TRAIN_DATA_LOCATION = (
    "/trainer/tiny-imagenet-200"  # Data location of filed downloaded in Dockerfile
)

display_name = (
    CONTAINER_NAME
    + "-LOCAL-ETCD-"
    + f"{REPLICAS}workers-{TRAIN_NGPU}{TRAIN_GPU}-{BATCH_SIZE}batch-"
    + TIMESTAMP
)
gcs_output_uri_prefix = f"{BUCKET_URI}/{display_name}"

RDZV_BACKEND = "etcd-v2"
RDZV_BACKEND_STORE = f"/gcs/{BUCKET_NAME}/sharedfile-{display_name}"
RDZV_ENDPOINT = "localhost:2379"

# Use letters for each parameter to be processed in the shell script
"""
e)epochs=${OPTARG};;
a)arch=${OPTARG};;
b)batchsize=${OPTARG};;
d)distbackend=${OPTARG};;
t)data=${OPTARG};;
w)workers=${OPTARG};;
v)env=${OPTARG};;
u)rdvzbackend=${OPTARG};;
i)rdvzid=${OPTARG};;
p)endpoint=${OPTARG};;
n)nnodes=${OPTARG};;
r)nprocpernode=${OPTARG};;
c)ischief=${OPTARG};;
"""

CONTAINER_SPEC = {
    "image_uri": custom_container_host_image_uri,
    "command": [
        "/bin/bash",
        "main.sh",
        f"-e {EPOCHS}",
        f"-a {ARCH}",
        f"-b {BATCH_SIZE}",
        f"-d {BACKEND}",
        f"-t {TRAIN_DATA_LOCATION}",
        f"-w {NUM_CPUS}",
        f"-v {RDZV_BACKEND_STORE}",
        f"-u {RDZV_BACKEND}",
        f"-i {display_name}",
        f"-p {RDZV_ENDPOINT}",
        f"-n {REPLICAS+1}",
        f"-r {TRAIN_NGPU}",
        "-c y",
    ],
}

CONTAINER_WORKER_SPEC = {
    "image_uri": custom_container_host_image_uri,
    "command": [
        "/bin/bash",
        "main.sh",
        f"-e {EPOCHS}",
        f"-a {ARCH}",
        f"-b {BATCH_SIZE}",
        f"-d {BACKEND}",
        f"-t {TRAIN_DATA_LOCATION}",
        f"-w {NUM_CPUS}",
        f"-v {RDZV_BACKEND_STORE}",
        f"-u {RDZV_BACKEND}",
        f"-i {display_name}",
        f"-p {RDZV_ENDPOINT}",
        f"-n {REPLICAS+1}",
        f"-r {TRAIN_NGPU}",
        "-c n",
    ],
}

PRIMARY_WORKER_POOL = {
    "replica_count": 1,
    "machine_spec": {
        "machine_type": PRIMARY_COMPUTE,
        "accelerator_count": TRAIN_NGPU,
        "accelerator_type": TRAIN_GPU,
    },
    "container_spec": CONTAINER_SPEC,
}

WORKER_POOL_SPECS = [PRIMARY_WORKER_POOL]

TRAIN_WORKER_POOL = {
    "replica_count": REPLICAS,
    "machine_spec": {
        "machine_type": TRAIN_COMPUTE,
        "accelerator_count": TRAIN_NGPU,
        "accelerator_type": TRAIN_GPU,
    },
    "container_spec": CONTAINER_WORKER_SPEC,
}

WORKER_POOL_SPECS.append(TRAIN_WORKER_POOL)

job = aiplatform.CustomJob(
    display_name=display_name,
    base_output_dir=gcs_output_uri_prefix,
    worker_pool_specs=WORKER_POOL_SPECS,
)

In [None]:
if not os.getenv("IS_TESTING"):
    job.run(
        sync=True
        # comment out the line below to turn off interactive debug
        ,
        enable_web_access=True,
        service_account=SERVICE_ACCOUNT,
    )

### Run training on Vertex AI using `torchrun` with ETCD on host and reduction server

In [None]:
from datetime import datetime

BUCKET_NAME = BUCKET_URI.replace("gs://", "")
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
PRIMARY_COMPUTE = "n1-highmem-16"
TRAIN_COMPUTE = "n1-highmem-16"
REDUCTION_COMPUTE = "n1-highcpu-16"
NUM_CPUS = 14  # Set to a few less than max CPUs per instance for paralle data loading
TRAIN_GPU = "NVIDIA_TESLA_T4"
TRAIN_NGPU = 1
BATCH_SIZE = 512
REPLICAS = 2
EPOCHS = 5
ARCH = "resnet18"
BACKEND = "nccl"  # gloo for CPU only, nccl for GPUs
TRAIN_DATA_LOCATION = (
    "/trainer/tiny-imagenet-200"  # Data location of filed downloaded in Dockerfile
)


display_name = (
    CONTAINER_NAME
    + "-LOCAL-ETCD-reduc-server-"
    + f"{REPLICAS}workers-{TRAIN_NGPU}{TRAIN_GPU}-{BATCH_SIZE}batch-"
    + TIMESTAMP
)
gcs_output_uri_prefix = f"{BUCKET_URI}/{display_name}"

RDZV_BACKEND = "etcd-v2"
RDZV_BACKEND_STORE = f"/gcs/{BUCKET_NAME}/sharedfile-{display_name}"
RDZV_ENDPOINT = "localhost:2379"


# Use letters for each parameter to be processed in the shell script
"""
e)epochs=${OPTARG};;
a)arch=${OPTARG};;
b)batchsize=${OPTARG};;
d)distbackend=${OPTARG};;
t)data=${OPTARG};;
w)workers=${OPTARG};;
v)env=${OPTARG};;
u)rdvzbackend=${OPTARG};;
i)rdvzid=${OPTARG};;
p)endpoint=${OPTARG};;
n)nnodes=${OPTARG};;
r)nprocpernode=${OPTARG};;
c)ischief=${OPTARG};;
"""

CONTAINER_SPEC = {
    "image_uri": custom_container_host_image_uri,
    "command": [
        "/bin/bash",
        "main.sh",
        f"-e {EPOCHS}",
        f"-a {ARCH}",
        f"-b {BATCH_SIZE}",
        f"-d {BACKEND}",
        f"-t {TRAIN_DATA_LOCATION}",
        f"-w {NUM_CPUS}",
        f"-v {RDZV_BACKEND_STORE}",
        f"-u {RDZV_BACKEND}",
        f"-i {display_name}",
        f"-p {RDZV_ENDPOINT}",
        f"-n {REPLICAS+1}",
        f"-r {TRAIN_NGPU}",
        "-c y",
    ],
}

CONTAINER_WORKER_SPEC = {
    "image_uri": custom_container_host_image_uri,
    "command": [
        "/bin/bash",
        "main.sh",
        f"-e {EPOCHS}",
        f"-a {ARCH}",
        f"-b {BATCH_SIZE}",
        f"-d {BACKEND}",
        f"-t {TRAIN_DATA_LOCATION}",
        f"-w {NUM_CPUS}",
        f"-v {RDZV_BACKEND_STORE}",
        f"-u {RDZV_BACKEND}",
        f"-i {display_name}",
        f"-p {RDZV_ENDPOINT}",
        f"-n {REPLICAS+1}",
        f"-r {TRAIN_NGPU}",
        "-c n",
    ],
}

PRIMARY_WORKER_POOL = {
    "replica_count": 1,
    "machine_spec": {
        "machine_type": PRIMARY_COMPUTE,
        "accelerator_count": TRAIN_NGPU,
        "accelerator_type": TRAIN_GPU,
    },
    "container_spec": CONTAINER_SPEC,
}

WORKER_POOL_SPECS = [PRIMARY_WORKER_POOL]

TRAIN_WORKER_POOL = {
    "replica_count": REPLICAS,
    "machine_spec": {
        "machine_type": TRAIN_COMPUTE,
        "accelerator_count": TRAIN_NGPU,
        "accelerator_type": TRAIN_GPU,
    },
    "container_spec": CONTAINER_WORKER_SPEC,
}

WORKER_POOL_SPECS.append(TRAIN_WORKER_POOL)

# Add Reduction Server worker pool
REDUCTION_SERVER_REPLICAS = 3
REDUCTION_SERVER_IMAGE_URI = (
    "us-docker.pkg.dev/vertex-ai-restricted/training/reductionserver:latest"
)

CONTAINER_REDUCTION_SPEC = {"image_uri": REDUCTION_SERVER_IMAGE_URI}

REDUCTION_WORKER_POOL = {
    "replica_count": REDUCTION_SERVER_REPLICAS,
    "machine_spec": {
        "machine_type": REDUCTION_COMPUTE,
    },
    "container_spec": CONTAINER_REDUCTION_SPEC,
}

WORKER_POOL_SPECS.append(REDUCTION_WORKER_POOL)

job = aiplatform.CustomJob(
    display_name=display_name,
    base_output_dir=gcs_output_uri_prefix,
    worker_pool_specs=WORKER_POOL_SPECS,
)

In [None]:
if not os.getenv("IS_TESTING"):
    job.run(
        sync=True
        # comment out the line below to turn off interactive debug
        ,
        enable_web_access=True,
        service_account=SERVICE_ACCOUNT,
    )

## Cleaning up

To clean up all Google Cloud resources used in this project, you can [delete the Google Cloud
project](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects) you used for the tutorial.

Otherwise, you can delete the individual resources you created in this tutorial:

In [None]:
try:
    job.delete()
except Exception as e:
    print(e)
    
gar_images = ! gcloud artifacts docker images list $REGION-docker.pkg.dev/$PROJECT_ID/$REPOSITORY \
      --filter="package~"$(echo $CONTAINER_NAME | sed 's/:.*//') \
      --format="get(package)"

delete_image = True
try:
    if delete_image or os.getenv("IS_TESTING"):
        for image in gar_images:
            # delete only if image name starts with valid region
            if image.startswith(f'{REGION}-docker.pkg.dev'):
                print(f"Deleting image {image} including all tags")
                ! gcloud artifacts docker images delete $image --delete-tags --quiet
except Exception as e:
    print(e)

# Delete Cloud Storage objects that were created
delete_bucket = False
if delete_bucket or os.getenv("IS_TESTING"):
    ! gsutil -m rm -r $BUCKET_URI