# Resnet152 PyTorch Training on Gaudi

In this notebook we will demonstrate how you can train the resnet152 image classifer using Pytorch. We will first demonstrate training on a single HPU, then on 8 HPUs. 

## Setup

Set the python path environment variable and cd into appropriate directory

In [None]:
%set_env PYTHONPATH=/home/ubuntu/work/Model-References/PyTorch/computer_vision/classification/torchvision:/root/examples/models:/usr/lib/habanalabs/:/root

In [None]:
%cd /home/ubuntu/work/Model-References/PyTorch/computer_vision/classification/torchvision

Import libraries necessary for pytorch training and define the argument parser.

In [None]:
from __future__ import print_function

import datetime
import os
import time
import sys

import torch
import torch.utils.data
from torch import nn
import torchvision
import torchvision.datasets as datasets
from torchvision import transforms
import random
import utils

def get_resnet152_argparser():
    import argparse
    import sys
    parser = argparse.ArgumentParser(description='PyTorch Classification Training')
    parser.add_argument('--dl-time-exclude', default='True', type=lambda x: x.lower() == 'true', help='Set to False to include data load time')
    parser.add_argument('-b', '--batch-size', default=128, type=int)
    parser.add_argument('--device', default='hpu', help='device')
    parser.add_argument('--epochs', default=90, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-j', '--workers', default=10, type=int, metavar='N',
                        help='number of data loading workers (default: 10)')
    parser.add_argument('--process-per-node', default=8, type=int, metavar='N',
                        help='Number of process per node')
    parser.add_argument('--hls_type', default='HLS1', help='Node type')
    parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)',
                        dest='weight_decay')

    parser.add_argument('--print-freq', default=1, type=int, help='print frequency')
    parser.add_argument('--output-dir', default='.', help='path where to save')

    parser.add_argument('--channels-last', default='True', type=lambda x: x.lower() == 'true',
                        help='Whether input is in channels last format.'
                        'Any value other than True(case insensitive) disables channels-last')
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--seed', type=int, default=123, help='random seed')
    parser.add_argument('--world-size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--num-train-steps', type=int, default=sys.maxsize, metavar='T',
                        help='number of steps a.k.a iterations to run in training phase')
    parser.add_argument('--num-eval-steps', type=int, default=sys.maxsize, metavar='E',
                        help='number of steps a.k.a iterations to run in evaluation phase')
    parser.add_argument('--save-checkpoint', action="store_true",
                        help='Whether or not to save model/checkpont; True: to save, False to avoid saving')
    parser.add_argument('--run-lazy-mode', action='store_true',
                        help='run model in lazy execution mode')
    parser.add_argument('--deterministic', action="store_true",
                        help='Whether or not to make data loading deterministic;This does not make execution deterministic')
    parser.add_argument('--hmp', dest='is_hmp', action='store_true', help='enable hmp mode')
    parser.add_argument('--hmp-bf16', default='', help='path to bf16 ops list in hmp O1 mode')
    parser.add_argument('--hmp-fp32', default='', help='path to fp32 ops list in hmp O1 mode')
    parser.add_argument('--hmp-opt-level', default='O1', help='choose optimization level for hmp')
    parser.add_argument('--hmp-verbose', action='store_true', help='enable verbose mode for hmp')

    return parser


## Main training function

Uncomment the mark step code block in the appropriate places (after backward loss computation and after optimizer step).

In [None]:
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ",device=device)
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
    metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))

    header = 'Epoch: [{}]'.format(epoch)
    step_count = 0
    last_print_time= time.time()

    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        image, target = image.to(device, non_blocking=True), target.to(device, non_blocking=True)

        dl_ex_start_time=time.time()

        if args.channels_last:
            image = image.contiguous(memory_format=torch.channels_last)

        output = model(image)
        loss = criterion(output, target)
        optimizer.zero_grad(set_to_none=True)

        loss.backward()
        
        #if args.run_lazy_mode:
        #    import habana_frameworks.torch.core as htcore
        #    htcore.mark_step()

        optimizer.step()

        #if args.run_lazy_mode:
        #    import habana_frameworks.torch.core as htcore
        #    htcore.mark_step()

        if step_count % print_freq == 0:
            output_cpu = output.detach().to('cpu')
            acc1, acc5 = utils.accuracy(output_cpu, target, topk=(1, 5))
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size*print_freq)
            metric_logger.meters['acc5'].update(acc5.item(), n=batch_size*print_freq)
            current_time = time.time()
            last_print_time = dl_ex_start_time if args.dl_time_exclude else last_print_time
            metric_logger.meters['img/s'].update(batch_size*print_freq / (current_time - last_print_time))
            last_print_time = time.time()

        step_count = step_count + 1
        if step_count >= args.num_train_steps:
            break

Setup necessary environment variables and command line arguments for single HPU resnet152 training:

In [None]:
os.environ["MAX_WAIT_ATTEMPTS"] = "50"
os.environ['HCL_CPU_AFFINITY'] = '1'
os.environ['PT_HPU_ENABLE_SYNC_OUTPUT_HOST'] = 'false'
parser = get_resnet152_argparser()
   
