In [None]:
import numpy as np
import os, sys, shutil, time, random
from tqdm import tqdm

# Plummer Other Lib Imports
import copy
import argparse
import warnings
import contextlib
from __future__ import division # Dunno what this does
import matplotlib.pyplot as plt

In [None]:
# Plummer Imports for Images and Statistical ML
import PIL
from sklearn.cluster import KMeans

In [None]:
# Torch Stuff
import torch
import torchvision

# Plummer Torch Imports
import torch.backends.cudnn as cudnn
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler
import torchvision.datasets as dset
import torchvision.transforms as transforms


# Local Imports
#import models # Unsure what this is...

In [None]:
# Tensorflow Stuff
import tensorflow as tf

# Citation

```
@InProceedings{
    plummerNPAS2022,
    author={ Bryan A. Plummer and Nikoli Dryden and Julius Frost and Torsten Hoefler and Kate Saenko },
    title={Neural Parameter Allocation Search},
    booktitle={International Conference on Learning Representations (ICLR)},
    year={2022}
}
```

# Utility (from util.py imports)

In [None]:
class Cutout(object):
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        h = img.size(1)
        w = img.size(2)

        mask = np.ones((h, w), np.float32)

        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask

        return img


# Lighting data augmentation take from here - https://github.com/eladhoffer/convNet.pytorch/blob/master/preprocess.py
class Lighting(object):
    """Lighting noise(AlexNet - style PCA - based noise)"""

    def __init__(self, alphastd, eigval, eigvec):
        self.alphastd = alphastd
        self.eigval = eigval
        self.eigvec = eigvec

    def __call__(self, img):
        if self.alphastd == 0:
            return img

        alpha = img.new().resize_(3).normal_(0, self.alphastd)
        rgb = self.eigvec.type_as(img).clone()\
            .mul(alpha.view(1, 3).expand(3, 3))\
            .mul(self.eigval.view(1, 3).expand(3, 3))\
            .sum(1).squeeze()
        return img.add(rgb.view(3, 1, 1).expand_as(img))


# Adapted from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/ConvNets/image_classification/smoothing.py
class LabelSmoothingNLLLoss(torch.nn.Module):
    """NLL loss with label smoothing."""

    def __init__(self, smoothing=0.0):
        super().__init__()
        self.smoothing = smoothing
        self.confidence = 1.0 - smoothing

    def forward(self, x, target):
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)
        nll_loss = (-logprobs.gather(dim=-1, index=target.unsqueeze(1))).squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence*nll_loss + self.smoothing*smooth_loss
        return loss.mean()


class RandomDataset(torch.utils.data.Dataset):
    """Dataset that just returns a random tensor for debugging."""

    def __init__(self, sample_shape, dataset_size, label=True, pil=False,
                 transform=None):
        super().__init__()
        self.sample_shape = sample_shape
        self.dataset_size = dataset_size
        self.label = label
        self.transform = transform
        if pil:
            d = torch.rand(sample_shape)
            self.d = torchvision.transforms.functional.to_pil_image(d)
        else:
            self.d = torch.rand(sample_shape)

    def __len__(self):
        return self.dataset_size

    def __getitem__(self, index):
        d = self.d
        if self.transform is not None:
            d = self.transform(d)
        if self.label:
            return d, 0
        else:
            return d


