Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Low Test Accuracy using Inception_V3 #422

Closed
soumendukrg opened this issue Nov 7, 2019 · 5 comments
Closed

Low Test Accuracy using Inception_V3 #422

soumendukrg opened this issue Nov 7, 2019 · 5 comments

Comments

@soumendukrg
Copy link
Contributor

soumendukrg commented Nov 7, 2019

I found out that the test accuracy for pretrained Inception_V3 model on ImageNet dataset in distiller are not the same as reported by the paper as well as by torchvision(PyTorch) documentation. Here are the results:
Top 1: 69.538 Top5: 88.654
Expected Top 1: 77.45 Top5: 93.56

I am using the following code to evaluate:
$ python compress_classifier.py /path_to_imagenet2012/ -a=inception_v3 --gpu 0 -e --pretrained

However, the accuracy for other networks like ResNet, DenseNet, AlexNet seemed to be almost similar to the ones reported in PyTorch docs. Do I have to specify any additional argument for Inception_V3?

@soumendukrg
Copy link
Contributor Author

I found out that the file data_loaders has fixed classification image size for imagenet as 1, 3, 224, 224. However, PyTorch inception_v3 requires image size of 1, 3, 299, 299.

When I modified the file to use the image size corresponding to inception, the test accuracy reported are:
Top1: 77.318 Top5: 93.402 , which are almost same as reported by Torchvision classification models.

This file and linked files needs to be modified. @nzmora unless there is any other method to fix this issue.

@nzmora
Copy link
Contributor

nzmora commented Nov 7, 2019

Hi @soumendukrg,

Thanks - this is an important bug to fix!
Would you care to issue a PR with the fix? It would help speed the delivery of a fix to others.
Thanks!
Neta

@soumendukrg
Copy link
Contributor Author

Thanks for responding. I will submit a PR with the fix. In fact for training we need to change data size as well as loss function in image_classifier.py, at line 501.

    if not early_exit_mode(args):
        if isinstance(output, tuple):
            loss = sum((criterion(o,target) for o in output))
        else:
            loss = criterion(output, target)

        # Measure accuracy
        if isinstance(output, tuple):
            classerr.add(output[0].data, target)
        else:
            classerr.add(output.data, target)`

This adds the aux logits loss with normal loss thereby increasing training efficiency, but I read somewhere that we can omit that loss since during validation/test, that part is essentially cut of from network. In that case, the above code will change. What do you suggest I do for this part?

nzmora added a commit that referenced this issue Apr 27, 2020
* Merge pytorch 1.3 commits

This PR is a fix for issue #422.

1. ImageNet models usually use input size [batch, 3, 224, 224], but all Inception models require an input image size of [batch, 3, 299, 299].

2. Inception models have auxiliary branches which contribute to the loss only during training.  The reported classification loss only considers the main classification loss.

3. Inception_V3 normalizes the input inside the network itself.  More details can be found in @soumendukrg's PR #425 [comments](#425 (comment)).

NOTE: Training using Inception_V3 is only possible on a single GPU as of now. This issue talks about this problem. I have checked and this problem persists in torch 1.3.0:
[inception_v3 of vision 0.3.0 does not fit in DataParallel of torch 1.1.0 #1048](pytorch/vision#1048)

Co-authored-by: Neta Zmora <neta.zmora@intel.com>
@nzmora
Copy link
Contributor

nzmora commented Apr 27, 2020

Thanks @soumendukrg for the fix! It took some time, but we finally merged the PR.

@nzmora nzmora closed this as completed Apr 27, 2020
@soumendukrg
Copy link
Contributor Author

Thanks @nzmora for the merge. I have been working on the object detection compression and have one suggestion/fix which I will put up in a separate issue.

michaelbeale-IL pushed a commit that referenced this issue Apr 24, 2023
* Merge pytorch 1.3 commits

This PR is a fix for issue #422.

1. ImageNet models usually use input size [batch, 3, 224, 224], but all Inception models require an input image size of [batch, 3, 299, 299].

2. Inception models have auxiliary branches which contribute to the loss only during training.  The reported classification loss only considers the main classification loss.

3. Inception_V3 normalizes the input inside the network itself.  More details can be found in @soumendukrg's PR #425 [comments](#425 (comment)).

NOTE: Training using Inception_V3 is only possible on a single GPU as of now. This issue talks about this problem. I have checked and this problem persists in torch 1.3.0:
[inception_v3 of vision 0.3.0 does not fit in DataParallel of torch 1.1.0 #1048](pytorch/vision#1048)

Co-authored-by: Neta Zmora <neta.zmora@intel.com>
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants