Skip to content

Commit

Permalink
making loading of nets more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
achaiah authored and achaiah committed Aug 14, 2019
1 parent 8f2c540 commit 5ff3975
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 33 deletions.
4 changes: 2 additions & 2 deletions pywick/models/segmentation/da_basenets/resnetv1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,11 @@ def resnet152_v1b(pretrained=False, **kwargs):
return model


def resnet50_v1s(pretrained=False, root='~/.torch/models', **kwargs):
def resnet50_v1s(pretrained=False, model_root='~/.torch/models', **kwargs):
model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, **kwargs)
if pretrained:
from .model_store import get_resnet_file
model.load_state_dict(torch.load(get_resnet_file('resnet50', root=root)), strict=False)
model.load_state_dict(torch.load(get_resnet_file('resnet50', root=model_root)), strict=False)
return model


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# this is for encoder part
# resnet encoder
class Resnet(nn.Module):
def __init__(self, orig_resnet):
def __init__(self, orig_resnet, **kwargs):
super(Resnet, self).__init__()

# take pretrained resnet, except AvgPool and FC
Expand Down Expand Up @@ -52,18 +52,15 @@ def forward(self, x, return_feature_maps=False):

# dilated resnet encoder
class ResnetDilated(nn.Module):
def __init__(self, orig_resnet, dilate_scale=8):
def __init__(self, orig_resnet, dilate_scale=8, **kwargs):
super(ResnetDilated, self).__init__()
from functools import partial

if dilate_scale == 8:
orig_resnet.layer3.apply(
partial(self._nostride_dilate, dilate=2))
orig_resnet.layer4.apply(
partial(self._nostride_dilate, dilate=4))
orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
elif dilate_scale == 16:
orig_resnet.layer4.apply(
partial(self._nostride_dilate, dilate=2))
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))

# take pretrained resnet, except AvgPool and FC
self.conv1 = orig_resnet.conv1
Expand Down Expand Up @@ -416,25 +413,25 @@ def build_encoder(self, arch='resnet50_dilated8', fc_dim=512, weights='', **kwar
net_encoder = ResnetDilated(orig_resnet,
dilate_scale=16)
elif arch == 'resnet50':
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
orig_resnet = resnet.resnet50(**kwargs)
net_encoder = Resnet(orig_resnet)
elif arch == 'resnet50_dilated8':
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
orig_resnet = resnet.resnet50(**kwargs)
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
elif arch == 'resnet50_dilated16':
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
orig_resnet = resnet.resnet50(**kwargs)
net_encoder = ResnetDilated(orig_resnet, dilate_scale=16)
elif arch == 'resnet101':
orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
orig_resnet = resnet.resnet101(**kwargs)
net_encoder = Resnet(orig_resnet)
elif arch == 'resnet101_dilated8':
orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
orig_resnet = resnet.resnet101(**kwargs)
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
elif arch == 'resnet101_dilated16':
orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
orig_resnet = resnet.resnet101(**kwargs)
net_encoder = ResnetDilated(orig_resnet, dilate_scale=16)
elif arch == 'resnext101':
orig_resnext = resnext.__dict__['resnext101'](pretrained=pretrained)
orig_resnext = resnext.resnext101(**kwargs)
net_encoder = Resnet(orig_resnext) # we can still use class Resnet
else:
raise Exception('Architecture undefined!')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forward(self, x):

class ResNet(nn.Module):

def __init__(self, block, layers, num_classes=1000):
def __init__(self, block, layers, num_classes=1000, **kwargs):
self.inplanes = 128
super(ResNet, self).__init__()
self.conv1 = conv3x3(3, 64, stride=2)
Expand Down Expand Up @@ -166,19 +166,19 @@ def forward(self, x):



def resnet50(pretrained=False, **kwargs):
def resnet50(pretrained=True, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on Places
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(load_url(model_urls['resnet50']), strict=False)
model.load_state_dict(load_url(model_urls['resnet50'], **kwargs), strict=False)
return model


def resnet101(pretrained=False, **kwargs):
def resnet101(pretrained=True, **kwargs):
"""Constructs a ResNet-101 model.
Args:
Expand All @@ -190,12 +190,12 @@ def resnet101(pretrained=False, **kwargs):
return model


def load_url(url, model_dir='~/.torch/models', map_location='cpu', **kwargs):
model_dir = os.path.expanduser(model_dir)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
def load_url(url, model_root='~/.torch/models', map_location='cpu', **kwargs):
model_root = os.path.expanduser(model_root)
if not os.path.exists(model_root):
os.makedirs(model_root)
filename = url.split('/')[-1]
cached_file = os.path.join(model_dir, filename)
cached_file = os.path.join(model_root, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
urlretrieve(url, cached_file)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(self, x):

class ResNeXt(nn.Module):

def __init__(self, block, layers, groups=32, num_classes=1000):
def __init__(self, block, layers, groups=32, num_classes=1000, **kwargs):
self.inplanes = 128
super(ResNeXt, self).__init__()
self.conv1 = conv3x3(3, 64, stride=2)
Expand Down Expand Up @@ -130,23 +130,23 @@ def forward(self, x):
return x


def resnext101(pretrained=False, **kwargs):
def resnext101(pretrained=True, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on Places
"""
model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(load_url(model_urls['resnext101']), strict=False)
model.load_state_dict(load_url(model_urls['resnext101'], **kwargs), strict=False)
return model


def load_url(url, model_dir='./models/backbones/pretrained', map_location=None):
if not os.path.exists(model_dir):
os.makedirs(model_dir)
def load_url(url, model_root='/models/pytorch', map_location=None, **kwargs):
if not os.path.exists(model_root):
os.makedirs(model_root)
filename = url.split('/')[-1]
cached_file = os.path.join(model_dir, filename)
cached_file = os.path.join(model_root, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
urlretrieve(url, cached_file)
Expand Down

0 comments on commit 5ff3975

Please sign in to comment.