# Adapted from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/ConvNets/image_classification/dataloaders.py#L250
class PrefetchWrapper:
    """Fetch ahead and do some asynchronous processing."""

    def __init__(self, data_loader, mean, stdev, lighting):
        self.data_loader = data_loader
        self.mean = mean
        self.stdev = stdev
        self.lighting = lighting
        self.stream = torch.cuda.Stream()
        self.sampler = data_loader.sampler  # To simplify set_epoch.

    def prefetch_loader(data_loader, mean, stdev, lighting, stream):
        if lighting is not None:
            mean = torch.tensor(mean).cuda().view(1, 3, 1, 1)
            stdev = torch.tensor(stdev).cuda().view(1, 3, 1, 1)
        else:
            mean = torch.tensor([x*255 for x in mean]).cuda().view(1, 3, 1, 1)
            stdev = torch.tensor([x*255 for x in stdev]).cuda().view(1, 3, 1, 1)

        first = True
        for next_input, next_target in data_loader:
            with torch.cuda.stream(stream):
                next_target = next_target.cuda(non_blocking=True)
                next_input = next_input.cuda(non_blocking=True).float()
                if lighting is not None:
                    # Scale and apply lighting first.
                    next_input = next_input.div_(255.0)
                    next_input = lighting(next_input).sub_(mean).div_(stdev)
                else:
                    next_input = next_input.sub_(mean).div_(stdev)

            if not first:
                yield input, target
            else:
                first = False

            torch.cuda.current_stream().wait_stream(stream)
            input = next_input
            target = next_target
        yield input, target

    def __iter__(self):
        return PrefetchWrapper.prefetch_loader(
            self.data_loader, self.mean, self.stdev, self.lighting, self.stream)

    def __len__(self):
        return len(self.data_loader)


def fast_collate(batch):
    imgs = [img[0] for img in batch]
    targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
    w = imgs[0].size[0]
    h = imgs[0].size[1]
    tensor = torch.zeros((len(imgs), 3, h, w), dtype=torch.uint8)
    for i, img in enumerate(imgs):
        nump_array = np.asarray(img, dtype=np.uint8)
        if nump_array.ndim < 3:
            nump_array = np.expand_dims(nump_array, axis=-1)
        nump_array = np.rollaxis(nump_array, 2)
        # Suppress warnings.
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            tensor[i] += torch.from_numpy(nump_array)
    return tensor, targets


class AverageMeter(object):
  """Computes and stores the average and current value"""
  def __init__(self):
    self.reset()

  def reset(self):
    self.val = 0
    self.avg = 0
    self.sum = 0
    self.count = 0

  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count


class RecorderMeter(object):
  """Computes and stores the minimum loss value and its epoch index"""
  def __init__(self, total_epoch):
    self.reset(total_epoch)

  def reset(self, total_epoch):
    assert total_epoch > 0
    self.total_epoch   = total_epoch
    self.current_epoch = 0
    self.epoch_losses  = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
    self.epoch_losses  = self.epoch_losses - 1

    self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val]
    self.epoch_accuracy= self.epoch_accuracy

  def refresh(self, epochs):
    if epochs == self.total_epoch: return
    self.epoch_losses = np.vstack( (self.epoch_losses, np.zeros((epochs - self.total_epoch, 2), dtype=np.float32) - 1) )
    self.epoch_accuracy = np.vstack( (self.epoch_accuracy, np.zeros((epochs - self.total_epoch, 2), dtype=np.float32)) )
    self.total_epoch = epochs

  def update(self, idx, train_loss, train_acc, val_loss, val_acc):
    assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx)
    self.epoch_losses  [idx, 0] = train_loss
    self.epoch_losses  [idx, 1] = val_loss
    self.epoch_accuracy[idx, 0] = train_acc
    self.epoch_accuracy[idx, 1] = val_acc
    self.current_epoch = idx + 1
    return self.max_accuracy(False) == val_acc

  def max_accuracy(self, istrain):
    if self.current_epoch <= 0: return 0
    if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max()
    else:       return self.epoch_accuracy[:self.current_epoch, 1].max()

  def plot_curve(self, save_path):
    title = 'the accuracy/loss curve of train/val'
    dpi = 80
    width, height = 1200, 800
    legend_fontsize = 10
    scale_distance = 48.8
    figsize = width / float(dpi), height / float(dpi)

    fig = plt.figure(figsize=figsize)
    x_axis = np.array([i for i in range(self.total_epoch)]) # epochs
    y_axis = np.zeros(self.total_epoch)

    plt.xlim(0, self.total_epoch)
    plt.ylim(0, 100)
    interval_y = 5
    interval_x = 5
    plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
    plt.yticks(np.arange(0, 100 + interval_y, interval_y))
    plt.grid()
    plt.title(title, fontsize=20)
    plt.xlabel('the training epoch', fontsize=16)
    plt.ylabel('accuracy', fontsize=16)

    y_axis[:] = self.epoch_accuracy[:, 0]
    plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2)
    plt.legend(loc=4, fontsize=legend_fontsize)

    y_axis[:] = self.epoch_accuracy[:, 1]
    plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2)
    plt.legend(loc=4, fontsize=legend_fontsize)


    y_axis[:] = self.epoch_losses[:, 0]
    plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2)
    plt.legend(loc=4, fontsize=legend_fontsize)

    y_axis[:] = self.epoch_losses[:, 1]
    plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2)
    plt.legend(loc=4, fontsize=legend_fontsize)

    if save_path is not None:
      fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
      print ('---- save figure {} into {}'.format(title, save_path))
    plt.close(fig)


