Skip to content

Commit

Permalink
fix cyclic import (#350)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Nov 9, 2020
1 parent 6674ece commit b697407
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
42 changes: 23 additions & 19 deletions pl_bolts/models/self_supervised/resnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,22 @@
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet50_bn', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']


model_urls = {
__all__ = [
'ResNet',
'resnet18',
'resnet34',
'resnet50',
'resnet50_bn',
'resnet101',
'resnet152',
'resnext50_32x4d',
'resnext101_32x8d',
'wide_resnet50_2',
'wide_resnet101_2',
]


MODEL_URLS = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50_bn': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
Expand Down Expand Up @@ -265,7 +275,7 @@ def forward(self, x):
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
state_dict = load_state_dict_from_url(MODEL_URLS[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
Expand All @@ -279,8 +289,7 @@ def resnet18(pretrained: bool = False, progress: bool = True, **kwargs):
pretrained: If True, returns a model pre-trained on ImageNet
progress: If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)


def resnet34(pretrained=False, progress=True, **kwargs):
Expand All @@ -291,8 +300,7 @@ def resnet34(pretrained=False, progress=True, **kwargs):
pretrained: If True, returns a model pre-trained on ImageNet
progress: If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)


def resnet50(pretrained: bool = False, progress: bool = True, **kwargs):
Expand All @@ -303,8 +311,7 @@ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs):
pretrained: If True, returns a model pre-trained on ImageNet
progress: If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)


def resnet50_bn(pretrained: bool = False, progress: bool = True, **kwargs):
Expand All @@ -315,8 +322,7 @@ def resnet50_bn(pretrained: bool = False, progress: bool = True, **kwargs):
pretrained: If True, returns a model pre-trained on ImageNet
progress: If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50_bn', BottleneckBN, [3, 4, 6, 3], pretrained, progress,
**kwargs)
return _resnet('resnet50_bn', BottleneckBN, [3, 4, 6, 3], pretrained, progress, **kwargs)


def resnet101(pretrained: bool = False, progress: bool = True, **kwargs):
Expand All @@ -327,8 +333,7 @@ def resnet101(pretrained: bool = False, progress: bool = True, **kwargs):
pretrained: If True, returns a model pre-trained on ImageNet
progress: If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', BottleneckBN, [3, 4, 23, 3], pretrained, progress,
**kwargs)
return _resnet('resnet101', BottleneckBN, [3, 4, 23, 3], pretrained, progress, **kwargs)


def resnet152(pretrained: bool = False, progress: bool = True, **kwargs):
Expand All @@ -339,8 +344,7 @@ def resnet152(pretrained: bool = False, progress: bool = True, **kwargs):
pretrained: If True, returns a model pre-trained on ImageNet
progress: If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs)


def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion pl_bolts/utils/self_supervised.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from pl_bolts.models.self_supervised import resnets
from pl_bolts.utils.semi_supervised import Identity


def torchvision_ssl_encoder(name, pretrained=False, return_all_feature_maps=False):
from pl_bolts.models.self_supervised import resnets

pretrained_model = getattr(resnets, name)(pretrained=pretrained, return_all_feature_maps=return_all_feature_maps)

pretrained_model.fc = Identity()
Expand Down

0 comments on commit b697407

Please sign in to comment.