Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Modify image size and training for Inception Models #425

Merged
merged 12 commits into from Apr 27, 2020

Conversation

soumendukrg
Copy link
Contributor

This PR is a fix for issue #422.

The file data_loader had fixed classification image size for ImageNet as [1, 3, 224, 224]. However, all Inception models requires an input image size of [1, 3, 299, 299].

To fix this issue, I modified the apputils/image_classifier.py file to add a new parameter to the load_data function. This function calls apputils.load_data, so I changed the corresponding function in apputils/data_loader.py.

Also, image_classifier.py is modified to consider both losses from the normal classifier and aux_logits classifier of inception network, and the classification accuracy is calculated only from normal classifier.

@nzmora : Please review the changes.

PS: My fork is based on PyTorch 1.3, so all additional changes for PyTorch1.3 support are present in this PR.

@nzmora
Copy link
Contributor

nzmora commented Nov 14, 2019

Thanks @soumendukrg!

Please rebase the PR on top of the 'master' branch which now supports PyTorch 1.3.
This will remove the non-relevant parts.

Cheers
Neta

pytorch 1.3 and torchvision 0.4: initial adaptations

1. change requirements.txt (new version numbers for pytorch and torchvision)
2. Change onnx op.uniqueName to op.debugName
See: lanpa/tensorboardX@9084ab8
lanpa/tensorboardX#483

Fixes in SummaryGraph and related tests for PyTorch 1.3

* Naming of trace entries changed in 1.3. One such change is that
  the "root" input of the model is now named 'input.1' instead of
  just '0'. Fixed test that checked that.
* One of the workarounds for scope names after ONNX pass is not
  needed anymore. Removed it and updated relevant test.

adjust full_flow_tests.py for pytorch 1.3 results

Move to PyTorch 1.3.1 and torchvision 0.4.2 + fix full_flow tests

Unit tests: Adjust tolerance in test_sim_bn_fold + filter some warnings

Updated expected acc for system tests

fixed image_size and training loss, accuracy for inception models

Revert "fixed image_size and training loss, accuracy for inception models"

This reverts commit fbbd351.

Revert "Revert "fixed image_size and training loss, accuracy for inception models""

This reverts commit ed895e6.

	new file:   tests/layer_quant_params.yaml
	new file:   tests/quant_stats_after_prepare_model.yaml

delete full_system_log generated yaml files

fix typos
@soumendukrg
Copy link
Contributor Author

@nzmora I have rebased the PR as you suggested. Please review the changes. Thanks.

NOTE: Training using Inception_V3 is only possible on a single GPU as of now. This issue talks about this problem. I have checked and this problem persists in torch 1.3.0:
inception_v3 of vision 0.3.0 does not fit in DataParallel of torch 1.1.0 #1048

@nzmora
Copy link
Contributor

nzmora commented Nov 15, 2019

Thanks @soumendukrg!

We've looked at your PR and we are thinking about it. On the one hand, adding is_inception to several functions is not what we'd prefer. It couples an API that is meant to be somewhat generic (at least for image classification models) to this one specific model.
On the other hand, implementing something more generic will take more work.

The fundamental problem that you've uncovered in our API is that our assumption that data loading is dependent purely on the dataset is... wrong. When we load a DS, we also pre-process it - and the pre-processing is model-specific in corner-cases (as you've shown with inception v3). So we have to go and change the relevant API functions (incl. related functions such as classification_get_input_shape) and add the architecture string as a parameter.

A rather small functional change, can create a large code-change ripple, when the API is not right :-(.
So that's where we are - I wanted to update you. Next week we're in a conference, so I'm not sure what we'll get done. Also I wanted to let you know that because the fix might involve a large deviation from the original PR, I'm considering refactoring the code outside the PR because it'll be easier (but @guyjacob assured me that if we go this way, github has means to give you due credit for the PR in our commit). Of course, if you make changes in the PR until then, we'll consider those instead.

Thanks!
Neta

@soumendukrg
Copy link
Contributor Author

@nzmora
Thank you for your response. I appreciate that you will be giving credit for the PR in your commit. I am also busy currently with a deadline, I will update if I need any additional changes which can generalize this issue.

Further, I solved the distributed training of inception and googlenet networks using the solution posted in the PyTorch issue I referred earlier. In addition, image_classifer.py needs some adequate modification, which I can share if you are interested.

@soumendukrg
Copy link
Contributor Author

soumendukrg commented Nov 24, 2019

I discovered another flaw in inception retraining using compress_classifier.py.

Inception_V3 transforms normalizes the input inside the network itself:

            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)