def time_string():
    ISOTIMEFORMAT = '%Y-%m-%d %X'
    string = '[{}]'.format(time.strftime(
        ISOTIMEFORMAT, time.gmtime(time.time())))
    return string


def convert_secs2time(epoch_time):
    need_hour = int(epoch_time / 3600)
    need_mins = int((epoch_time - 3600*need_hour) / 60)
    need_secs = int(epoch_time - 3600*need_hour - 60*need_mins)
    return need_hour, need_mins, need_secs


def time_file_str():
    ISOTIMEFORMAT = '%Y-%m-%d'
    string = '{}'.format(time.strftime(
        ISOTIMEFORMAT, time.gmtime(time.time())))
    return string + '-{}'.format(random.randint(1, 10000))


# Utilities for distributed training.

def get_num_gpus():
    """Number of GPUs on this node."""
    return torch.cuda.device_count()


def get_local_rank():
    """Get local rank from environment."""
    if 'MV2_COMM_WORLD_LOCAL_RANK' in os.environ:
        return int(os.environ['MV2_COMM_WORLD_LOCAL_RANK'])
    elif 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ:
        return int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
    elif 'SLURM_LOCALID' in os.environ:
        return int(os.environ['SLURM_LOCALID'])
    else:
        return 0


def get_local_size():
    """Get local size from environment."""
    if 'MV2_COMM_WORLD_LOCAL_SIZE' in os.environ:
        return int(os.environ['MV2_COMM_WORLD_LOCAL_SIZE'])
    elif 'OMPI_COMM_WORLD_LOCAL_SIZE' in os.environ:
        return int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'])
    elif 'SLURM_NTASKS_PER_NODE' in os.environ:
        return int(os.environ['SLURM_NTASKS_PER_NODE'])
    else:
        return 1


