In [None]:
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Custom training and batch prediction

<table align="left">
  <td>
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/ai-platform-samples/blob/master/ai-platform-unified/notebooks/official/custom-training-batch-prediction.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Colab logo"> Run in Colab
    </a>
  </td>
  <td>
    <a href="https://github.com/GoogleCloudPlatform/ai-platform-samples/blob/master/ai-platform-unified/notebooks/official/custom-training-batch-prediction.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      View on GitHub
    </a>
  </td>
</table>
<br/><br/><br/>

## Overview


This tutorial demonstrates how to use the Vertex SDK for Python to train and deploy a custom image classification model for batch prediction.

### Dataset

The dataset used for this tutorial is the [cifar10 dataset](https://www.tensorflow.org/datasets/catalog/cifar10) from [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/overview). The version of the dataset you will use is built into TensorFlow. The trained model predicts which type of class an image is from ten classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck.

### Objective

In this notebook, you create a custom-trained model from a Python script in a Docker container using the Vertex SDK for Python, and then do a prediction on the deployed model by sending data. Alternatively, you can create custom-trained models using `gcloud` command-line tool, or online using the Cloud Console.

The steps performed include:

- Create a Vertex AI custom job for training a model.
- Train a TensorFlow model.
- Make a batch prediction.
- Cleanup resources.

### Costs

This tutorial uses billable components of Google Cloud (GCP):

* Vertex AI
* Cloud Storage

Learn about [Vertex AI
pricing](https://cloud.google.com/vertex-ai/pricing) and [Cloud Storage
pricing](https://cloud.google.com/storage/pricing), and use the [Pricing
Calculator](https://cloud.google.com/products/calculator/)
to generate a cost estimate based on your projected usage.

## Installation

Install the latest (preview) version of Vertex SDK for Python.

In [None]:
import os

# The Google Cloud Notebook product has specific requirements
IS_GOOGLE_CLOUD_NOTEBOOK = os.path.exists("/opt/deeplearning/metadata/env_version")

# Google Cloud Notebook requires dependencies to be installed with '--user'
USER_FLAG = ""
if IS_GOOGLE_CLOUD_NOTEBOOK:
    USER_FLAG = "--user"

In [None]:
! pip3 install {USER_FLAG} --upgrade google-cloud-aiplatform

Install the latest GA version of *google-cloud-storage* library as well.

In [None]:
! pip3 install {USER_FLAG} --upgrade google-cloud-storage

Install the *pillow* library for loading images.

In [None]:
! pip3 install {USER_FLAG} --upgrade pillow

Install the *numpy* library for manipulation of image data.

In [None]:
! pip3 install {USER_FLAG} --upgrade numpy

### Restart the kernel

Once you've installed everything, you need to restart the notebook kernel so it can find the packages.

In [None]:
import os

if not os.getenv("IS_TESTING"):
    # Automatically restart kernel after installs
    import IPython

    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)

## Before you begin

### Select a GPU runtime

**Make sure you're running this notebook in a GPU runtime if you have that option. In Colab, select "Runtime --> Change runtime type > GPU"**

### Set up your Google Cloud project

**The following steps are required, regardless of your notebook environment.**

1. [Select or create a Google Cloud project](https://console.cloud.google.com/cloud-resource-manager). When you first create an account, you get a $300 free credit towards your compute/storage costs.

2. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

3. [Enable the Vertex AI API and Compute Engine API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com,compute_component).

4. If you are running this notebook locally, you will need to install the [Cloud SDK](https://cloud.google.com/sdk).

5. Enter your project ID in the cell below. Then run the cell to make sure the
Cloud SDK uses the right project for all the commands in this notebook.

**Note**: Jupyter runs lines prefixed with `!` as shell commands, and it interpolates Python variables prefixed with `$` into these commands.

#### Set your project ID

**If you don't know your project ID**, you may be able to get your project ID using `gcloud`.

In [None]:
import os

PROJECT_ID = ""

if not os.getenv("IS_TESTING"):
    # Get your Google Cloud project ID from gcloud
    shell_output=!gcloud config list --format 'value(core.project)' 2>/dev/null
    PROJECT_ID = shell_output[0]
    print("Project ID: ", PROJECT_ID)

Otherwise, set your project ID here.

In [None]:
if PROJECT_ID == "" or PROJECT_ID is None:
    PROJECT_ID = "[your-project-id]"  # @param {type:"string"}

### Authenticate your Google Cloud account

**If you are using Google Cloud Notebooks**, your environment is already
authenticated. Skip this step.

**If you are using Colab**, run the cell below and follow the instructions
when prompted to authenticate your account via oAuth.

**Otherwise**, follow these steps:

1. In the Cloud Console, go to the [**Create service account key**
   page](https://console.cloud.google.com/apis/credentials/serviceaccountkey).

2. Click **Create service account**.

3. In the **Service account name** field, enter a name, and
   click **Create**.

4. In the **Grant this service account access to project** section, click the **Role** drop-down list. Type "Vertex AI"
into the filter box, and select
   **Vertex AI Administrator**. Type "Storage Object Admin" into the filter box, and select **Storage Object Admin**.

5. Click *Create*. A JSON file that contains your key downloads to your
local environment.

6. Enter the path to your service account key as the
`GOOGLE_APPLICATION_CREDENTIALS` variable in the cell below and run the cell.

In [None]:
import sys

# If you are running this notebook in Colab, run this cell and follow the
# instructions to authenticate your GCP account. This provides access to your
# Cloud Storage bucket and lets you submit training jobs and prediction
# requests.

# The Google Cloud Notebook product has specific requirements
IS_GOOGLE_CLOUD_NOTEBOOK = os.path.exists("/opt/deeplearning/metadata/env_version")

# If on Google Cloud Notebooks, then don't execute this code
if not IS_GOOGLE_CLOUD_NOTEBOOK:
    if "google.colab" in sys.modules:
        from google.colab import auth as google_auth

        google_auth.authenticate_user()

    # If you are running this notebook locally, replace the string below with the
    # path to your service account key and run this cell to authenticate your GCP
    # account.
    elif not os.getenv("IS_TESTING"):
        %env GOOGLE_APPLICATION_CREDENTIALS ''

In [None]:
%env GOOGLE_APPLICATION_CREDENTIALS=/home/jupyter/credentials.json

### Create a Cloud Storage bucket

**The following steps are required, regardless of your notebook environment.**

When you submit a training job using the Cloud SDK, you upload a Python package
containing your training code to a Cloud Storage bucket. Vertex AI runs
the code from this package. In this tutorial, Vertex AI also saves the
trained model that results from your job in the same bucket. Using this model artifact, you can then create Vertex AI model resources.

Set the name of your Cloud Storage bucket below. It must be unique across all
Cloud Storage buckets.

You may also change the `REGION` variable, which is used for operations
throughout the rest of this notebook. Make sure to [choose a region where Vertex AI services are
available](https://cloud.google.com/vertex-ai/docs/general/locations#available_regions). You may
not use a Multi-Regional Storage bucket for training with Vertex AI.

In [None]:
BUCKET_NAME = "gs://kubeflow-cloudinsight-dev-vertexai"  # @param {type:"string"}
REGION = "us-central1"  # @param {type:"string"}

In [None]:
if BUCKET_NAME == "" or BUCKET_NAME is None or BUCKET_NAME == "gs://[your-bucket-name]":
    BUCKET_NAME = "gs://" + PROJECT_ID + "aip-" + TIMESTAMP

**Only if your bucket doesn't already exist**: Run the following cell to create your Cloud Storage bucket.

In [None]:
! gsutil mb -l $REGION $BUCKET_NAME

Finally, validate access to your Cloud Storage bucket by examining its contents:

In [None]:
! gsutil ls -al $BUCKET_NAME

### Set up variables

Next, set up some variables used throughout the tutorial.

#### Import Vertex SDK for Python

Import the Vertex SDK for Python into your Python environment and initialize it.

In [None]:
import os
import sys

from google.cloud import aiplatform
from google.cloud.aiplatform import gapic as aip

aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_NAME)

# Tutorial

Now you are ready to start creating your own custom-trained model with CIFAR10.

## Train a model

There are two ways you can train a custom model using a container image:

- **Use a Google Cloud prebuilt container**. If you use a prebuilt container, you will additionally specify a Python package to install into the container image. This Python package contains your code for training a custom model.

- **Use your own custom container image**. If you use your own container, the container needs to contain your code for training a custom model.

### Define the command args for the training script

Prepare the command-line arguments to pass to your training script.
- `args`: The command line arguments to pass to the corresponding Python module. In this example, they will be:
  - `"--epochs=" + EPOCHS`: The number of epochs for training.
  - `"--steps=" + STEPS`: The number of steps (batches) per epoch.
  - `"--distribute=" + TRAIN_STRATEGY"` : The training distribution strategy to use for single or distributed training.
     - `"single"`: single device.
     - `"mirror"`: all GPU devices on a single compute instance.
     - `"multi"`: all GPU devices on all compute instances.

"""
#### Training script

In the next cell, you will write the contents of the training script, `task.py`. In summary:

- Get the directory where to save the model artifacts from the environment variable `AIP_MODEL_DIR`. This variable is set by the training service.
- Loads CIFAR10 dataset from TF Datasets (tfds).
- Builds a model using TF.Keras model API.
- Compiles the model (`compile()`).
- Sets a training distribution strategy according to the argument `args.distribute`.
- Trains the model (`fit()`) with epochs and steps according to the arguments `args.epochs` and `args.steps`
- Saves the trained model (`save(MODEL_DIR)`) to the specified model directory.
"""

In [None]:
%%writefile task.py
import sys
import argparse
import os
import random
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import webdataset as wds

model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--data', metavar='DIR', default='./home/jupyter/data',
                    help='path to dataset')
parser.add_argument('--loader', default='orig', help='loader to use: orig, wds')
parser.add_argument('--shuffle', type=int, default=1000, help='shuffle buffer size for WebDataset')
parser.add_argument('--trainshards', default='./cifar-shards/train/cifar-train-{000000..000694}.tar', help='path/URL for ImageNet shards',
)
parser.add_argument('--valshards', default='./cifar-shards/val/cifar-val-{000000..000490}.tar', help='path/URL for ImageNet shards',
)
parser.add_argument('--trainsize', type=int, default=1300000, help='Training set size')
parser.add_argument('--valsize', type=int, default=100000, help='Validation set size')
parser.add_argument('--augmentation', default='full')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--batch_size', default=16, type=int, metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)',
                    dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', #updated to not use pretrained model by default
                    help='use pre-trained model')
parser.add_argument('--world-size', default=int(os.getenv('WORLD_SIZE', -1)), type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--rank', default=int(os.getenv('RANK', -1)), type=int,
                    help='node rank for distributed training')
parser.add_argument('--dist-url', default='http://localhost:8082', type=str,
                    help='url used to set up distributed training') # From https://cloud.google.com/ai-platform/training/docs/distributed-pytorch#updating_your_training_code
parser.add_argument('--dist-backend', default='nccl', type=str,
                    help='distributed backend')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')
parser.add_argument('--multiprocessing-distributed', action='store_true',
                    help='Use multi-processing distributed training to launch '
                         'N processes per node, which has N GPUs. This is the '
                         'fastest way to use PyTorch for either single node or '
                         'multi node data parallel training')

best_acc1 = 0

def main():
    args = parser.parse_args()
    # Setting to custom model
    args.pretrained == False
    
    """
    train_transform = make_train_transform(args)
    num_batches = args.trainsize
    train_dataset = (
        wds.WebDataset('./home/jupyter/data',length=num_batches)#("https://storage.cloud.google.com/kubeflow-cloudinsight-dev-vertexai/cifar-shards/train/cifar-train-{000000..000000}.tar", length=num_batches)
        .shuffle(args.batch_size)
        .decode("pil")
        .to_tuple("jpg;png;jpeg cls")
        .map_tuple(train_transform, identity)
    )
    
    print ('success')
    print (train_dataset, identity)
    
    
    # Setting to single server
    args.multiprocessing_distributed == False
    """
    
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed
    # debugging
    print (args.distributed)

    ngpus_per_node = torch.cuda.device_count()
    if args.multiprocessing_distributed:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        args.world_size = ngpus_per_node * args.world_size
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    else:
        # Simply call main_worker function
        main_worker(args.gpu, ngpus_per_node, args)


class Cifar10Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = Cifar10Model()
        #model = models.__dict__[args.arch]()

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) # From https://cloud.google.com/ai-platform/training/docs/distributed-pytorch#updating_your_training_code
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    train_loader = eval(f"make_train_loader_{args.loader}")(args)

    val_loader = eval(f"make_val_loader_{args.loader}")(args)

    #val_loader = make_val_loader(args)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            sampler = getattr(train_loader, "sampler", None)
            if sampler is not None and hasattr(sampler, "set_epoch"):
                sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
            }, is_best)