One can find more details in pytorch forum post as to why is this performed:
However, this is only performed if pretrained=True when you create the model (in image_classifier.py) and when you are training from scratch/fine tuning from pretrained network. But lets say, you have saved your training at one point, and now you are resuming training from a saved checkpoint, you do not use --pretrained flag, you use --resume-from or --exp-load-weights-from. This ends up using the default pytorch normalization which you use for other networks and now your images are trained using different normalization factor.

This needs to be reviewed. I am thinking of an efficient way to solve this. Any thoughts @nzmora @guyjacob ?

@nzmora
Copy link
Contributor

nzmora commented Nov 26, 2019

Hi @soumendukrg ,
Thanks for bringing this to our attention. This problem that you raise is also present in GoogLeNet (https://github.com/pytorch/vision/blob/537f0df79c464983dc1f32b8697dbc42d3f4872b/torchvision/models/googlenet.py#L123-L130).

So how do we handle this issue?
If we zoom-out and restate the problem you raise, it translates to this requirement: there are corner cases where there is a dependency between the way we originally create the model, and the way we train (and evaluate) a resumed model. Between creation and resuming, there is checkpoint creation and checkpoint loading. Checkpointing break the flow in time, and the checkpoint file itself is the only way to pass (state) information between the creation time and the resumption time. That's a convoluted way of stating that we need to save the model's transform_input attribute (if it exists) in the checkpoint file, and correctly assign this attribute in the model when we load it from a checkpoint. Sort of (but not exactly) like we do here:
https://github.com/NervanaSystems/distiller/blob/8b341593e8fe71919de149cdd00e269061cecaed/distiller/apputils/checkpoint.py#L65-L70
and here:
https://github.com/NervanaSystems/distiller/blob/8b341593e8fe71919de149cdd00e269061cecaed/distiller/apputils/checkpoint.py#L172-L173

I need to sleep on it. Let me know what you think.
Thanks,
Neta

@nzmora nzmora requested a review from guyjacob November 26, 2019 00:16
@nzmora
Copy link
Contributor

nzmora commented Nov 26, 2019

@barrh - please take a look and give us your opinion.

@barrh
Copy link
Contributor

barrh commented Nov 26, 2019

First, thanks for bringing this up @soumendukrg . This is very problematic that these issues with using inception models go unnoticed without any useful warning.

To my understanding, the following are model-dependent:

  1. input size
  2. input normalized or not, and
  3. loss computation method

The dataset load functionality should be kept generic, therefore, I would pass the (1) and (2) as orthogonal arguments. We can write a small lookup utility, based on model name/type, to find the matching values for those. e.g. get_input_shape('inception') and/or get_model_input_attr('inception') and feed those into load_data(). Referring the 'is_inception' arguments, as @nzmora commented before me, we should be careful not to make this fix inception-specific.
Another vector to tackle this is adding the relevant attributes to the corresponding torchvision modules, or at least throw some warning when feeding it the wrong tensor size.

Regarding the loss computation (3), since it's done in compress_classifier.py, I think the implementation is fine (great documentation!), but moving it into a separate function could be better to clean up the main loop.

Other notes:

  1. Is there reference to the loss computation on Inception v3? (why 40%?)
  2. PEP8: Never compare boolean vals with '==' (line 164 in data_loaders.py), use if is_inception:.

nzmora added a commit that referenced this pull request Nov 26, 2019
This is based on PR #425 (@soumendukrg).
The file data_loader had fixed classification image size for ImageNet
as [1, 3, 224, 224]. However, all Inception models require input images
size of [1, 3, 299, 299].  Thus, we need to change the pre-processing.

This commit does not include other required fixes (loss handling;
specialized input normalization).

Co-authored-by: Soumendu Kumar Ghosh <soumendu@ece.iitkgp.ernet.in>
@nzmora
Copy link
Contributor

nzmora commented Nov 26, 2019

Hi @soumendukrg, @barrh
I felt like we were possibly confusing you with all these words, so I quickly coded what I think is a good solution for the first issue you raised.
I placed it in this branch: https://github.com/NervanaSystems/distiller/tree/inception_support
We can switch to this new branch, or you can rewrite your branch based on this new code - whatever you prefer, but notice that this does not contain all of the fixes from your PR.

It is basically a slightly more generic way to replace is_inception.

Thanks!
Neta

@nzmora
Copy link
Contributor

nzmora commented Nov 27, 2019

Hi @soumendukrg ,

Some further thoughts (just doing the math, so we have this on record): From looking at the Inception papers I could not deduce the preprocessing they recommend. However, looking at some Keras and TF code, I understand that the standard way of preprocessing Inception input is: x = (x - 0.5) / 0.5 where they assume mean and std-dev of 0.5.

https://github.com/tensorflow/models/blob/30165f86e890cf84dd59d1d3fed6e66ea5b44c78/research/slim/preprocessing/inception_preprocessing.py#L304-L305

Let

inception_pp(x) := (x - 0.5) / 0.5
torchvision_pp(x) = (x - mean) / std
inv_torchvision_pp(x) = x * std + mean

where inv_torchvision_pp inverses torchvision_pp.

Assuming inputs x are first pre-processed by inception_pp(x), we need to apply the inverse of this pre-processing, and then perform the expected Inception pre-processing:

x' = torchvision_pp(x)
x'' = inception_pp(inv_torchvision_pp(x`)

if we expand this we get:

x'' =( x' * std + mean - 0.5) / 0.5 = x' * (std / 0.5) + (mean - 0.5) / 0.5

And this formula matches the code you gave above:

 x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
 x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
 x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
 x = torch.cat((x_ch0, x_ch1, x_ch2), 1)

Despite what I wrote above, about the "standard" Inception preprocessing being different than other models, I think that this is not true for TorchVision's Inception.
For one, the documentation states:

The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]

This is quite explicit.
But the Cadene documentation says the opposite. E.g. about the mean-subtraction they write:

[0.5, 0.5, 0.5] for inception* networks,
[0.485, 0.456, 0.406] for resnet* networks.

And if you reverse-engineer the Cadene code you see that 'inceptionv3' from Cadene uses 'inception_v3' from TorchVision! One of them is incorrect...

So I ran two evaluations on inception_v3 (classifier_compression/compress_classifier.py /datasets/imagenet/ -e --pretrained --arch=inception_v3): once using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225] (this is the torchvision preprocessing) and once using mean = [0.5, 0.5, 0.5] and std = [0.5, 0.5, 0.5] (this is the Inception preprocessing).

architecture torchvision preprocessing Inception preprocessing TorchVision reported
inception_v3 (torchvision) Top1: 77.598 Top5: 93.572 Top1: 77.318 Top5: 93.402 Top1: 77.45 and Top5: 93.56
inceptionv3 (cadene) Top1: 77.598 Top5: 93.572 Top1: 77.318 Top5: 93.402 Top1: 77.45 and Top5: 93.56
googlenet (torchvision) Top1: 64.776 Top5: 86.252 Top1: 71.912 Top5: 90.728 Top1:69.78 Top5: 89.53

The TorchVision documentation claims we should be expecting Top1: 77.45 and Top5: 93.56.

This is the second indication that we should be using only the torchvision preprocessing. So I changed the invocation of Inception models such that transform_input=False explicitly and I kept the torchvision processing only.
See:
https://github.com/NervanaSystems/distiller/compare/inception_support#diff-0da6e8c73a1ed6ebf39beae6843daa8bR164-R174
and
https://github.com/NervanaSystems/distiller/compare/inception_support#diff-42913a804e812de234e5389dd3dbafd5R173-R176

Cheers
Neta

@soumendukrg
Copy link
Contributor Author

soumendukrg commented Jan 31, 2020

Hi @nzmora , sorry I took so long to reply. I have reviewed the branch you created, and that definitely handles the inception models in an efficient way. I am rewriting my code using your branch and will make a separate function for the loss function according to @barrh suggestion.

Regarding the normalization, thanks for taking a deeper look into this issue. I got the exact same results as you did when I used the 2 different types of preprocessing. Infact, the googlenet results match exactly with the one you reported with inception preprocessing. I will update the PR soon.

@soumendukrg
Copy link
Contributor Author

@nzmora , I have updated the PR with all inception related changes which you made in your branch inception_support. Additionally, I have created a separate function for training loss calculation for inception networks with lots of documentations. I have also merged it with the current master, as you can already see. Please review and let me know if I need to make any changes. Thanks.

@nzmora nzmora merged commit 9e78723 into IntelLabs:master Apr 27, 2020
OZA15015 pushed a commit to OZA15015/pruning that referenced this pull request Sep 6, 2020
* Merge pytorch 1.3 commits

This PR is a fix for issue #422.

1. ImageNet models usually use input size [batch, 3, 224, 224], but all Inception models require an input image size of [batch, 3, 299, 299].

2. Inception models have auxiliary branches which contribute to the loss only during training.  The reported classification loss only considers the main classification loss.

3. Inception_V3 normalizes the input inside the network itself.  More details can be found in @soumendukrg's PR #425 [comments](IntelLabs/distiller#425 (comment)).

NOTE: Training using Inception_V3 is only possible on a single GPU as of now. This issue talks about this problem. I have checked and this problem persists in torch 1.3.0:
[inception_v3 of vision 0.3.0 does not fit in DataParallel of torch 1.1.0 #1048](pytorch/vision#1048)

Co-authored-by: Neta Zmora <neta.zmora@intel.com>
michaelbeale-IL pushed a commit that referenced this pull request Apr 24, 2023
* Merge pytorch 1.3 commits

This PR is a fix for issue #422.

1. ImageNet models usually use input size [batch, 3, 224, 224], but all Inception models require an input image size of [batch, 3, 299, 299].

2. Inception models have auxiliary branches which contribute to the loss only during training.  The reported classification loss only considers the main classification loss.

3. Inception_V3 normalizes the input inside the network itself.  More details can be found in @soumendukrg's PR #425 [comments](#425 (comment)).

NOTE: Training using Inception_V3 is only possible on a single GPU as of now. This issue talks about this problem. I have checked and this problem persists in torch 1.3.0:
[inception_v3 of vision 0.3.0 does not fit in DataParallel of torch 1.1.0 #1048](pytorch/vision#1048)

Co-authored-by: Neta Zmora <neta.zmora@intel.com>
michaelbeale-IL pushed a commit that referenced this pull request Apr 24, 2023
This is based on PR #425 (@soumendukrg).
The file data_loader had fixed classification image size for ImageNet
as [1, 3, 224, 224]. However, all Inception models require input images
size of [1, 3, 299, 299].  Thus, we need to change the pre-processing.

This commit does not include other required fixes (loss handling;
specialized input normalization).

Co-authored-by: Soumendu Kumar Ghosh <soumendu@ece.iitkgp.ernet.in>
michaelbeale-IL pushed a commit that referenced this pull request Apr 24, 2023
This is based on PR #425 (@soumendukrg).
The file data_loader had fixed classification image size for ImageNet
as [1, 3, 224, 224]. However, all Inception models require input images
size of [1, 3, 299, 299].  Thus, we need to change the pre-processing.

This commit does not include other required fixes (loss handling;
specialized input normalization).

Co-authored-by: Soumendu Kumar Ghosh <soumendu@ece.iitkgp.ernet.in>
michaelbeale-IL pushed a commit that referenced this pull request Apr 24, 2023
This is based on PR #425 (@soumendukrg).
The file data_loader had fixed classification image size for ImageNet
as [1, 3, 224, 224]. However, all Inception models require input images
size of [1, 3, 299, 299].  Thus, we need to change the pre-processing.

This commit does not include other required fixes (loss handling;
specialized input normalization).

Co-authored-by: Soumendu Kumar Ghosh <soumendu@ece.iitkgp.ernet.in>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants