|
| 1 | +""" |
| 2 | +2017/11/24 ref:https://github.com/Zehaos/MobileNet/blob/master/nets/mobilenet.py |
| 3 | +""" |
| 4 | + |
| 5 | +import tensorflow as tf |
| 6 | +from tensorflow.python.training import moving_averages |
| 7 | + |
| 8 | +UPDATE_OPS_COLLECTION = "_update_ops_" |
| 9 | + |
| 10 | +# create variable |
| 11 | +def create_variable(name, shape, initializer, |
| 12 | + dtype=tf.float32, trainable=True): |
| 13 | + return tf.get_variable(name, shape=shape, dtype=dtype, |
| 14 | + initializer=initializer, trainable=trainable) |
| 15 | + |
| 16 | +# batchnorm layer |
| 17 | +def bacthnorm(inputs, scope, epsilon=1e-05, momentum=0.99, is_training=True): |
| 18 | + inputs_shape = inputs.get_shape().as_list() |
| 19 | + params_shape = inputs_shape[-1:] |
| 20 | + axis = list(range(len(inputs_shape) - 1)) |
| 21 | + |
| 22 | + with tf.variable_scope(scope): |
| 23 | + beta = create_variable("beta", params_shape, |
| 24 | + initializer=tf.zeros_initializer()) |
| 25 | + gamma = create_variable("gamma", params_shape, |
| 26 | + initializer=tf.ones_initializer()) |
| 27 | + # for inference |
| 28 | + moving_mean = create_variable("moving_mean", params_shape, |
| 29 | + initializer=tf.zeros_initializer(), trainable=False) |
| 30 | + moving_variance = create_variable("moving_variance", params_shape, |
| 31 | + initializer=tf.ones_initializer(), trainable=False) |
| 32 | + if is_training: |
| 33 | + mean, variance = tf.nn.moments(inputs, axes=axis) |
| 34 | + update_move_mean = moving_averages.assign_moving_average(moving_mean, |
| 35 | + mean, decay=momentum) |
| 36 | + update_move_variance = moving_averages.assign_moving_average(moving_variance, |
| 37 | + variance, decay=momentum) |
| 38 | + tf.add_to_collection(UPDATE_OPS_COLLECTION, update_move_mean) |
| 39 | + tf.add_to_collection(UPDATE_OPS_COLLECTION, update_move_variance) |
| 40 | + else: |
| 41 | + mean, variance = moving_mean, moving_variance |
| 42 | + return tf.nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon) |
| 43 | + |
| 44 | +# depthwise conv2d layer |
| 45 | +def depthwise_conv2d(inputs, scope, filter_size=3, channel_multiplier=1, strides=1): |
| 46 | + inputs_shape = inputs.get_shape().as_list() |
| 47 | + in_channels = inputs_shape[-1] |
| 48 | + with tf.variable_scope(scope): |
| 49 | + filter = create_variable("filter", shape=[filter_size, filter_size, |
| 50 | + in_channels, channel_multiplier], |
| 51 | + initializer=tf.truncated_normal_initializer(stddev=0.01)) |
| 52 | + |
| 53 | + return tf.nn.depthwise_conv2d(inputs, filter, strides=[1, strides, strides, 1], |
| 54 | + padding="SAME", rate=[1, 1]) |
| 55 | + |
| 56 | +# conv2d layer |
| 57 | +def conv2d(inputs, scope, num_filters, filter_size=1, strides=1): |
| 58 | + inputs_shape = inputs.get_shape().as_list() |
| 59 | + in_channels = inputs_shape[-1] |
| 60 | + with tf.variable_scope(scope): |
| 61 | + filter = create_variable("filter", shape=[filter_size, filter_size, |
| 62 | + in_channels, num_filters], |
| 63 | + initializer=tf.truncated_normal_initializer(stddev=0.01)) |
| 64 | + return tf.nn.conv2d(inputs, filter, strides=[1, strides, strides, 1], |
| 65 | + padding="SAME") |
| 66 | + |
| 67 | +# avg pool layer |
| 68 | +def avg_pool(inputs, pool_size, scope): |
| 69 | + with tf.variable_scope(scope): |
| 70 | + return tf.nn.avg_pool(inputs, [1, pool_size, pool_size, 1], |
| 71 | + strides=[1, pool_size, pool_size, 1], padding="VALID") |
| 72 | + |
| 73 | +# fully connected layer |
| 74 | +def fc(inputs, n_out, scope, use_bias=True): |
| 75 | + inputs_shape = inputs.get_shape().as_list() |
| 76 | + n_in = inputs_shape[-1] |
| 77 | + with tf.variable_scope(scope): |
| 78 | + weight = create_variable("weight", shape=[n_in, n_out], |
| 79 | + initializer=tf.random_normal_initializer(stddev=0.01)) |
| 80 | + if use_bias: |
| 81 | + bias = create_variable("bias", shape=[n_out,], |
| 82 | + initializer=tf.zeros_initializer()) |
| 83 | + return tf.nn.xw_plus_b(inputs, weight, bias) |
| 84 | + return tf.matmul(inputs, weight) |
| 85 | + |
| 86 | + |
| 87 | +class MobileNet(object): |
| 88 | + def __init__(self, inputs, num_classes=1000, is_training=True, |
| 89 | + width_multiplier=1, scope="MobileNet"): |
| 90 | + """ |
| 91 | + The implement of MobileNet(ref:https://arxiv.org/abs/1704.04861) |
| 92 | + :param inputs: 4-D Tensor of [batch_size, height, width, channels] |
| 93 | + :param num_classes: number of classes |
| 94 | + :param is_training: Boolean, whether or not the model is training |
| 95 | + :param width_multiplier: float, controls the size of model |
| 96 | + :param scope: Optional scope for variables |
| 97 | + """ |
| 98 | + self.inputs = inputs |
| 99 | + self.num_classes = num_classes |
| 100 | + self.is_training = is_training |
| 101 | + self.width_multiplier = width_multiplier |
| 102 | + |
| 103 | + # construct model |
| 104 | + with tf.variable_scope(scope): |
| 105 | + # conv1 |
| 106 | + net = conv2d(inputs, "conv_1", round(32 * width_multiplier), filter_size=3, |
| 107 | + strides=2) # ->[N, 112, 112, 32] |
| 108 | + net = tf.nn.relu(bacthnorm(net, "conv_1/bn", is_training=self.is_training)) |
| 109 | + net = self._depthwise_separable_conv2d(net, 64, self.width_multiplier, |
| 110 | + "ds_conv_2") # ->[N, 112, 112, 64] |
| 111 | + net = self._depthwise_separable_conv2d(net, 128, self.width_multiplier, |
| 112 | + "ds_conv_3", downsample=True) # ->[N, 56, 56, 128] |
| 113 | + net = self._depthwise_separable_conv2d(net, 128, self.width_multiplier, |
| 114 | + "ds_conv_4") # ->[N, 56, 56, 128] |
| 115 | + net = self._depthwise_separable_conv2d(net, 256, self.width_multiplier, |
| 116 | + "ds_conv_5", downsample=True) # ->[N, 28, 28, 256] |
| 117 | + net = self._depthwise_separable_conv2d(net, 256, self.width_multiplier, |
| 118 | + "ds_conv_6") # ->[N, 28, 28, 256] |
| 119 | + net = self._depthwise_separable_conv2d(net, 512, self.width_multiplier, |
| 120 | + "ds_conv_7", downsample=True) # ->[N, 14, 14, 512] |
| 121 | + net = self._depthwise_separable_conv2d(net, 512, self.width_multiplier, |
| 122 | + "ds_conv_8") # ->[N, 14, 14, 512] |
| 123 | + net = self._depthwise_separable_conv2d(net, 512, self.width_multiplier, |
| 124 | + "ds_conv_9") # ->[N, 14, 14, 512] |
| 125 | + net = self._depthwise_separable_conv2d(net, 512, self.width_multiplier, |
| 126 | + "ds_conv_10") # ->[N, 14, 14, 512] |
| 127 | + net = self._depthwise_separable_conv2d(net, 512, self.width_multiplier, |
| 128 | + "ds_conv_11") # ->[N, 14, 14, 512] |
| 129 | + net = self._depthwise_separable_conv2d(net, 512, self.width_multiplier, |
| 130 | + "ds_conv_12") # ->[N, 14, 14, 512] |
| 131 | + net = self._depthwise_separable_conv2d(net, 1024, self.width_multiplier, |
| 132 | + "ds_conv_13", downsample=True) # ->[N, 7, 7, 1024] |
| 133 | + net = self._depthwise_separable_conv2d(net, 1024, self.width_multiplier, |
| 134 | + "ds_conv_14") # ->[N, 7, 7, 1024] |
| 135 | + net = avg_pool(net, 7, "avg_pool_15") |
| 136 | + net = tf.squeeze(net, [1, 2], name="SpatialSqueeze") |
| 137 | + self.logits = fc(net, self.num_classes, "fc_16") |
| 138 | + self.predictions = tf.nn.softmax(self.logits) |
| 139 | + |
| 140 | + def _depthwise_separable_conv2d(self, inputs, num_filters, width_multiplier, |
| 141 | + scope, downsample=False): |
| 142 | + """depthwise separable convolution 2D function""" |
| 143 | + num_filters = round(num_filters * width_multiplier) |
| 144 | + strides = 2 if downsample else 1 |
| 145 | + |
| 146 | + with tf.variable_scope(scope): |
| 147 | + # depthwise conv2d |
| 148 | + dw_conv = depthwise_conv2d(inputs, "depthwise_conv", strides=strides) |
| 149 | + # batchnorm |
| 150 | + bn = bacthnorm(dw_conv, "dw_bn", is_training=self.is_training) |
| 151 | + # relu |
| 152 | + relu = tf.nn.relu(bn) |
| 153 | + # pointwise conv2d (1x1) |
| 154 | + pw_conv = conv2d(relu, "pointwise_conv", num_filters) |
| 155 | + # bn |
| 156 | + bn = bacthnorm(pw_conv, "pw_bn", is_training=self.is_training) |
| 157 | + return tf.nn.relu(bn) |
| 158 | + |
| 159 | +if __name__ == "__main__": |
| 160 | + # test data |
| 161 | + inputs = tf.random_normal(shape=[4, 224, 224, 3]) |
| 162 | + mobileNet = MobileNet(inputs) |
| 163 | + writer = tf.summary.FileWriter("./logs", graph=tf.get_default_graph()) |
| 164 | + init = tf.global_variables_initializer() |
| 165 | + with tf.Session() as sess: |
| 166 | + sess.run(init) |
| 167 | + pred = sess.run(mobileNet.predictions) |
| 168 | + print(pred.shape) |
| 169 | + |
0 commit comments