diff --git a/builders/model_builder.py b/builders/model_builder.py index 07cd5fc89..a99b017d0 100644 --- a/builders/model_builder.py +++ b/builders/model_builder.py @@ -14,9 +14,11 @@ from models.DeepLabV3_plus import build_deeplabv3_plus from models.AdapNet import build_adaptnet from models.custom_model import build_custom +from models.DenseASPP import build_dense_aspp SUPPORTED_MODELS = ["FC-DenseNet56", "FC-DenseNet67", "FC-DenseNet103", "Encoder-Decoder", "Encoder-Decoder-Skip", "RefineNet", - "FRRN-A", "FRRN-B", "MobileUNet", "MobileUNet-Skip", "PSPNet", "GCN", "DeepLabV3", "DeepLabV3_plus", "AdapNet", "custom"] + "FRRN-A", "FRRN-B", "MobileUNet", "MobileUNet-Skip", "PSPNet", "GCN", "DeepLabV3", "DeepLabV3_plus", "AdapNet", + "DenseASPP", "custom"] SUPPORTED_FRONTENDS = ["ResNet50", "ResNet101", "ResNet152", "MobileNetV2", "InceptionV4"] @@ -32,10 +34,10 @@ def build_model(model_name, net_input, num_classes, crop_width, crop_height, fro print("Preparing the model ...") if model_name not in SUPPORTED_MODELS: - raise ValueError("The model you selelect is not supported. The following models are currently supported: {0}".format(SUPPORTED_MODELS)) + raise ValueError("The model you selected is not supported. The following models are currently supported: {0}".format(SUPPORTED_MODELS)) if frontend not in SUPPORTED_FRONTENDS: - raise ValueError("The frontend you selelect is not supported. The following models are currently supported: {0}".format(SUPPORTED_FRONTENDS)) + raise ValueError("The frontend you selected is not supported. The following models are currently supported: {0}".format(SUPPORTED_FRONTENDS)) if "ResNet50" == frontend and not os.path.isfile("models/resnet_v2_50.ckpt"): download_checkpoints("ResNet50") @@ -74,6 +76,9 @@ def build_model(model_name, net_input, num_classes, crop_width, crop_height, fro elif model_name == "DeepLabV3_plus": # DeepLabV3+ requires pre-trained ResNet weights network, init_fn = build_deeplabv3_plus(net_input, preset_model = model_name, frontend=frontend, num_classes=num_classes, is_training=is_training) + elif model_name == "DenseASPP": + # DenseASPP+ requires pre-trained ResNet weights + network, init_fn = build_dense_aspp(net_input, preset_model = model_name, frontend=frontend, num_classes=num_classes, is_training=is_training) elif model_name == "AdapNet": network = build_adaptnet(net_input, num_classes=num_classes) elif model_name == "custom": diff --git a/docs/README.md b/docs/README.md index 9d7598a62..842f4b55f 100644 --- a/docs/README.md +++ b/docs/README.md @@ -4,21 +4,21 @@ ## News -**What's New:** +### What's New -- Plotting for every epoch, similar to Tensorboard +- Added the DenseASPP network from CVPR 2018! - Added support for MobileNetV2 and InceptionV4 frontends! - Code restructuring. Much easier to expand and debug **You can now set the segmentation model and frontend to use (ResNet50, ResNet101, etc) separately as command line arguments. See the updated usage section below** -**Coming Soon:** +### Coming Soon -- Anything that comes out at CVPR 2018! +- Anything that comes out at CVPR 2018 and ECCV 2018! -- Support for exporting inference graph. +- More network frontends! -Open up an issue to suggest a new feature or improvement! +**Open up an issue to suggest a new feature or improvement!** ## Description This repository serves as a Semantic Segmentation Suite. The goal is to easily be able to implement, train, and test new Semantic Segmentation models! Complete with the following: @@ -73,6 +73,8 @@ to obtain robust features for recognition. The two streams are coupled at the fu - [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611). This is the DeepLabV3+ network which adds a Decoder module on top of the regular DeepLabV3 model. +- [DenseASPP for Semantic Segmentation in Street Scenes](http://openaccess.thecvf.com/content_cvpr_2018/html/Yang_DenseASPP_for_Semantic_CVPR_2018_paper.html). Combines many different scales using dilated convolution but with dense connections + - Or make your own and plug and play! **Note:** If you are using any of the networks that rely on a pre-trained ResNet, then you will need to download the pre-trained weights using the provided script. These are currently: PSPNet, RefineNet, DeepLabV3, DeepLabV3+, GCN. diff --git a/iou_vs_epochs.png b/iou_vs_epochs.png index 3149dff08..a3dc70061 100644 Binary files a/iou_vs_epochs.png and b/iou_vs_epochs.png differ diff --git a/models/DenseASPP.py b/models/DenseASPP.py new file mode 100644 index 000000000..a7e67b9ff --- /dev/null +++ b/models/DenseASPP.py @@ -0,0 +1,60 @@ +import tensorflow as tf +from tensorflow.contrib import slim +from builders import frontend_builder +import os, sys + + +def Upsampling(inputs,scale): + return tf.image.resize_nearest_neighbor(inputs, size=[tf.shape(inputs)[1]*scale, tf.shape(inputs)[2]*scale]) + + + +def DilatedConvBlock(inputs, n_filters, rate=1, kernel_size=[3, 3]): + """ + Basic dilated conv block + Apply successivly BatchNormalization, ReLU nonlinearity, dilated convolution + """ + net = tf.nn.relu(slim.batch_norm(inputs, fused=True)) + net = slim.conv2d(net, n_filters, kernel_size, rate=rate, activation_fn=None, normalizer_fn=None) + return net + + + +def build_dense_aspp(inputs, num_classes, preset_model='DenseASPP', frontend="ResNet101", weight_decay=1e-5, is_training=True, pretrained_dir="models"): + + + logits, end_points, frontend_scope, init_fn = frontend_builder.build_frontend(inputs, frontend, is_training=is_training) + + init_features = end_points['pool3'] + + ### First block, rate = 3 + d_3_features = DilatedConvBlock(init_features, n_filters=256, kernel_size=[1, 1]) + d_3 = DilatedConvBlock(d_3_features, n_filters=64, rate=3, kernel_size=[3, 3]) + + ### Second block, rate = 6 + d_4 = tf.concat([init_features, d_3], axis=-1) + d_4 = DilatedConvBlock(d_4, n_filters=256, kernel_size=[1, 1]) + d_4 = DilatedConvBlock(d_4, n_filters=64, rate=6, kernel_size=[3, 3]) + + ### Third block, rate = 12 + d_5 = tf.concat([init_features, d_3, d_4], axis=-1) + d_5 = DilatedConvBlock(d_5, n_filters=256, kernel_size=[1, 1]) + d_5 = DilatedConvBlock(d_5, n_filters=64, rate=12, kernel_size=[3, 3]) + + ### Fourth block, rate = 18 + d_6 = tf.concat([init_features, d_3, d_4, d_5], axis=-1) + d_6 = DilatedConvBlock(d_6, n_filters=256, kernel_size=[1, 1]) + d_6 = DilatedConvBlock(d_6, n_filters=64, rate=18, kernel_size=[3, 3]) + + ### Fifth block, rate = 24 + d_7 = tf.concat([init_features, d_3, d_4, d_5, d_6], axis=-1) + d_7 = DilatedConvBlock(d_7, n_filters=256, kernel_size=[1, 1]) + d_7 = DilatedConvBlock(d_7, n_filters=64, rate=24, kernel_size=[3, 3]) + + full_block = tf.concat([init_features, d_3, d_4, d_5, d_6, d_7], axis=-1) + + net = slim.conv2d(full_block, num_classes, [1, 1], activation_fn=None, scope='logits') + + net = Upsampling(net, scale=8) + + return net, init_fn \ No newline at end of file