Skip to content

Commit

Permalink
change builder to static method
Browse files Browse the repository at this point in the history
  • Loading branch information
hangzhaomit committed Aug 1, 2019
1 parent 2b8f2d4 commit 90f9d6d
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 40 deletions.
5 changes: 2 additions & 3 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,11 @@ def main(cfg, gpu):
torch.cuda.set_device(gpu)

# Network Builders
builder = ModelBuilder()
net_encoder = builder.build_encoder(
net_encoder = ModelBuilder.build_encoder(
arch=cfg.MODEL.arch_encoder.lower(),
fc_dim=cfg.MODEL.fc_dim,
weights=cfg.MODEL.weights_encoder)
net_decoder = builder.build_decoder(
net_decoder = ModelBuilder.build_decoder(
arch=cfg.MODEL.arch_decoder.lower(),
fc_dim=cfg.MODEL.fc_dim,
num_class=cfg.DATASET.num_class,
Expand Down
5 changes: 2 additions & 3 deletions eval_multipro.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,11 @@ def worker(cfg, gpu_id, start_idx, end_idx, result_queue):
num_workers=2)

# Network Builders
builder = ModelBuilder()
net_encoder = builder.build_encoder(
net_encoder = ModelBuilder.build_encoder(
arch=cfg.MODEL.arch_encoder.lower(),
fc_dim=cfg.MODEL.fc_dim,
weights=cfg.MODEL.weights_encoder)
net_decoder = builder.build_decoder(
net_decoder = ModelBuilder.build_decoder(
arch=cfg.MODEL.arch_decoder.lower(),
fc_dim=cfg.MODEL.fc_dim,
num_class=cfg.DATASET.num_class,
Expand Down
57 changes: 29 additions & 28 deletions models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torchvision
from . import resnet, resnext, mobilenet, hrnet
from lib.nn import SynchronizedBatchNorm2d
BatchNorm2d = SynchronizedBatchNorm2d


class SegmentationModuleBase(nn.Module):
Expand Down Expand Up @@ -47,23 +48,10 @@ def forward(self, feed_dict, *, segSize=None):
return pred


def conv3x3(in_planes, out_planes, stride=1, has_bias=False):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=has_bias)


def conv3x3_bn_relu(in_planes, out_planes, stride=1):
return nn.Sequential(
conv3x3(in_planes, out_planes, stride),
SynchronizedBatchNorm2d(out_planes),
nn.ReLU(inplace=True),
)


class ModelBuilder():
class ModelBuilder:
# custom weights initialization
def weights_init(self, m):
@staticmethod
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.kaiming_normal_(m.weight.data)
Expand All @@ -73,7 +61,8 @@ def weights_init(self, m):
#elif classname.find('Linear') != -1:
# m.weight.data.normal_(0.0, 0.0001)

def build_encoder(self, arch='resnet50dilated', fc_dim=512, weights=''):
@staticmethod
def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''):
pretrained = True if len(weights) == 0 else False
arch = arch.lower()
if arch == 'mobilenetv2dilated':
Expand Down Expand Up @@ -113,14 +102,16 @@ def build_encoder(self, arch='resnet50dilated', fc_dim=512, weights=''):
else:
raise Exception('Architecture undefined!')

# net_encoder.apply(self.weights_init)
# encoders are usually pretrained
# net_encoder.apply(ModelBuilder.weights_init)
if len(weights) > 0:
print('Loading weights for net_encoder')
net_encoder.load_state_dict(
torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
return net_encoder

def build_decoder(self, arch='ppm_deepsup',
@staticmethod
def build_decoder(arch='ppm_deepsup',
fc_dim=512, num_class=150,
weights='', use_softmax=False):
arch = arch.lower()
Expand Down Expand Up @@ -159,14 +150,24 @@ def build_decoder(self, arch='ppm_deepsup',
else:
raise Exception('Architecture undefined!')

net_decoder.apply(self.weights_init)
net_decoder.apply(ModelBuilder.weights_init)
if len(weights) > 0:
print('Loading weights for net_decoder')
net_decoder.load_state_dict(
torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
return net_decoder


def conv3x3_bn_relu(in_planes, out_planes, stride=1):
"3x3 convolution + BN + relu"
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=3,
stride=stride, padding=1, bias=False),
BatchNorm2d(out_planes),
nn.ReLU(inplace=True),
)


class Resnet(nn.Module):
def __init__(self, orig_resnet):
super(Resnet, self).__init__()
Expand Down Expand Up @@ -397,15 +398,15 @@ def __init__(self, num_class=150, fc_dim=4096,
self.ppm.append(nn.Sequential(
nn.AdaptiveAvgPool2d(scale),
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
SynchronizedBatchNorm2d(512),
BatchNorm2d(512),
nn.ReLU(inplace=True)
))
self.ppm = nn.ModuleList(self.ppm)

self.conv_last = nn.Sequential(
nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
kernel_size=3, padding=1, bias=False),
SynchronizedBatchNorm2d(512),
BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Dropout2d(0.1),
nn.Conv2d(512, num_class, kernel_size=1)
Expand Down Expand Up @@ -446,7 +447,7 @@ def __init__(self, num_class=150, fc_dim=4096,
self.ppm.append(nn.Sequential(
nn.AdaptiveAvgPool2d(scale),
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
SynchronizedBatchNorm2d(512),
BatchNorm2d(512),
nn.ReLU(inplace=True)
))
self.ppm = nn.ModuleList(self.ppm)
Expand All @@ -455,7 +456,7 @@ def __init__(self, num_class=150, fc_dim=4096,
self.conv_last = nn.Sequential(
nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
kernel_size=3, padding=1, bias=False),
SynchronizedBatchNorm2d(512),
BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Dropout2d(0.1),
nn.Conv2d(512, num_class, kernel_size=1)
Expand Down Expand Up @@ -511,7 +512,7 @@ def __init__(self, num_class=150, fc_dim=4096,
self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale))
self.ppm_conv.append(nn.Sequential(
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
SynchronizedBatchNorm2d(512),
BatchNorm2d(512),
nn.ReLU(inplace=True)
))
self.ppm_pooling = nn.ModuleList(self.ppm_pooling)
Expand All @@ -520,16 +521,16 @@ def __init__(self, num_class=150, fc_dim=4096,

# FPN Module
self.fpn_in = []
for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer
for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer
self.fpn_in.append(nn.Sequential(
nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False),
SynchronizedBatchNorm2d(fpn_dim),
BatchNorm2d(fpn_dim),
nn.ReLU(inplace=True)
))
self.fpn_in = nn.ModuleList(self.fpn_in)

self.fpn_out = []
for i in range(len(fpn_inplanes) - 1): # skip the top layer
for i in range(len(fpn_inplanes) - 1): # skip the top layer
self.fpn_out.append(nn.Sequential(
conv3x3_bn_relu(fpn_dim, fpn_dim, 1),
))
Expand Down
5 changes: 2 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,11 @@ def main(cfg, gpu):
torch.cuda.set_device(gpu)

# Network Builders
builder = ModelBuilder()
net_encoder = builder.build_encoder(
net_encoder = ModelBuilder.build_encoder(
arch=cfg.MODEL.arch_encoder,
fc_dim=cfg.MODEL.fc_dim,
weights=cfg.MODEL.weights_encoder)
net_decoder = builder.build_decoder(
net_decoder = ModelBuilder.build_decoder(
arch=cfg.MODEL.arch_decoder,
fc_dim=cfg.MODEL.fc_dim,
num_class=cfg.DATASET.num_class,
Expand Down
5 changes: 2 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,11 @@ def adjust_learning_rate(optimizers, cur_iter, cfg):

def main(cfg, gpus):
# Network Builders
builder = ModelBuilder()
net_encoder = builder.build_encoder(
net_encoder = ModelBuilder.build_encoder(
arch=cfg.MODEL.arch_encoder.lower(),
fc_dim=cfg.MODEL.fc_dim,
weights=cfg.MODEL.weights_encoder)
net_decoder = builder.build_decoder(
net_decoder = ModelBuilder.build_decoder(
arch=cfg.MODEL.arch_decoder.lower(),
fc_dim=cfg.MODEL.fc_dim,
num_class=cfg.DATASET.num_class,
Expand Down

0 comments on commit 90f9d6d

Please sign in to comment.