Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Modify image size and training for Inception Models (#425)
Browse files Browse the repository at this point in the history
* Merge pytorch 1.3 commits

This PR is a fix for issue #422.

1. ImageNet models usually use input size [batch, 3, 224, 224], but all Inception models require an input image size of [batch, 3, 299, 299].

2. Inception models have auxiliary branches which contribute to the loss only during training.  The reported classification loss only considers the main classification loss.

3. Inception_V3 normalizes the input inside the network itself.  More details can be found in @soumendukrg's PR #425 [comments](#425 (comment)).

NOTE: Training using Inception_V3 is only possible on a single GPU as of now. This issue talks about this problem. I have checked and this problem persists in torch 1.3.0:
[inception_v3 of vision 0.3.0 does not fit in DataParallel of torch 1.1.0 #1048](pytorch/vision#1048)

Co-authored-by: Neta Zmora <neta.zmora@intel.com>
  • Loading branch information
soumendukrg and nzmora committed Apr 27, 2020
1 parent 301484c commit 99afd8e
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 20 deletions.
41 changes: 26 additions & 15 deletions distiller/apputils/data_loaders.py
Expand Up @@ -24,6 +24,7 @@
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data.sampler import Sampler
from functools import partial
import numpy as np
import distiller

Expand Down Expand Up @@ -58,19 +59,21 @@ def classification_get_input_shape(dataset):
raise ValueError("dataset %s is not supported" % dataset)


def __dataset_factory(dataset):
def __dataset_factory(dataset, arch):
return {'cifar10': cifar10_get_datasets,
'mnist': mnist_get_datasets,
'imagenet': imagenet_get_datasets}.get(dataset, None)
'imagenet': partial(imagenet_get_datasets, arch=arch)}.get(dataset, None)


def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, deterministic=False,
def load_data(dataset, arch, data_dir,
batch_size, workers, validation_split=0.1, deterministic=False,
effective_train_size=1., effective_valid_size=1., effective_test_size=1.,
fixed_subset=False, sequential=False, test_only=False):
"""Load a dataset.
Args:
dataset: a string with the name of the dataset to load (cifar10/imagenet)
arch: a string with the name of the model architecture
data_dir: the directory where the dataset resides
batch_size: the batch size
workers: the number of worker threads to use for loading the data
Expand All @@ -86,12 +89,12 @@ def load_data(dataset, data_dir, batch_size, workers, validation_split=0.1, dete
"""
if dataset not in DATASETS_NAMES:
raise ValueError('load_data does not support dataset %s" % dataset')
datasets_fn = __dataset_factory(dataset)
return get_data_loaders(datasets_fn, data_dir, batch_size, workers,
datasets_fn = __dataset_factory(dataset, arch)
return get_data_loaders(datasets_fn, data_dir, batch_size, workers,
validation_split=validation_split,
deterministic=deterministic,
deterministic=deterministic,
effective_train_size=effective_train_size,
effective_valid_size=effective_valid_size,
effective_valid_size=effective_valid_size,
effective_test_size=effective_test_size,
fixed_subset=fixed_subset,
sequential=sequential,
Expand Down Expand Up @@ -163,20 +166,29 @@ def cifar10_get_datasets(data_dir, load_train=True, load_test=True):

return train_dataset, test_dataset


def imagenet_get_datasets(data_dir, load_train=True, load_test=True):
def imagenet_get_datasets(data_dir, arch, load_train=True, load_test=True):
"""
Load the ImageNet dataset.
"""
# Inception Network accepts image of size 3, 299, 299
if distiller.models.is_inception(arch):
resize, crop = 336, 299
else:
resize, crop = 256, 224
if arch == 'googlenet':
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
else:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

train_dataset = None
if load_train:
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomResizedCrop(crop),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
Expand All @@ -187,8 +199,8 @@ def imagenet_get_datasets(data_dir, load_train=True, load_test=True):
test_dataset = None
if load_test:
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.Resize(resize),
transforms.CenterCrop(crop),
transforms.ToTensor(),
normalize,
])
Expand All @@ -197,7 +209,6 @@ def imagenet_get_datasets(data_dir, load_train=True, load_test=True):

return train_dataset, test_dataset


def __image_size(dataset):
# un-squeeze is used here to add the batch dimension (value=1), which is missing
return dataset[0][0].unsqueeze(0).size()
Expand Down
56 changes: 52 additions & 4 deletions distiller/apputils/image_classifier.py
Expand Up @@ -472,7 +472,7 @@ def save_collectors_data(collectors, directory):
def load_data(args, fixed_subset=False, sequential=False, load_train=True, load_val=True, load_test=True):
test_only = not load_train and not load_val

train_loader, val_loader, test_loader, _ = apputils.load_data(args.dataset,
train_loader, val_loader, test_loader, _ = apputils.load_data(args.dataset, args.arch,
os.path.expanduser(args.data), args.batch_size,
args.workers, args.validation_split, args.deterministic,
args.effective_train_size, args.effective_valid_size, args.effective_test_size,
Expand All @@ -488,7 +488,7 @@ def load_data(args, fixed_subset=False, sequential=False, load_train=True, load_
loaders = [loaders[i] for i, flag in enumerate(flags) if flag]

if len(loaders) == 1:
# Unpack the list for convinience
# Unpack the list for convenience
loaders = loaders[0]
return loaders

Expand Down Expand Up @@ -579,9 +579,19 @@ def _log_training_progress():
output = args.kd_policy.forward(inputs)

if not early_exit_mode(args):
loss = criterion(output, target)
# Handle loss calculation for inception models separately due to auxiliary outputs
# if user turned off auxiliary classifiers by hand, then loss should be calculated normally,
# so, we have this check to ensure we only call this function when output is a tuple
if models.is_inception(args.arch) and isinstance(output, tuple):
loss = inception_training_loss(output, target, criterion, args)
else:
loss = criterion(output, target)
# Measure accuracy
classerr.add(output.detach(), target)
# For inception models, we only consider accuracy of main classifier
if isinstance(output, tuple):
classerr.add(output[0].detach(), target)
else:
classerr.add(output.detach(), target)
acc_stats.append([classerr.value(1), classerr.value(5)])
else:
# Measure accuracy and record loss
Expand Down Expand Up @@ -741,6 +751,44 @@ def _log_validation_progress():
return total_top1, total_top5, losses_exits_stats[args.num_exits-1]


def inception_training_loss(output, target, criterion, args):
"""Compute weighted loss for Inception networks as they have auxiliary classifiers
Auxiliary classifiers were added to inception networks to tackle the vanishing gradient problem
They apply softmax to outputs of one or more intermediate inception modules and compute auxiliary
loss over same labels.
Note that auxiliary loss is purely used for training purposes, as they are disabled during inference.
GoogleNet has 2 auxiliary classifiers, hence two 3 outputs in total, output[0] is main classifier output,
output[1] is aux2 classifier output and output[2] is aux1 classifier output and the weights of the
aux losses are weighted by 0.3 according to the paper (C. Szegedy et al., "Going deeper with convolutions,"
2015 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Boston, MA, 2015, pp. 1-9.)
All other versions of Inception networks have only one auxiliary classifier, and the auxiliary loss
is weighted by 0.4 according to PyTorch documentation
# From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
"""
weighted_loss = 0
if args.arch == 'googlenet':
# DEFAULT, aux classifiers are NOT included in PyTorch Pretrained googlenet model as they are NOT trained,
# they are only present if network is trained from scratch. If you need to fine tune googlenet (e.g. after
# pruning a pretrained model), then you have to explicitly enable aux classifiers when creating the model
# DEFAULT, in case of pretrained model, output length is 1, so loss will be calculated in main training loop
# instead of here, as we enter this function only if output is a tuple (len>1)
# TODO: Enable user to feed some input to add aux classifiers for pretrained googlenet model
outputs, aux2_outputs, aux1_outputs = output # extract all 3 outputs
loss0 = criterion(outputs, target)
loss1 = criterion(aux1_outputs, target)
loss2 = criterion(aux2_outputs, target)
weighted_loss = loss0 + 0.3*loss1 + 0.3*loss2
else:
outputs, aux_outputs = output # extract two outputs
loss0 = criterion(outputs, target)
loss1 = criterion(aux_outputs, target)
weighted_loss = loss0 + 0.4*loss1
return weighted_loss


def earlyexit_loss(output, target, criterion, args):
"""Compute the weighted sum of the exits losses
Expand Down
13 changes: 12 additions & 1 deletion distiller/models/__init__.py
Expand Up @@ -158,6 +158,13 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
return model.to(device)


def is_inception(arch):
return arch in [ # Torchvision architectures
'inception_v3', 'googlenet',
# Cadene architectures
'inceptionv3', 'inceptionv4', 'inceptionresnetv2']


def _create_imagenet_model(arch, pretrained):
dataset = "imagenet"
cadene = False
Expand All @@ -166,9 +173,13 @@ def _create_imagenet_model(arch, pretrained):
model = imagenet_extra_models.__dict__[arch](pretrained=pretrained)
elif arch in TORCHVISION_MODEL_NAMES:
try:
model = getattr(torch_models, arch)(pretrained=pretrained)
if is_inception(arch):
model = getattr(torch_models, arch)(pretrained=pretrained, transform_input=False)
else:
model = getattr(torch_models, arch)(pretrained=pretrained)
if arch == "mobilenet_v2":
patch_torchvision_mobilenet_v2(model)

except NotImplementedError:
# In torchvision 0.3, trying to download a model that has no
# pretrained image available will raise NotImplementedError
Expand Down

0 comments on commit 99afd8e

Please sign in to comment.