From e6bfb0c23a146c78ec96c92bfa9c4416edf17758 Mon Sep 17 00:00:00 2001 From: LaserBit <31342033+LaserBit@users.noreply.github.com> Date: Tue, 22 Dec 2020 18:00:29 +0900 Subject: [PATCH 1/3] Change the classifier input from 2048 to 1000. --- docs/source/transfer_learning.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/transfer_learning.rst b/docs/source/transfer_learning.rst index ba44203721b98..d983adda1e13c 100644 --- a/docs/source/transfer_learning.rst +++ b/docs/source/transfer_learning.rst @@ -58,7 +58,7 @@ Example: Imagenet (computer Vision) self.feature_extractor.eval() # use the pretrained model to classify cifar-10 (10 image classes) - self.classifier = nn.Linear(2048, num_target_classes) + self.classifier = nn.Linear(1000, num_target_classes) def forward(self, x): representations = self.feature_extractor(x) From e692b405901ca8cbb77aab74af31fa7fa3d3d027 Mon Sep 17 00:00:00 2001 From: LaserBit <31342033+LaserBit@users.noreply.github.com> Date: Sat, 26 Dec 2020 16:55:58 +0900 Subject: [PATCH 2/3] Update docs for Imagenet example Thanks @rohitgr7 --- docs/source/transfer_learning.rst | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/docs/source/transfer_learning.rst b/docs/source/transfer_learning.rst index d983adda1e13c..2f75ec664a158 100644 --- a/docs/source/transfer_learning.rst +++ b/docs/source/transfer_learning.rst @@ -52,16 +52,23 @@ Example: Imagenet (computer Vision) class ImagenetTransferLearning(LightningModule): def __init__(self): + super().__init__() + # init a pretrained resnet - num_target_classes = 10 - self.feature_extractor = models.resnet50(pretrained=True) - self.feature_extractor.eval() + backbone = models.resnet50(pretrained=True) + num_filters = backbone.fc.in_features + _layers = list(backbone.children())[:-1] + self.feature_extractor = torch.nn.Sequential(*_layers) # use the pretrained model to classify cifar-10 (10 image classes) - self.classifier = nn.Linear(1000, num_target_classes) + num_target_classes = 10 + self.classifier = nn.Linear(num_filters, num_target_classes) def forward(self, x): - representations = self.feature_extractor(x) + self.feature_extractor.eval() + batch_size = x.size(0) + with torch.no_grad(): + representations = self.feature_extractor(x).view(batch_size, -1) x = self.classifier(representations) ... From 06c2df8361a05c0f614fc8aeada6b86a198071c6 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 5 Jan 2021 01:22:17 +0530 Subject: [PATCH 3/3] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- docs/source/transfer_learning.rst | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/source/transfer_learning.rst b/docs/source/transfer_learning.rst index 2f75ec664a158..3b8f5b004974e 100644 --- a/docs/source/transfer_learning.rst +++ b/docs/source/transfer_learning.rst @@ -57,8 +57,8 @@ Example: Imagenet (computer Vision) # init a pretrained resnet backbone = models.resnet50(pretrained=True) num_filters = backbone.fc.in_features - _layers = list(backbone.children())[:-1] - self.feature_extractor = torch.nn.Sequential(*_layers) + layers = list(backbone.children())[:-1] + self.feature_extractor = torch.nn.Sequential(*layers) # use the pretrained model to classify cifar-10 (10 image classes) num_target_classes = 10 @@ -66,9 +66,8 @@ Example: Imagenet (computer Vision) def forward(self, x): self.feature_extractor.eval() - batch_size = x.size(0) with torch.no_grad(): - representations = self.feature_extractor(x).view(batch_size, -1) + representations = self.feature_extractor(x).flatten(1) x = self.classifier(representations) ...