Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…not success yet
  • Loading branch information
RobinDong committed Aug 23, 2019
1 parent b3dd592 commit 1a87991
Show file tree
Hide file tree
Showing 5 changed files with 374 additions and 14 deletions.
14 changes: 7 additions & 7 deletions data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# SSD300 CONFIGS
voc = {
'num_classes': 21,
'lr_steps': (80000, 100000, 120000, 140000),
'max_iter': 140001,
'feature_maps': [38, 19, 10, 5, 3, 1],
'lr_steps': (20000, 40000, 60000, 80000),
'max_iter': 80001,
'feature_maps': [19, 10, 5, 3, 2, 1],
'min_dim': 300,
'steps': [8, 16, 32, 64, 100, 300],
'min_sizes': [30, 60, 111, 162, 213, 264],
'max_sizes': [60, 111, 162, 213, 264, 315],
'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
'steps': [16, 30, 60, 100, 150, 300],
'min_sizes': [45, 90, 135, 180, 225, 270],
'max_sizes': [90, 135, 180, 225, 270, 315],
'aspect_ratios': [[2, 3], [2, 3], [2, 3], [2, 3], [2], [2]],
'variance': [0.1, 0.2],
'clip': True,
'name': 'VOC',
Expand Down
172 changes: 172 additions & 0 deletions nets/mobilenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from torch import nn


__all__ = ['MobileNetV2', 'mobilenet_v2']


model_urls = {
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
}


def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v


class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True)
)


class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]

hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup

layers = []
#if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
])
self.conv = nn.Sequential(*layers)

def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)


class MobileNetV2(nn.Module):
def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
"""
MobileNet V2 main class
Args:
num_classes (int): Number of classes
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
"""
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280

if inverted_residual_setting is None:
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]

# only check the first element, assuming user knows t,c,n,s are required
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))

# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(3, input_channel, stride=2)]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
for i in range(n):
stride = s if i == 0 else 1
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
# make it nn.Sequential
# self.features = nn.Sequential(*features)
self.features = features

# building classifier
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, num_classes),
)

# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)

def forward(self, x):
endpoints = []
# fetch layer 6 (38x38), 13 (19x19), 17 (10x10) as detection features
layers = []
for index, layer in enumerate(self.features):
x = layer(x)
layers.append(layer)
if index in [13, 17]:
endpoints.append(x)
if index == 17:
break
# x = self.features(x)
# x = x.mean([2, 3])
# x = self.classifier(x)
return x, endpoints


