Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MobileNet V2 Implementation for segmentation #19

Merged
merged 4 commits into from
Feb 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Empty file added model/segmentation/__init__.py
Empty file.
21 changes: 15 additions & 6 deletions model/segmentation/deeplabV3_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


# add the FedML root directory to the python path
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../FedML")))
from fedml_api.model.cv.batchnorm_utils import SynchronizedBatchNorm2d
from fedml_api.model.cv.xception import *
from .resnet import *
from .mobilenet_v2 import *

class _ASPPModule(nn.Module):
def __init__(self, inplanes, planes, dilation, BatchNorm):
Expand Down Expand Up @@ -156,20 +157,28 @@ def _init_weight(self):
m.bias.data.zero_()

class FeatureExtractor(nn.Module):
def __init__(self, backbone, n_channels, output_stride, BatchNorm, pretrained):
def __init__(self, backbone, n_channels, output_stride, BatchNorm, pretrained, num_classes):
super(FeatureExtractor, self).__init__()
self.backbone = self.build_backbone(backbone=backbone, n_channels=n_channels, output_stride=output_stride, BatchNorm=BatchNorm, pretrained=pretrained)
self.backbone = self.build_backbone(backbone=backbone, n_channels=n_channels, output_stride=output_stride, BatchNorm=BatchNorm, pretrained=pretrained, num_classes=num_classes)

def forward(self, input):
x, low_level_feat = self.backbone(input)
return x, low_level_feat

@staticmethod
def build_backbone(backbone='xception', n_channels=3, output_stride=16, BatchNorm=nn.BatchNorm2d, pretrained=True):
def build_backbone(backbone='xception', n_channels=3, output_stride=16, BatchNorm=nn.BatchNorm2d, pretrained=True, num_classes=21):
if backbone == 'xception':
return AlignedXception(inplanes = n_channels, output_stride = output_stride, BatchNorm=BatchNorm, pretrained=pretrained)
elif backbone == 'resnet':
return ResNet101(output_stride, BatchNorm, pretrained=pretrained)
elif backbone == 'mobilenet':
backbone_model = MobileNetV2Encoder(output_stride=output_stride, batch_norm=BatchNorm, pretrained=pretrained)
backbone_model.low_level_features = backbone_model.features[0:4]
backbone_model.high_level_features = backbone_model.features[4:-1]
backbone_model.features = None
backbone_model.classifier = None
return_layers = {'high_level_features': 'out', 'low_level_features': 'low_level'}
return IntermediateLayerGetter(backbone_model, return_layers=return_layers)
else:
raise NotImplementedError

Expand Down Expand Up @@ -211,7 +220,7 @@ def __init__(self, backbone='resnet', image_size=torch.Size([513, 513]) , nInput
BatchNorm2d = nn.BatchNorm2d

self.n_classes = n_classes
self.feature_extractor = FeatureExtractor(backbone=backbone, n_channels=nInputChannels, output_stride=output_stride, BatchNorm=BatchNorm2d, pretrained=pretrained)
self.feature_extractor = FeatureExtractor(backbone=backbone, n_channels=nInputChannels, output_stride=output_stride, BatchNorm=BatchNorm2d, pretrained=pretrained, num_classes=n_classes)
self.encoder_decoder = EncoderDecoder(backbone=backbone, image_size = torch.Size([513, 513]), output_stride=output_stride, BatchNorm=BatchNorm2d, num_classes=n_classes)

self.freeze_bn = freeze_bn
Expand Down Expand Up @@ -273,7 +282,7 @@ def get_10x_lr_params(self):


if __name__ == "__main__":
model = DeepLabV3_plus(nInputChannels=3, n_classes=3, output_stride=16, pretrained=False, _print=True)
model = DeepLabV3_plus(backbone='mobilenet', nInputChannels=3, n_classes=3, output_stride=16, pretrained=False, _print=True)
image = torch.randn(16,3,513,513)
with torch.no_grad():
output = model.forward(image)
Expand Down
224 changes: 224 additions & 0 deletions model/segmentation/mobilenet_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import os
import sys
from collections import OrderedDict

from torch import nn
from torchvision.models.utils import load_state_dict_from_url
import torch.nn.functional as F

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../FedML")))
from fedml_api.model.cv.batchnorm_utils import SynchronizedBatchNorm2d

##############################################################################
# The following implementation was taken from the following repo with slight #
# structural modifications to suit our architecture. #
# Source: https://github.com/VainF/DeepLabV3Plus-Pytorch #
##############################################################################

def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v


def fixed_padding(kernel_size, dilation):
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
pad_total = kernel_size_effective - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
return pad_beg, pad_end, pad_beg, pad_end


class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, batch_norm, kernel_size=3, stride=1, dilation=1, groups=1):
#padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False),
batch_norm(out_planes),
nn.ReLU6(inplace=True)
)


class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, dilation, expand_ratio, batch_norm):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]

hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup

layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, batch_norm=batch_norm))

layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim, batch_norm=batch_norm),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
batch_norm(oup),
])
self.conv = nn.Sequential(*layers)

self.input_padding = fixed_padding( 3, dilation )

def forward(self, x):
x_pad = F.pad(x, self.input_padding)
if self.use_res_connect:
return x + self.conv(x_pad)
else:
return self.conv(x_pad)


class MobileNetV2(nn.Module):
def __init__(self, output_stride, batch_norm, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8, pretrained=True):
"""
MobileNet V2 main class
Args:
num_classes (int): Number of classes
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
"""
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
self.output_stride = output_stride
current_stride = 1
if inverted_residual_setting is None:
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]

# only check the first element, assuming user knows t,c,n,s are required
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))

# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(3, input_channel, batch_norm=batch_norm, stride=2)]
current_stride *= 2
dilation=1

# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
previous_dilation = dilation
if current_stride == output_stride:
stride = 1
dilation *= s
else:
stride = s
current_stride *= s
output_channel = int(c * width_mult)

for i in range(n):
if i == 0:
features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t, batch_norm=batch_norm))
else:
features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t, batch_norm=batch_norm))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, batch_norm=batch_norm))
# make it nn.Sequential
self.features = nn.Sequential(*features)

# building classifier
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, num_classes),
)

self._init_weights()

if pretrained:
self._load_pretrained_model()

def forward(self, x):
x = self.features(x)
x = x.mean([2, 3])
x = self.classifier(x)
return x

def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, SynchronizedBatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)

def _load_pretrained_model(self):
pretrain_dict =load_state_dict_from_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', progress=True)
self.load_state_dict(pretrain_dict)


class IntermediateLayerGetter(nn.ModuleDict):
def __init__(self, model, return_layers):
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")

orig_return_layers = return_layers
return_layers = {k: v for k, v in return_layers.items()}
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
if name in return_layers:
del return_layers[name]
if not return_layers:
break

super(IntermediateLayerGetter, self).__init__(layers)
self.return_layers = orig_return_layers

def forward(self, x):
out = OrderedDict()
for name, module in self.named_children():
x = module(x)
if name in self.return_layers:
out_name = self.return_layers[name]
out[out_name] = x
return out['out'], out['low_level']


def MobileNetV2Encoder(**kwargs):
"""
Constructs a MobileNetV2 architecture from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
"""
model = MobileNetV2(**kwargs)
return model
2 changes: 1 addition & 1 deletion model/segmentation/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.utils.model_zoo as model_zoo

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../FedML")))
from fedml_api.model.cv.batchnorm_utils import SynchronizedBatchNorm2d

class Bottleneck(nn.Module):
Expand Down