args = parser.parse_args(["--batch-size", "256", "--epochs", "20", "--workers", "8",
"--dl-time-exclude", "False", "--print-freq", "20", "--channels-last", "False", "--seed", "123", 
"--run-lazy-mode", "--hmp",  "--hmp-bf16", "/home/ubuntu/work/Model-References/PyTorch/computer_vision/classification/torchvision/ops_bf16_Resnet.txt",
"--hmp-fp32", "/home/ubuntu/work/Model-References/PyTorch/computer_vision/classification/torchvision/ops_fp32_Resnet.txt",
"--deterministic"])

Main training code block for single node training. Use fake data to train.

In [None]:
try:
    # Default 'fork' doesn't work with synapse. Use 'forkserver' or 'spawn'
    torch.multiprocessing.set_start_method('spawn')
except RuntimeError:
    pass

if args.is_hmp:
    from habana_frameworks.torch.hpex import hmp
    hmp.convert(opt_level=args.hmp_opt_level, bf16_file_path=args.hmp_bf16,
                fp32_file_path=args.hmp_fp32, isVerbose=args.hmp_verbose)

torch.manual_seed(args.seed)

if args.deterministic:
    seed = args.seed
    random.seed(seed)
else:
    seed = None

device = torch.device('hpu')

torch.backends.cudnn.benchmark = True

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

dataset = datasets.FakeData(transform=transforms.Compose([transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,]))
dataset_test = datasets.FakeData(transform=transforms.Compose([transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,]))


train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)

if args.workers > 0:
    # patch torch cuda functions that are being unconditionally invoked
    # in the multiprocessing data loader
    torch.cuda.current_device = lambda: None
    torch.cuda.set_device = lambda x: None

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=args.batch_size, sampler=train_sampler,
    num_workers=args.workers, pin_memory=True, pin_memory_device='hpu')

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=args.batch_size, sampler=test_sampler,
    num_workers=args.workers, pin_memory=True, pin_memory_device='hpu')


print("Creating model")
model = torchvision.models.__dict__['resnet152'](pretrained=False)
model.to(device)

criterion = nn.CrossEntropyLoss()

if args.run_lazy_mode:
    from habana_frameworks.torch.hpex.optimizers import FusedSGD
    sgd_optimizer = FusedSGD
else:
    sgd_optimizer = torch.optim.SGD
optimizer = sgd_optimizer(
    model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

model_for_train = model


print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
    train_one_epoch(model_for_train, criterion, optimizer, data_loader,
            device, epoch, print_freq=args.print_freq)

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))

# Distributed Training

**Restart the kernel before running the next section of the notebook**

We will use the Model-References repository command line to demo distributed training on 8 HPUs. 

Distributed training differs in the following ways.

1. [Initialization with hccl](https://github.com/HabanaAI/Model-References/blob/1.6.0/PyTorch/computer_vision/classification/torchvision/utils.py#L249)
```
    from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
    args.world_size, args.rank, args.local_rank = initialize_distributed_hpu()
    ...
    dist.init_process_group(backend='hccl', rank=args.rank, world_size=args.world_size)

```

2. [Use the torch distributed data sampler](https://github.com/HabanaAI/Model-References/blob/1.6.0/PyTorch/computer_vision/classification/torchvision/train.py#L179)
```
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
```

3. [Distributed data parallel pytorch model initalization](https://github.com/HabanaAI/Model-References/blob/1.6.0/PyTorch/computer_vision/classification/torchvision/train.py#L328)
```
    model = torch.nn.parallel.DistributedDataParallel(model, broadcast_buffers=False,
            gradient_as_bucket_view=True)
```

__Note__: regarding Step 3 you must use the DistributedDataParallel API, as the DataParallel API is unsupported by Habana

In [None]:
%set_env PYTHONPATH=/home/ubuntu/work/Model-References/PyTorch/computer_vision/classification/torchvision:/root/examples/models:/usr/lib/habanalabs/:/root

In [None]:
%cd /home/ubuntu/work/Model-References/PyTorch/computer_vision/classification/torchvision

Apply the following patch to use fake data and remove evaluation (you can pass -R to git apply if you want to revert it)

In [None]:
! git apply /home/ubuntu/work/DL1-Workshop/PyTorch-ResNet152/fake_data_no_eval.patch

Run the following bash command as a shell script in the final cell(demo_resnet.sh) to start multi-HPU training.

  ```bash
  export MASTER_ADDR=localhost
  export MASTER_PORT=12355
  mpirun -n 8 --bind-to core --map-by slot:PE=6 --rank-by core --report-bindings --allow-run-as-root \
    python3 train.py --model=resnet152 --device=hpu --batch-size=256 --epochs=90 --workers=10 \
    --dl-worker-type=MP --print-freq=10 --output-dir=. --seed=123 --hmp --hmp-bf16 ./ops_bf16_Resnet.txt \
    --hmp-fp32 ./ops_fp32_Resnet.txt --custom-lr-values 0.275 0.45 0.625 0.8 0.08 0.008 0.0008 \
    --custom-lr-milestones 1 2 3 4 30 60 80 --deterministic --dl-time-exclude=False
  ```

In [None]:
! sh /home/ubuntu/work/DL1-Workshop/PyTorch-ResNet152/demo_resnet.sh