def make_train_transform(args):
    
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    """
    if args.augmentation == "full":
        return transforms.Compose(
            [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        )
    elif args.augmentation == "simple":
        print("=> using simple augmentation")
        return transforms.Compose(
            [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize,]
        )
    """
    return transforms.Compose(
    #[transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize,]
    [transforms.ToTensor(), normalize,]
    )

def make_val_transform(args):
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    return transforms.Compose(
        #[transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize,]
        [transforms.ToTensor(), normalize,]
    )


def make_train_loader_orig(args):
    print("=> using file based loader")
    traindir = os.path.join(args.data, "train")
    train_transform = make_train_transform(args)
    train_dataset = datasets.ImageFolder(traindir, train_transform)
    args.trainsize = len(train_dataset)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        sampler=train_sampler,
    )

    return train_loader


def identity(x):
    return x


def worker_urls(urls):
    result = wds.worker_urls(urls)
    print("worker_urls returning", len(result), "of", len(urls), "urls", file=sys.stderr)
    return result


def make_train_loader_wds(args):
    print("=> using WebDataset loader")
    train_transform = make_train_transform(args)
    num_batches = args.trainsize // args.batch_size
    train_dataset = (
        wds.WebDataset(args.trainshards, length=num_batches)
        .shuffle(args.shuffle)
        .decode("pil") # changed from pil to ppm
        .to_tuple("ppm cls") #Changed to match image and metadata format
        #.to_tuple("jpg;png;jpeg cls")
        .map_tuple(train_transform, identity)
    )
    if args.distributed:
        # It's good to avoid partial batches when using DistributedDataParallel.
        train_dataset = train_dataset.batched(args.batch_size, partial=False)
    else:
        train_dataset = train_dataset.batched(args.batch_size)
    # WebLoader is just the regular DataLoader with the same convenience methods
    # that WebDataset has.
    train_loader = wds.WebLoader(
        train_dataset, batch_size=None, shuffle=False, num_workers=args.workers,
    )
    if args.distributed:
        # With DDP, we need to make sure that all nodes get the same number of batches;
        # we do that by reusing a little bit of data.
        # Note that you only need to do this when retrofitting code that depends on
        # epoch size. A better way is to iterate through the entire dataset on all nodes.
        dataset_size = args.trainsize #1281167
        number_of_batches = dataset_size // (args.batch_size * args.world_size)
        print("# batches per node = ", number_of_batches)
        train_loader = train_loader.repeat(2).slice(number_of_batches)
        # This only sets the value returned by the len() function; nothing else uses it,
        # but some frameworks care about it.
        train_loader.length = number_of_batches
    
    print ('returning dataset')
    return train_loader