def get_world_rank():
    """Get rank in world from environment.""" 
    if 'MV2_COMM_WORLD_RANK' in os.environ:
        return int(os.environ['MV2_COMM_WORLD_RANK'])
    elif 'OMPI_COMM_WORLD_RANK' in os.environ:
        return int(os.environ['OMPI_COMM_WORLD_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        return int(os.environ['SLURM_PROCID'])
    else:
        return 0


def get_world_size():
    """Get world size from environment."""
    if 'MV2_COMM_WORLD_SIZE' in os.environ:
        return int(os.environ['MV2_COMM_WORLD_SIZE'])
    elif 'OMPI_COMM_WORLD_SIZE' in os.environ:
        return int(os.environ['OMPI_COMM_WORLD_SIZE'])
    elif 'SLURM_NTASKS' in os.environ:
        return int(os.environ['SLURM_NTASKS'])
    else:
        return 1


def initialize_dist(init_file):
    """Initialize PyTorch distributed backend."""
    torch.cuda.init()
    torch.cuda.set_device(get_local_rank())
    init_file = os.path.abspath(init_file)
    torch.distributed.init_process_group(
        backend='nccl', init_method=f'file://{init_file}',
        rank=get_world_rank(), world_size=get_world_size()
    )
    torch.distributed.barrier()
    # Ensure the init file is removed.
    if get_world_rank() == 0 and os.path.exists(init_file):
        os.unlink(init_file)

def get_cuda_device():
    """Get this rank's CUDA device."""
    return torch.device(f'cuda:{get_local_rank()}')


def allreduce_tensor(t):
    """Allreduce and average tensor t."""
    rt = t.clone().detach()
    torch.distributed.all_reduce(rt)
    rt /= get_world_size()
    return rt

# Dataset Loading

Christ, this is very long. I think the preprocessing is included as well...

In [None]:
def load_dataset():
    if args.dataset == 'cifar10':
        mean, std = [x / 255 for x in [125.3, 123.0, 113.9]],  [x / 255 for x in [63.0, 62.1, 66.7]] 
        dataset = dset.CIFAR10
        num_classes = 10
    elif args.dataset == 'cifar100':
        mean, std = [x / 255 for x in [129.3, 124.1, 112.4]], [x / 255 for x in [68.2, 65.4, 70.4]]
        dataset = dset.CIFAR100
        num_classes = 100
    elif args.dataset not in ['imagenet', 'rand_imagenet']:
        assert False, "Unknown dataset : {}".format(args.dataset) # so I assume we cannot use a custom dataset?

    if args.dataset == 'cifar10' or args.dataset == 'cifar100':
        #train_transform = transforms.Compose([transforms.Scale(256), transforms.RandomHorizontalFlip(), transforms.RandomCrop(224), transforms.ToTensor(), transforms.Normalize(mean, std)])
        train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)])
        if args.cutout: train_transform.transforms.append(Cutout(n_holes=1, length=16))
        #test_transform = transforms.Compose([transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean, std)])
        test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

        # Ensure only one rank downloads
        if args.dist and get_world_rank() != 0:
            torch.distributed.barrier()

        if args.evaluate:
            train_data = dataset(args.data_path, train=True,
                                 transform=train_transform, download=True)
            test_data = dataset(args.data_path, train=False,
                                transform=test_transform, download=True)

            train_loader = torch.utils.data.DataLoader(
                train_data, batch_size=args.batch_size, shuffle=True,
                num_workers=args.workers, pin_memory=True)
            test_loader = torch.utils.data.DataLoader(
                test_data, batch_size=args.batch_size, shuffle=False,
                num_workers=args.workers, pin_memory=True)
        else:
            # partition training set into two instead.
            # note that test_data is defined using train=True
            train_data = dataset(args.data_path, train=True,
                                 transform=train_transform, download=True)
            test_data = dataset(args.data_path, train=True,
                                transform=test_transform, download=True)

            indices = list(range(len(train_data)))
            np.random.shuffle(indices)
            split = int(0.9 * len(train_data))
            train_indices, test_indices = indices[:split], indices[split:]
            if args.dist:
                # Use the distributed sampler here.
                train_subset = torch.utils.data.Subset(
                    train_data, train_indices)
                train_sampler = torch.utils.data.distributed.DistributedSampler(
                    train_subset, num_replicas=get_world_size(),
                    rank=get_world_rank())
                train_loader = torch.utils.data.DataLoader(
                    train_subset, batch_size=args.batch_size,
                    sampler=train_sampler, num_workers=args.workers,
                    pin_memory=True)
                test_subset = torch.utils.data.Subset(test_data, test_indices)
                test_sampler = torch.utils.data.distributed.DistributedSampler(
                    test_subset, num_replicas=get_world_size(),
                    rank=get_world_rank())
                test_loader = torch.utils.data.DataLoader(
                    test_subset, batch_size=args.batch_size,
                    sampler=test_sampler, num_workers=args.workers,
                    pin_memory=True)
            else:
                train_sampler = SubsetRandomSampler(train_indices)
                train_loader = torch.utils.data.DataLoader(
                    train_data, batch_size=args.batch_size,
                    num_workers=args.workers, pin_memory=True,
                    sampler=train_sampler)
                test_sampler = SubsetRandomSampler(test_indices)
                test_loader = torch.utils.data.DataLoader(
                    test_data, batch_size=args.batch_size,
                    num_workers=args.workers, pin_memory=True,
                    sampler=test_sampler)

        # Let ranks through.
        if args.dist and get_world_rank() == 0:
            torch.distributed.barrier()

    elif args.dataset == 'imagenet':
        if args.dist:
            imagenet_means = [0.485, 0.456, 0.406]
            imagenet_stdevs = [0.229, 0.224, 0.225]

            # Can just read off SSDs.
            if 'efficientnet' in args.arch:
                image_size = models.efficientnet.EfficientNet.get_image_size(
                    args.effnet_arch)
                train_transform = transforms.Compose([
                    models.efficientnet.augmentations.Augmentation(
                        models.efficientnet.augmentations.get_fastautoaugment_policy()),
                    models.efficientnet.augmentations.EfficientNetRandomCrop(
                        image_size),
                    transforms.Resize((image_size, image_size),
                                      PIL.Image.BICUBIC),
                    transforms.RandomHorizontalFlip(),
                    transforms.ColorJitter(0.4, 0.4, 0.4),
                ])
                test_transform = transforms.Compose([
                    models.efficientnet.augmentations.EfficientNetCenterCrop(
                        image_size),
                    transforms.Resize((image_size, image_size),
                                      PIL.Image.BICUBIC)
                ])
            else:
                # Transforms adapted from imagenet_seq's, except that color jitter
                # and lighting are not applied in random orders, and that resizing
                # is done with bilinear instead of cubic interpolation.
                train_transform = transforms.Compose([
                    transforms.RandomResizedCrop((224, 224)),
                    # transforms.ColorJitter(0.4, 0.4, 0.4),
                    transforms.RandomHorizontalFlip()])
                test_transform = transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop((224, 224))])
            train_data = dset.ImageFolder(
                args.data_path + '/train', transform=train_transform)
            test_data = dset.ImageFolder(
                args.data_path + '/val', transform=test_transform)
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_data, num_replicas=get_world_size(),
                rank=get_world_rank())
            train_loader = torch.utils.data.DataLoader(
                train_data, batch_size=args.batch_size, sampler=train_sampler,
                num_workers=args.workers, pin_memory=True,
                collate_fn=fast_collate, drop_last=args.drop_last)
            train_loader = PrefetchWrapper(
                train_loader, imagenet_means, imagenet_stdevs,
                Lighting(0.1,
                         torch.Tensor([0.2175, 0.0188, 0.0045]).cuda(),
                         torch.Tensor([
                             [-0.5675, 0.7192, 0.4009],
                             [-0.5808, -0.0045, -0.8140],
                             [-0.5836, -0.6948, 0.4203],
                         ]).cuda()))
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_data, num_replicas=get_world_size(),
                rank=get_world_rank())
            test_loader = torch.utils.data.DataLoader(
                test_data, batch_size=args.batch_size, sampler=test_sampler,
                num_workers=args.workers, pin_memory=True,
                collate_fn=fast_collate)
            test_loader = PrefetchWrapper(
                test_loader, imagenet_means, imagenet_stdevs, None)
        else:
            import imagenet_seq
            train_loader = imagenet_seq.data.Loader(
                'train', batch_size=args.batch_size, num_workers=args.workers)
            test_loader = imagenet_seq.data.Loader(
                'val', batch_size=args.batch_size, num_workers=args.workers)
        num_classes = 1000
    elif args.dataset == 'rand_imagenet':
        imagenet_means = [0.485, 0.456, 0.406]
        imagenet_stdevs = [0.229, 0.224, 0.225]
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop((224, 224)),
            transforms.ColorJitter(0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip()])
        test_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224)])
        train_data = RandomDataset((3, 256, 256), 1200000, pil=True,
                                   transform=train_transform)
        test_data = RandomDataset((3, 256, 256), 50000, pil=True,
                                  transform=test_transform)
        if args.dist:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_data, num_replicas=get_world_size(),
                rank=get_world_rank())
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_data, num_replicas=get_world_size(),
                rank=get_world_rank())
        else:
            train_sampler = RandomSampler(train_data)
            test_sampler = RandomSampler(test_data)
        train_loader = torch.utils.data.DataLoader(
            train_data, batch_size=args.batch_size, num_workers=args.workers,
            pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
        train_loader = PrefetchWrapper(
            train_loader, imagenet_means, imagenet_stdevs,
            Lighting(0.1,
                     torch.Tensor([0.2175, 0.0188, 0.0045]).cuda(),
                     torch.Tensor([
                         [-0.5675, 0.7192, 0.4009],
                         [-0.5808, -0.0045, -0.8140],
                         [-0.5836, -0.6948, 0.4203],
                     ]).cuda()))
        test_loader = torch.utils.data.DataLoader(
            test_data, batch_size=args.batch_size, num_workers=args.workers,
            pin_memory=True, sampler=test_sampler, collate_fn=fast_collate)
        test_loader = PrefetchWrapper(
            test_loader, imagenet_means, imagenet_stdevs, None)
        num_classes = 1000
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    return num_classes, train_loader, test_loader

# Some Definitions

**Neural Parameter Allocation Search (NPAS)**: Given a neural network architecture with layers
$l_1, . . . ,l_L$, which each require weights $w_1, . . . , w_L$, and a fixed parameter budget $\theta$, train a
high-performing neural network using the given architecture and parameter budget.

# Model

NPAS Model

In [None]:
def load_model(
    num_classes, 
    log, 
    max_params, 
    share_type, 
    upsample_type,
    groups=None
):
    print_log("=> creating model '{}'".format(args.arch), log)
    
    if args.arch == 'efficientnet_imagenet':
        net = models.efficientnet_imagenet(
            args.effnet_arch, share_type, upsample_type, args.upsample_window,
            args.bank_size, max_params, groups
        )
    else:
        net = models.__dict__[args.arch](
            share_type, upsample_type, args.upsample_window, args.depth,
            args.wide, args.bank_size, max_params, num_classes, groups
        )
    
    print_log("=> network :\n {}".format(net), log)
    if args.dist:
        net = net.to(get_cuda_device())
    else:
        net = torch.nn.DataParallel(
            net.cuda(), device_ids=list(range(args.ngpu)))
    trainable_params = filter(lambda p: p.requires_grad, net.parameters())
    params = sum([p.numel() for p in trainable_params])
    print_log("Number of parameters: {}".format(params), log)
    return net

# Main Loop

Lots of lots of code...

In [None]:
def main():
    global best_acc, best_los # Why global?

    if get_world_rank() == 0: 
        if not os.path.isdir(args.save_path):
            os.makedirs(args.save_path)
        log = open(os.path.join(
            args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    else:
        log = None
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("PyTorch  version : {}".format(torch.__version__), log)
    print_log("CuDNN  version : {}".format(torch.backends.cudnn.version()), log)
    print_log(f'Ranks: {get_world_size()}', log)
    print_log(f'Global batch size: {args.batch_size*get_world_size()}', log)
    
    if get_world_rank() == 0 and not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    num_classes, train_loader, test_loader = load_dataset() # Load the Dataset
    
    groups = args.param_groups
    if args.param_groups > 1:
        fn = os.path.join(args.save_path, 'groups.npy')
        if args.evaluate or args.resume:
            groups = np.load(fn)
            assert len(set(groups)) == args.param_groups
        else:
            groups = get_parameter_groups(train_loader, state, num_classes, log)
            if args.param_group_type != 'reload' and get_world_rank() == 0:
                np.save(fn, groups)
            if args.param_group_type == 'learned':
                print_log('Must restart after learning parameter groups', log)
                return
            if args.param_group_type == 'random':
                # Need to load this from rank 0 to get consistent view.
                torch.distributed.barrier()
                if get_world_rank() != 0:
                    groups = np.load(fn)
        print_log('groups- ' + ', '.join(
            [str(i) + ':' + str(g) for i, g in enumerate(groups)]), log)

    net = load_model(num_classes, log, args.max_params, args.share_type,
                     args.upsample_type, groups=groups)

    if args.label_smoothing > 0.0:
        criterion = LabelSmoothingNLLLoss(args.label_smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss().cuda()

    decay_skip = ['coefficients']
    if args.no_bn_decay:
        decay_skip.append('bn')
    params = group_weight_decay(net, state['decay'], decay_skip)
    if args.optimizer == 'sgd':
        if args.dist:
            optimizer = apex.optimizers.FusedSGD(
                params, state['learning_rate'], momentum=state['momentum'],
                nesterov=(not args.no_nesterov and state['momentum'] > 0.0))
        else:
            optimizer = torch.optim.SGD(
                params, state['learning_rate'], momentum=state['momentum'],
                nesterov=(not args.no_nesterov and state['momentum'] > 0.0))
    else:
        optimizer = models.efficientnet.RMSpropTF(
            params, state['learning_rate'], alpha=0.9, eps=1e-3,
            momentum=state['momentum'])

    if args.step_size:
        if args.schedule:
            raise ValueError('Cannot combine regular and step schedules')
        step_scheduler = torch.optim.lr_scheduler.StepLR(
           optimizer, args.step_size, args.step_gamma)
        if args.step_warmup:
            step_scheduler = models.efficientnet.GradualWarmupScheduler(
                optimizer, multiplier=1.0, warmup_epoch=args.step_warmup,
                after_scheduler=step_scheduler)
    else:
        step_scheduler = None

    if args.dist:
        net = torch.nn.parallel.DistributedDataParallel(
            net,
            device_ids=[get_local_rank()],
            output_device=get_local_rank(),
            find_unused_parameters=True)
    scaler = GradScaler(enabled=args.amp)

    if args.ema_decay:
        ema_model = copy.deepcopy(net).to(get_cuda_device())
        ema_manager = models.efficientnet.EMA(args.ema_decay)
    else:
        ema_model, ema_manager = None, None

    recorder = RecorderMeter(args.epochs)
    if args.resume:
        if args.resume == 'auto':
            args.resume = os.path.join(args.save_path, 'checkpoint.pth.tar')
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(
                args.resume,
                map_location=get_cuda_device() if args.ngpu else 'cpu')
            recorder = checkpoint['recorder']
            recorder.refresh(args.epochs)
            args.start_epoch = checkpoint['epoch']
            # Hack to load models that were wrapped in (D)DP.
            if args.no_dp:
                net = torch.nn.DataParallel(net, device_ids=[get_local_rank()])
            net.load_state_dict(checkpoint['state_dict'])
            if args.no_dp:
                net = net.module
            optimizer.load_state_dict(checkpoint['optimizer'])
            if step_scheduler:
                step_scheduler.load_state_dict(checkpoint['scheduler'])
            if ema_manager is not None:
                ema_manager.shadow = checkpoint['ema']
            if args.amp:
                scaler.load_state_dict(checkpoint['amp'])
            best_acc = recorder.max_accuracy(False)
            print_log(
                "=> loaded checkpoint '{}' accuracy={} (epoch {})" .format(
                    args.resume, best_acc, checkpoint['epoch']), log)
        else:
            print_log(
                "=> no checkpoint found at '{}'".format(args.resume), log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        if get_world_size() > 1:
            raise RuntimeError('Do not validate with distributed training')
        validate(test_loader, net, criterion, log)
        return

    start_time = time.time()
    epoch_time = AverageMeter()
    train_los = -1

    for epoch in range(args.start_epoch, args.epochs):
        if step_scheduler:
            current_learning_rate = step_scheduler.get_last_lr()[0]
        else:
            current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule, train_los)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        if args.dist:
            train_loader.sampler.set_epoch(epoch)
            test_loader.sampler.set_epoch(epoch)
        train_acc, train_los = train(train_loader, net, criterion, optimizer,
                                     scaler, epoch, log, step_scheduler,
                                     ema_manager)
        torch.cuda.synchronize()

        val_acc, val_los = validate(test_loader, net, criterion, log,
                                    ema_model, ema_manager)
        recorder.update(epoch, train_los, train_acc, val_los, val_acc)

        is_best = False
        if args.best_loss:
            if val_los < best_los:
                is_best = True
                best_los = val_los
        else:
            if val_acc > best_acc:
                is_best = True
                best_acc = val_acc

        if get_world_rank() == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
                'scheduler': step_scheduler.state_dict() if step_scheduler else None,
                'ema': ema_manager.state_dict() if ema_manager is not None else None,
                'amp': scaler.state_dict() if args.amp else None
            }, is_best, args.save_path, 'checkpoint.pth.tar')

        epoch_time.update(time.time() - start_time)
        start_time = time.time()

        if get_world_rank() == 0:
            recorder.plot_curve(result_png_path)

    if get_world_rank() == 0:
        log.close()