def mobilenet_v2(pretrained=False, progress=True, **kwargs):
"""
Constructs a MobileNetV2 architecture from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = MobileNetV2(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
progress=progress)
model.load_state_dict(state_dict)
return model
186 changes: 186 additions & 0 deletions ssd_mobilenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from layers import *
from data import voc, coco, cub
from nets import mobilenet
import os


class SSDMobileNetV2(nn.Module):
"""Single Shot Multibox Architecture
The network is composed of a base MobileNetV2 followed by the
added multibox conv layers. Each multibox layer branches into
1) conv2d for class conf scores
2) conv2d for localization predictions
3) associated priorbox layer to produce default bounding
boxes specific to the layer's feature map size.
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
Args:
phase: (string) Can be "test" or "train"
size: input image size
extras: extra layers that feed to multibox loc and conf layers
head: "multibox head" consists of loc and conf conv layers
"""

def __init__(self, phase, size, extras, head, num_classes):
super(SSDMobileNetV2, self).__init__()
self.phase = phase
self.num_classes = num_classes
switcher = {
2: cub,
21: voc
}
self.cfg = switcher.get(num_classes, coco)
self.priorbox = PriorBox(self.cfg)
self.priors = Variable(self.priorbox.forward(), volatile=True)
self.size = size

# SSD network
self.backbone = mobilenet.MobileNetV2(num_classes=num_classes, width_mult=1.0)
self.norm = L2Norm(96, 20)
self.extras = nn.ModuleList(extras)

self.loc = nn.ModuleList(head[0])
self.conf = nn.ModuleList(head[1])

if phase == 'test':
self.softmax = nn.Softmax(dim=-1)
self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)

def forward(self, x):
"""Applies network layers and ops on input image(s) x.
Args:
x: input image or batch of images. Shape: [batch,3,300,300].
Return:
Depending on phase:
test:
Variable(tensor) of output class label predictions,
confidence score, and corresponding location predictions for
each object detected. Shape: [batch,topk,7]
train:
list of concat outputs from:
1: confidence layers, Shape: [batch*num_priors,num_classes]
2: localization layers, Shape: [batch,num_priors*4]
3: priorbox layers, Shape: [2,num_priors*4]
"""
sources = list()
loc = list()
conf = list()

_, endpoints = self.backbone(x)
sources.append(self.norm(endpoints[0]))
sources.append(endpoints[1])
x = endpoints[1]

# apply extra layers and cache source layer outputs
for k, v in enumerate(self.extras):
x = F.relu(v(x), inplace=True)
#x = v(x)
#if k % 2 == 1:
sources.append(x)

# apply multibox head to source layers
for (x, l, c) in zip(sources, self.loc, self.conf):
loc.append(l(x).permute(0, 2, 3, 1).contiguous())
conf.append(c(x).permute(0, 2, 3, 1).contiguous())

loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
if self.phase == "test":
output = self.detect(
loc.view(loc.size(0), -1, 4), # loc preds
self.softmax(conf.view(conf.size(0), -1,
self.num_classes)), # conf preds
self.priors.type(type(x.data)) # default boxes
)
else:
output = (
loc.view(loc.size(0), -1, 4),
conf.view(conf.size(0), -1, self.num_classes),
self.priors
)
return output

def load_weights(self, base_file):
other, ext = os.path.splitext(base_file)
if ext == '.pkl' or '.pth':
print('Loading weights into state dict...')
self.load_state_dict(torch.load(base_file,
map_location=lambda storage, loc: storage))
print('Finished!')
else:
print('Sorry only .pth and .pkl files supported.')


'''def add_extras(cfg, i, batch_norm=False):
# Extra layers added to VGG for feature scaling
layers = []
in_channels = i
flag = False
for k, v in enumerate(cfg):
if in_channels != 'S':
if v == 'S':
layers += [nn.Conv2d(in_channels, cfg[k + 1],
kernel_size=(1, 3)[flag], stride=2, padding=1)]
else:
layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])]
flag = not flag
in_channels = v
return layers'''

def add_extras(cfg):
layers = []
block = mobilenet.InvertedResidual

layers.append(block(320, 512, 2, 512.0/320.0))
layers.append(block(512, 256, 2, 0.5))
layers.append(block(256, 256, 2, 1))
layers.append(block(256, 128, 2, 0.5))

return layers

def multibox(extra_layers, cfg, num_classes):
loc_layers = []
conf_layers = []
mobilenet_channels = [96, 320]
for k, channel in enumerate(mobilenet_channels):
loc_layers += [nn.Conv2d(channel,
cfg[k] * 4, kernel_size=3, padding=1)]
conf_layers += [nn.Conv2d(channel,
cfg[k] * num_classes, kernel_size=3, padding=1)]
out_channels = [0, 0, 512, 256, 256, 128]
for k, v in enumerate(extra_layers, 2):
loc_layers += [nn.Conv2d(out_channels[k], cfg[k]
* 4, kernel_size=3, padding=1)]
conf_layers += [nn.Conv2d(out_channels[k], cfg[k]
* num_classes, kernel_size=3, padding=1)]
return extra_layers, (loc_layers, conf_layers)


extras = {
'300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256],
'512': [],
}
mbox = {
'300': [6, 6, 6, 6, 4, 4], # number of boxes per feature map location
'512': [],
}


def build_ssd_mobilenet(phase, size=300, num_classes=21):
if phase != "test" and phase != "train":
print("ERROR: Phase: " + phase + " not recognized")
return
if size != 300:
print("ERROR: You specified size " + repr(size) + ". However, " +
"currently only SSD300 (size=300) is supported!")
return
extras_, head_ = multibox(add_extras(extras[str(size)]),
mbox[str(size)], num_classes)
return SSDMobileNetV2(phase, size, extras_, head_, num_classes)
Loading

0 comments on commit 1a87991

Please sign in to comment.