def make_val_loader_orig(args):
    valdir = os.path.join(args.data, "val")
    val_transform = make_val_transform(args)
    val_dataset = datasets.ImageFolder(valdir, val_transform)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True
    )
    return val_loader


def make_val_loader_wds(args):
    print("=> using WebDataset loader")
    val_transform = make_val_transform(args)
    num_batches = args.valsize // args.batch_size
    val_dataset = (
        wds.WebDataset(args.valshards, length=num_batches)
        .shuffle(args.shuffle)
        .decode("pil")
        .to_tuple("ppm cls") #Changed to match image and metadata format
        #.to_tuple("jpg;png;jpeg cls")
        .map_tuple(val_transform, identity)
    )
    if args.distributed:
        # It's good to avoid partial batches when using DistributedDataParallel.
        val_dataset = val_dataset.batched(args.batch_size, partial=False)
    else:
        val_dataset = val_dataset.batched(args.batch_size)
    # WebLoader is just the regular DataLoader with the same convenience methods
    # that WebDataset has.
    val_loader = wds.WebLoader(
        val_dataset, batch_size=None, shuffle=False, num_workers=args.workers,
    )
    if args.distributed:
        # With DDP, we need to make sure that all nodes get the same number of batches;
        # we do that by reusing a little bit of data.
        # Note that you only need to do this when retrofitting code that depends on
        # epoch size. A better way is to iterate through the entire dataset on all nodes.
        dataset_size = args.valsize #1281167
        number_of_batches = dataset_size // (args.batch_size * args.world_size)
        print("# batches per node = ", number_of_batches)
        val_loader = val_loader.repeat(2).slice(number_of_batches)
        # This only sets the value returned by the len() function; nothing else uses it,
        # but some frameworks care about it.
        val_loader.length = number_of_batches
    return val_loader


def train(train_loader, model, criterion, optimizer, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        data_time.update(time.time() - end)
    
        # get the images; data is a list of [images, targets]
        images, target = data

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))
        
        # forward + backward + optimize
        output = model(images)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        """
        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0
        """
        
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
    """
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
        target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
    """

def validate(val_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def adjust_learning_rate(optimizer, epoch, args):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.reshape(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


if __name__ == '__main__':
    main()

In [None]:
#Set env for local test
os.environ["MASTER_ADDR"]="localhost"
os.environ["MASTER_PORT"]="8082"
!python3 task.py --dist-backend 'nccl' \
--world-size 1 \
--rank 0 \
--batch_size 16 \
--trainsize 3000 \
--valsize 1000 \
--data '/home/jupyter/data'
#--trainshards 'https://storage.cloud.google.com/kubeflow-cloudinsight-dev-vertexai/cifar-shards/train/cifar-train-{000000..000694}.tar' \
#--valshards 'https://storage.cloud.google.com/kubeflow-cloudinsight-dev-vertexai/cifar-shards/val/cifar-val-{000000..000049}.tar' \
#--multiprocessing-distributed \
#--dist-url "env://"

In [None]:
%%writefile task.py
import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.multiprocessing as mp
import torch.distributed as dist

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

import argparse
import os
import random
import numpy as np
from datetime import datetime
from google.cloud import storage


def set_random_seeds(random_seed=0):

    torch.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

def evaluate(model, device, test_loader):

    model.eval()

    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total

    return accuracy

def main():

    model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

    # Each process runs on GPU devices specified by the local_rank argument.
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--local_rank", type=int, help="Local rank. Necessary for using the torch.distributed.launch utility.")
    parser.add_argument("--num_epochs", type=int, help="Number of training epochs.", default=100)
    parser.add_argument("--batch_size", type=int, help="Training batch size for one process.", default=1024)
    parser.add_argument("--learning_rate", dest='learning_rate', type=float, help="Learning rate.", default=0.1)
    parser.add_argument("--random_seed", type=int, help="Random seed.", default=0)
    parser.add_argument("--model_dir", type=str, help="Directory for saving models.", default=os.environ['AIP_MODEL_DIR'] if 'AIP_MODEL_DIR' in os.environ else "")
    parser.add_argument("--model_filename", type=str, help="Model filename.", default="resnet_distributed.pth")
    parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet18)')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)',
                    dest='weight_decay')
    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
    parser.add_argument('--local_training', dest='local_training', action='store_true',
                    help='use local machine for training')
    parser.add_argument('--rank', default=-1, type=int,
                    help='node rank for distributed training')
    parser.add_argument("--resume", action="store_true", help="Resume training from saved checkpoint.")
    parser.add_argument('--world-size', default=int(os.getenv('WORLD_SIZE', -1)), type=int,
                    help='number of nodes for distributed training')
    parser.add_argument('--dist-url', default='http://localhost:8082', type=str,
                    help='url used to set up distributed training') # From https://cloud.google.com/ai-platform/training/docs/distributed-pytorch#updating_your_training_code
    parser.add_argument('--dist-backend', default='nccl', type=str,
                    help='distributed backend')
    parser.add_argument('--multiprocessing-distributed', action='store_true',
                    help='Use multi-processing distributed training to launch '
                         'N processes per node, which has N GPUs. This is the '
                         'fastest way to use PyTorch for either single node or '
                         'multi node data parallel training')  
    parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')  
    parser.add_argument('--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
    argv = parser.parse_args()

    if argv.dist_url == "env://" and argv.world_size == -1:
        argv.world_size = int(os.environ["WORLD_SIZE"])

    argv.distributed = argv.world_size > 1 or argv.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()
    
    # debugging
    print (f"os WORLD_SIZE={os.getenv('WORLD_SIZE', -1)}")
    print (f"os RANK={os.getenv('RANK', 0)}")
    print (f"os MASTER_ADDR={os.getenv('MASTER_ADDR', 'localhost')}")
    print (f"os MASTER_PORT={os.getenv('MASTER_PORT', '8082')}")
    print (f'Arg - distributed={argv.distributed}')
    print (f'Arg - multiprocessing_distributed={argv.multiprocessing_distributed}')
    print (f'Arg - dist_backend={argv.dist_backend}')
    print (f'Arg - dist_url={argv.dist_url}')
    print (f'ngpus_per_node={ngpus_per_node}')

    start = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    print (f'Starting training: {start}')       
    if argv.multiprocessing_distributed:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        argv.world_size = ngpus_per_node * argv.world_size
        print ('GPU x WORLD SIZE = {}'.format(argv.world_size))
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, argv))
    else:
        # Simply call main_worker function
        main_worker(argv.gpu, ngpus_per_node, argv)

    end = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    print (f'Training complete: {end}')

def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu
    
    if args.gpu is not None:
        print("Use GPU: {args.gpu} for training")

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
            print (f"Distributed and getting rank from os.environ: rank={args.rank}")
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
            print (f"Distributed and Multiprocesing. Setting rank for each worker. rank={args.rank}")
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
        print ("Process group initialized")
    
    num_epochs = args.num_epochs
    batch_size = args.batch_size
    #learning_rate = args.learning_rate
    random_seed = args.random_seed
    model_dir = args.model_dir
    model_filename = args.model_filename
    resume = args.resume
    model_filepath = os.path.join(model_dir, model_filename)

    # We need to use seeds to make sure that the models initialized in different processes are the same
    set_random_seeds(random_seed=random_seed)
    print ("random seeds")
    
    # create model
    if args.pretrained:
        print(f"=> using pre-trained model '{args.arch}'")
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print(f"=> creating model '{args.arch}'")
        model = models.__dict__[args.arch]()   

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)

            # Encapsulate the model on the GPU assigned to the current process
            device = torch.device("cuda:{}".format(args.rank))
            print (f"Distributed GPU device={device}")
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
            print (f"Distributed CPU device used")
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        device = torch.device(f"cuda:{args.rank}")
        print (f"Non-distributed GPU device id ={args.gpu}")
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()  
        print (f"Non distributed CPU device used")

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    optimizer = torch.optim.SGD(model.parameters(), args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    
    # We only save the model who uses device "cuda:0"
    # To resume, the device for the saved model would also be "cuda:0"
    #if resume == True:
    #    map_location = {"cuda:0": "cuda:{}".format(local_rank)}
    #    ddp_model.load_state_dict(torch.load(model_filepath, map_location=map_location))

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = f'cuda:{args.gpu}'
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print(f"=> no checkpoint found at '{args.resume}'")        
    
    cudnn.benchmark = True    
    
    # Prepare dataset and dataloader
    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Data should be prefetched
    # Download should be set to be False, because it is not multiprocess safe
    train_set = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=transform) 
    test_set = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=transform)

    # Restricts data loading to a subset of the dataset exclusive to the current process
    train_sampler = DistributedSampler(dataset=train_set)

    # Load training data - set num_workers to turn multi-process data loading
    train_loader = DataLoader(dataset=train_set, batch_size=batch_size, sampler=train_sampler, num_workers=8)
    
    # Test loader does not have to follow distributed sampling strategy
    # Load training data - set num_workers to turn multi-process data loading
    test_loader = DataLoader(dataset=test_set, batch_size=128, shuffle=False, num_workers=8)

    #criterion = nn.CrossEntropyLoss()
    #optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)

    # Loop over the dataset multiple times
    for epoch in range(num_epochs):

        epoch_start = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
        print("Rank: {}, Epoch: {}, Training start: {}".format(args.rank, epoch, epoch_start))
        
        # Evaluate model routinely
        if epoch % 10 == 0:
            if args.rank == 0:
                accuracy = evaluate(model=model, device=device, test_loader=test_loader)
                if args.local_training:
                    torch.save(model.state_dict(), model_filepath)
                    print ('saving model to local folders')
                else:
                    # Save locally, then copy to GCS - https://cloud.google.com/vertex-ai/docs/training/exporting-model-artifacts
                    torch.save(model.state_dict(), model_filename)
                    # Upload model artifact to Cloud Storage
                    model_directory = os.environ['AIP_MODEL_DIR']
                    print (f"AIP_MODEL_DIR={model_directory}")
                    storage_path = os.path.join(model_directory,'model', model_filename)
                    print (f"storage_path={storage_path}")
                    blob = storage.blob.Blob.from_string(storage_path, client=storage.Client())
                    blob.upload_from_filename(model_filename)
                print("-" * 75)
                epoch_middle = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
                print(f"Epoch: {epoch}, Accuracy: {accuracy}, Time: {epoch_middle}")
                print("-" * 75)

        # Switch to training mode
        model.train()    
            
        for data in train_loader:
            if args.gpu is not None:
                inputs = data[0].cuda(args.gpu, non_blocking=True)
                print(f'training with gpu {args.gpu}')
            labels = data[1].cuda(args.gpu, non_blocking=True)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
    epoch_end = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    print (f'Epoch complete: {epoch_end}')
            
if __name__ == "__main__":
    main()

#### Set hardware accelerators

You can set hardware accelerators for both training and prediction.

Set the variables `TRAIN_GPU/TRAIN_NGPU` and `DEPLOY_GPU/DEPLOY_NGPU` to use a container image supporting a GPU and the number of GPUs allocated to the virtual machine (VM) instance. For example, to use a GPU container image with 4 Nvidia Tesla K80 GPUs allocated to each VM, you would specify:

    (aip.AcceleratorType.NVIDIA_TESLA_K80, 4)

See the [locations where accelerators are available](https://cloud.google.com/vertex-ai/docs/general/locations#accelerators).

Otherwise specify `(None, None)` to use a container image to run on a CPU.

*Note*: TensorFlow releases earlier than 2.3 for GPU support fail to load the custom model in this tutorial. This issue is caused by static graph operations that are generated in the serving function. This is a known issue, which is fixed in TensorFlow 2.3. If you encounter this issue with your own custom models, use a container image for TensorFlow 2.3 or later with GPU support.

In [None]:
TRAIN_GPU, TRAIN_NGPU = (aip.AcceleratorType.NVIDIA_TESLA_V100, 2)

#DEPLOY_GPU, DEPLOY_NGPU = (aip.AcceleratorType.NVIDIA_TESLA_K80, 1)

#### Set pre-built containers

Vertex AI provides pre-built containers to run training and prediction.

For the latest list, see [Pre-built containers for training](https://cloud.google.com/vertex-ai/docs/training/pre-built-containers) and [Pre-built containers for prediction](https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers)

In [None]:
# Uses GPU image for training, no pytorch prebuilt deployment images
TRAIN_VERSION = "pytorch-gpu.1-7"
#DEPLOY_VERSION = "tf2-gpu.2-1"

TRAIN_IMAGE = "us-docker.pkg.dev/vertex-ai/training/{}:latest".format(TRAIN_VERSION)
#DEPLOY_IMAGE = "gcr.io/cloud-aiplatform/prediction/{}:latest".format(DEPLOY_VERSION)

print("Training:", TRAIN_IMAGE, TRAIN_GPU, TRAIN_NGPU)
#print("Deployment:", DEPLOY_IMAGE, DEPLOY_GPU, DEPLOY_NGPU)

#### Set machine types

Next, set the machine types to use for training and prediction.

- Set the variables `TRAIN_COMPUTE` and `DEPLOY_COMPUTE` to configure your compute resources for training and prediction.
 - `machine type`
     - `n1-standard`: 3.75GB of memory per vCPU
     - `n1-highmem`: 6.5GB of memory per vCPU
     - `n1-highcpu`: 0.9 GB of memory per vCPU
 - `vCPUs`: number of \[2, 4, 8, 16, 32, 64, 96 \]

*Note: The following is not supported for training:*

 - `standard`: 2 vCPUs
 - `highcpu`: 2, 4 and 8 vCPUs

*Note: You may also use n2 and e2 machine types for training and deployment, but they do not support GPUs*.

In [None]:
MACHINE_TYPE = "n1-standard"

VCPU = "16"
TRAIN_COMPUTE = MACHINE_TYPE + "-" + VCPU
print("Train machine type", TRAIN_COMPUTE)
TRAIN_NCOMPUTE_MASTER = 1
TRAIN_NCOMPUTE_WORKER = 2

MACHINE_TYPE = "n1-standard"

VCPU = "16"
DEPLOY_COMPUTE = MACHINE_TYPE + "-" + VCPU
print("Deploy machine type", DEPLOY_COMPUTE)
DEPLOY_NCOMPUTE = 1

### Train the model

Define your custom training job on Vertex AI.

Use the `CustomTrainingJob` class to define the job, which takes the following parameters:

- `display_name`: The user-defined name of this training pipeline.
- `script_path`: The local path to the training script.
- `container_uri`: The URI of the training container image.
- `requirements`: The list of Python package dependencies of the script.
- `model_serving_container_image_uri`: The URI of a container that can serve predictions for your model — either a prebuilt container or a custom container.

Use the `run` function to start training, which takes the following parameters:

- `args`: The command line arguments to be passed to the Python script.
- `replica_count`: The number of worker replicas.
- `model_display_name`: The display name of the `Model` if the script produces a managed `Model`.
- `machine_type`: The type of machine to use for training.
- `accelerator_type`: The hardware accelerator type.
- `accelerator_count`: The number of accelerators to attach to a worker replica.

The `run` function creates a training pipeline that trains and creates a `Model` object. After the training pipeline completes, the `run` function returns the `Model` object.

In [None]:
from datetime import datetime

TIMESTAMP = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
print (TIMESTAMP)

In [None]:
JOB_NAME = "cifar10_resnet_custom_job_" + TIMESTAMP

ARGS = [
    "--dist-url=" + "env://",
    "--multiprocessing-distributed",
    "--num_epochs=100"
]

In [None]:
#Set env for local test
os.environ["MASTER_ADDR"]="localhost"
os.environ["MASTER_PORT"]="8082"
os.environ["RANK"]="0"
os.environ['WORLD_SIZE']="1"
!python3 task.py --rank 0 \
--dist-url "env://" \
--multiprocessing-distributed \
--num_epochs 1 \
--model_dir "saved_models" \
--local_training

In [None]:
base_output_dir = '{}/jobs/{}'.format(BUCKET_NAME, JOB_NAME)

job = aiplatform.CustomTrainingJob(
    display_name=JOB_NAME,
    script_path="task.py",
    container_uri=TRAIN_IMAGE,
    staging_bucket=base_output_dir,
#    requirements=["tensorflow_datasets==1.3.0"],
#    model_serving_container_image_uri=DEPLOY_IMAGE,
)

MODEL_DISPLAY_NAME = "cifar10-pytorch-" + TIMESTAMP

# Start the training
if TRAIN_GPU:
    model = job.run(
#        model_display_name=MODEL_DISPLAY_NAME,
        args=ARGS,
        replica_count=TRAIN_NCOMPUTE_WORKER + TRAIN_NCOMPUTE_MASTER,
        machine_type=TRAIN_COMPUTE,
        accelerator_type=TRAIN_GPU.name,
        accelerator_count=TRAIN_NGPU,
    )
else:
    model = job.run(
        model_display_name=MODEL_DISPLAY_NAME,
        args=ARGS,
        replica_count=1,
        machine_type=TRAIN_COMPUTE,
        accelerator_count=0,
    )

In [None]:
# Using containers to further customize worker pools
WORKER_POOL_SPECS =  [
    {
        "machine_spec": {
            "machine_type": MACHINE_TYPE,
            "accelerator_type": TRAIN_GPU,
            "accelerator_count": TRAIN_NGPU,
        },
        "replica_count": TRAIN_NCOMPUTE_MASTER,
        "container_spec": {
            "image_uri": TRAIN_IMAGE,
#            "args": [
#                '--pretrained',
#                '--dist-url' + 'env://',
#                '--multiprocessing-distributed',
#            ],
        },
    },
    {
        "machine_spec": {
            "machine_type": MACHINE_TYPE,
            "accelerator_type": TRAIN_GPU,
            "accelerator_count": TRAIN_NGPU,
        },
        "replica_count": TRAIN_NCOMPUTE_WORKER,
        "container_spec": {
            "image_uri": TRAIN_IMAGE,
#            "args": [
#                '--pretrained',
#                '--dist-url' + 'env://',
#                '--multiprocessing-distributed',
#            ],
        },
    }
]

ARGS = [
    "--pretrained",
    "--dist-url" + "env://",
    "--multiprocessing-distributed",
]

In [None]:
JOB_NAME = "cifar10_resnet_custom_job_" + TIMESTAMP
MODEL_DIR = "{}/{}".format(BUCKET_NAME, JOB_NAME)
BASE_OUTPUT_DIR = "{}/{}/staging".format(BUCKET_NAME, JOB_NAME)
MODEL_DISPLAY_NAME = "cifar10-pytorch-" + TIMESTAMP
SCRIPT_PATH='/home/jupter/task.py'


job = aiplatform.CustomJob.from_local_script(
    display_name=JOB_NAME,
    script_path=SCRIPT_PATH,
    #container_uri=TRAIN_IMAGE,
    worker_pool_specs = WORKER_POOL_SPECS,
    args=CMDARGS,
    staging_bucket=BASE_OUTPUT_DIR
    #requirements=["tensorflow_datasets==1.3.0"],
    #model_serving_container_image_uri=DEPLOY_IMAGE,
)

# Start the training
if TRAIN_GPU:
    model = job.run(
        model_display_name=MODEL_DISPLAY_NAME,
        args=CMDARGS,
        replica_count=4,
        machine_type=TRAIN_COMPUTE,
        accelerator_type=TRAIN_GPU.name,
        accelerator_count=TRAIN_NGPU,
    )
else:
    model = job.run(
        model_display_name=MODEL_DISPLAY_NAME,
        args=CMDARGS,
        replica_count=1,
        machine_type=TRAIN_COMPUTE,
        accelerator_count=0,
    )



job.run(
    model_display_name=MODEL_DISPLAY_NAME,
    args=ARGS,
    )

## Make a batch prediction request

Send a batch prediction request to your deployed model.

### Get test data

Download images from the CIFAR dataset and preprocess them.

#### Download the test images

Download the provided set of images from the CIFAR dataset:

In [None]:
# Download the images
! gsutil -m cp -r gs://cloud-samples-data/ai-platform-unified/cifar_test_images .

#### Preprocess the images
Before you can run the data through the endpoint, you need to preprocess it to match the format that your custom model defined in `task.py` expects.

`x_test`:
Normalize (rescale) the pixel data by dividing each pixel by 255. This replaces each single byte integer pixel with a 32-bit floating point number between 0 and 1.

`y_test`:
You can extract the labels from the image filenames. Each image's filename format is "image_{LABEL}_{IMAGE_NUMBER}.jpg"

In [None]:
import numpy as np
from PIL import Image

# Load image data
IMAGE_DIRECTORY = "cifar_test_images"

image_files = [file for file in os.listdir(IMAGE_DIRECTORY) if file.endswith(".jpg")]

# Decode JPEG images into numpy arrays
image_data = [
    np.asarray(Image.open(os.path.join(IMAGE_DIRECTORY, file))) for file in image_files
]

# Scale and convert to expected format
x_test = [(image / 255.0).astype(np.float32).tolist() for image in image_data]

# Extract labels from image name
y_test = [int(file.split("_")[1]) for file in image_files]

#### Prepare data for batch prediction
Before you can run the data through batch prediction, you need to save the data into one of a few possible formats.

For this tutorial, use JSONL as it's compatible with the 3-dimensional list that each image is currently represented in. To do this:

1. In a file, write each instance as JSON on its own line.
2. Upload this file to Cloud Storage.

For more details on batch prediction input formats: https://cloud.google.com/vertex-ai/docs/predictions/batch-predictions#batch_request_input

In [None]:
import json

BATCH_PREDICTION_INSTANCES_FILE = "batch_prediction_instances.jsonl"

BATCH_PREDICTION_GCS_SOURCE = (
    BUCKET_NAME + "/batch_prediction_instances/" + BATCH_PREDICTION_INSTANCES_FILE
)

# Write instances at JSONL
with open(BATCH_PREDICTION_INSTANCES_FILE, "w") as f:
    for x in x_test:
        f.write(json.dumps(x) + "\n")

# Upload to Cloud Storage bucket
! gsutil cp $BATCH_PREDICTION_INSTANCES_FILE $BATCH_PREDICTION_GCS_SOURCE

print("Uploaded instances to: ", BATCH_PREDICTION_GCS_SOURCE)

### Send the prediction request

To make a batch prediction request, call the model object's `batch_predict` method with the following parameters: 
- `instances_format`: The format of the batch prediction request file: "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip" or "file-list"
- `prediction_format`: The format of the batch prediction response file: "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip" or "file-list"
- `job_display_name`: The human readable name for the prediction job.
 - `gcs_source`: A list of one or more Cloud Storage paths to your batch prediction requests.
- `gcs_destination_prefix`: The Cloud Storage path that the service will write the predictions to.
- `model_parameters`: Additional filtering parameters for serving prediction results.
- `machine_type`: The type of machine to use for training.
- `accelerator_type`: The hardware accelerator type.
- `accelerator_count`: The number of accelerators to attach to a worker replica.
- `starting_replica_count`: The number of compute instances to initially provision.
- `max_replica_count`: The maximum number of compute instances to scale to. In this tutorial, only one instance is provisioned.

### Compute instance scaling

You can specify a single instance (or node) to process your batch prediction request. This tutorial uses a single node, so the variables `MIN_NODES` and `MAX_NODES` are both set to `1`.

If you want to use multiple nodes to process your batch prediction request, set `MAX_NODES` to the maximum number of nodes you want to use. Vertex AI autoscales the number of nodes used to serve your predictions, up to the maximum number you set. Refer to the [pricing page](https://cloud.google.com/vertex-ai/pricing#prediction-prices) to understand the costs of autoscaling with multiple nodes.


In [None]:
MIN_NODES = 1
MAX_NODES = 1

# The name of the job
BATCH_PREDICTION_JOB_NAME = "cifar10_batch-" + TIMESTAMP

# Folder in the bucket to write results to
DESTINATION_FOLDER = "batch_prediction_results"

# The Cloud Storage bucket to upload results to
BATCH_PREDICTION_GCS_DEST_PREFIX = BUCKET_NAME + "/" + DESTINATION_FOLDER

# Make SDK batch_predict method call
batch_prediction_job = model.batch_predict(
    instances_format="jsonl",
    predictions_format="jsonl",
    job_display_name=BATCH_PREDICTION_JOB_NAME,
    gcs_source=BATCH_PREDICTION_GCS_SOURCE,
    gcs_destination_prefix=BATCH_PREDICTION_GCS_DEST_PREFIX,
    model_parameters=None,
    machine_type=DEPLOY_COMPUTE,
    accelerator_type=DEPLOY_GPU,
    accelerator_count=DEPLOY_NGPU,
    starting_replica_count=MIN_NODES,
    max_replica_count=MAX_NODES,
    sync=True,
)

### Retrieve batch prediction results
When the batch prediction is done processing, you can finally view the predictions stored at the Cloud Storage path you set as output. The predictions will be in a JSONL format, which you indicated when you created the batch prediction job. The predictions are located in a subdirectory starting with the name prediction. Within that directory, there is a file named prediction.results-xxxx-of-xxxx.

Let's display the contents. You will get a row for each prediction. The row is the softmax probability distribution for the corresponding CIFAR10 classes.

In [None]:
RESULTS_DIRECTORY = "prediction_results"
RESULTS_DIRECTORY_FULL = RESULTS_DIRECTORY + "/" + DESTINATION_FOLDER

# Create missing directories
os.makedirs(RESULTS_DIRECTORY, exist_ok=True)

# Get the Cloud Storage paths for each result
! gsutil -m cp -r $BATCH_PREDICTION_GCS_DEST_PREFIX $RESULTS_DIRECTORY

# Get most recently modified directory
latest_directory = max(
    [
        os.path.join(RESULTS_DIRECTORY_FULL, d)
        for d in os.listdir(RESULTS_DIRECTORY_FULL)
    ],
    key=os.path.getmtime,
)

# Get downloaded results in directory
results_files = []
for dirpath, subdirs, files in os.walk(latest_directory):
    for file in files:
        if file.startswith("prediction.results"):
            results_files.append(os.path.join(dirpath, file))

# Consolidate all the results into a list
results = []
for results_file in results_files:
    # Download each result
    with open(results_file, "r") as file:
        results.extend([json.loads(line) for line in file.readlines()])

### Evaluate results

You can then run a quick evaluation on the prediction results:

1. `np.argmax`: Convert each list of confidence levels to a label
2. Compare the predicted labels to the actual labels
3. Calculate `accuracy` as `correct/total`

To improve the accuracy, try training for a higher number of epochs.

In [None]:
y_predicted = [np.argmax(result["prediction"]) for result in results]

correct = sum(y_predicted == np.array(y_test))
accuracy = len(y_predicted)
print(
    f"Correct predictions = {correct}, Total predictions = {accuracy}, Accuracy = {correct/accuracy}"
)

# Cleaning up

To clean up all Google Cloud resources used in this project, you can [delete the Google Cloud project](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects) you used for the tutorial.

Otherwise, you can delete the individual resources you created in this tutorial:

- Training Job
- Model
- Cloud Storage Bucket

In [None]:
delete_training_job = True
delete_model = True

# Warning: Setting this to true will delete everything in your bucket
delete_bucket = False

# Delete the training job
job.delete()

# Delete the model
model.delete()

if delete_bucket and "BUCKET_NAME" in globals():
    ! gsutil -m rm -r $BUCKET_NAME