Skip to content
This repository has been archived by the owner on Jul 5, 2021. It is now read-only.

Commit

Permalink
Added the DenseASPP model from CVPR 2018
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgeSeif committed Sep 13, 2018
1 parent 702defd commit 791dbc7
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 9 deletions.
11 changes: 8 additions & 3 deletions builders/model_builder.py
Expand Up @@ -14,9 +14,11 @@
from models.DeepLabV3_plus import build_deeplabv3_plus
from models.AdapNet import build_adaptnet
from models.custom_model import build_custom
from models.DenseASPP import build_dense_aspp

SUPPORTED_MODELS = ["FC-DenseNet56", "FC-DenseNet67", "FC-DenseNet103", "Encoder-Decoder", "Encoder-Decoder-Skip", "RefineNet",
"FRRN-A", "FRRN-B", "MobileUNet", "MobileUNet-Skip", "PSPNet", "GCN", "DeepLabV3", "DeepLabV3_plus", "AdapNet", "custom"]
"FRRN-A", "FRRN-B", "MobileUNet", "MobileUNet-Skip", "PSPNet", "GCN", "DeepLabV3", "DeepLabV3_plus", "AdapNet",
"DenseASPP", "custom"]

SUPPORTED_FRONTENDS = ["ResNet50", "ResNet101", "ResNet152", "MobileNetV2", "InceptionV4"]

Expand All @@ -32,10 +34,10 @@ def build_model(model_name, net_input, num_classes, crop_width, crop_height, fro
print("Preparing the model ...")

if model_name not in SUPPORTED_MODELS:
raise ValueError("The model you selelect is not supported. The following models are currently supported: {0}".format(SUPPORTED_MODELS))
raise ValueError("The model you selected is not supported. The following models are currently supported: {0}".format(SUPPORTED_MODELS))

if frontend not in SUPPORTED_FRONTENDS:
raise ValueError("The frontend you selelect is not supported. The following models are currently supported: {0}".format(SUPPORTED_FRONTENDS))
raise ValueError("The frontend you selected is not supported. The following models are currently supported: {0}".format(SUPPORTED_FRONTENDS))

if "ResNet50" == frontend and not os.path.isfile("models/resnet_v2_50.ckpt"):
download_checkpoints("ResNet50")
Expand Down Expand Up @@ -74,6 +76,9 @@ def build_model(model_name, net_input, num_classes, crop_width, crop_height, fro
elif model_name == "DeepLabV3_plus":
# DeepLabV3+ requires pre-trained ResNet weights
network, init_fn = build_deeplabv3_plus(net_input, preset_model = model_name, frontend=frontend, num_classes=num_classes, is_training=is_training)
elif model_name == "DenseASPP":
# DenseASPP+ requires pre-trained ResNet weights
network, init_fn = build_dense_aspp(net_input, preset_model = model_name, frontend=frontend, num_classes=num_classes, is_training=is_training)
elif model_name == "AdapNet":
network = build_adaptnet(net_input, num_classes=num_classes)
elif model_name == "custom":
Expand Down
14 changes: 8 additions & 6 deletions docs/README.md
Expand Up @@ -4,21 +4,21 @@

## News

**What's New:**
### What's New

- Plotting for every epoch, similar to Tensorboard
- Added the DenseASPP network from CVPR 2018!

- Added support for MobileNetV2 and InceptionV4 frontends!

- Code restructuring. Much easier to expand and debug **You can now set the segmentation model and frontend to use (ResNet50, ResNet101, etc) separately as command line arguments. See the updated usage section below**

**Coming Soon:**
### Coming Soon

- Anything that comes out at CVPR 2018!
- Anything that comes out at CVPR 2018 and ECCV 2018!

- Support for exporting inference graph.
- More network frontends!

Open up an issue to suggest a new feature or improvement!
**Open up an issue to suggest a new feature or improvement!**

## Description
This repository serves as a Semantic Segmentation Suite. The goal is to easily be able to implement, train, and test new Semantic Segmentation models! Complete with the following:
Expand Down Expand Up @@ -73,6 +73,8 @@ to obtain robust features for recognition. The two streams are coupled at the fu

- [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611). This is the DeepLabV3+ network which adds a Decoder module on top of the regular DeepLabV3 model.

- [DenseASPP for Semantic Segmentation in Street Scenes](http://openaccess.thecvf.com/content_cvpr_2018/html/Yang_DenseASPP_for_Semantic_CVPR_2018_paper.html). Combines many different scales using dilated convolution but with dense connections

- Or make your own and plug and play!

**Note:** If you are using any of the networks that rely on a pre-trained ResNet, then you will need to download the pre-trained weights using the provided script. These are currently: PSPNet, RefineNet, DeepLabV3, DeepLabV3+, GCN.
Expand Down
Binary file modified iou_vs_epochs.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
60 changes: 60 additions & 0 deletions models/DenseASPP.py
@@ -0,0 +1,60 @@
import tensorflow as tf
from tensorflow.contrib import slim
from builders import frontend_builder
import os, sys


def Upsampling(inputs,scale):
return tf.image.resize_nearest_neighbor(inputs, size=[tf.shape(inputs)[1]*scale, tf.shape(inputs)[2]*scale])



def DilatedConvBlock(inputs, n_filters, rate=1, kernel_size=[3, 3]):
"""
Basic dilated conv block
Apply successivly BatchNormalization, ReLU nonlinearity, dilated convolution
"""
net = tf.nn.relu(slim.batch_norm(inputs, fused=True))
net = slim.conv2d(net, n_filters, kernel_size, rate=rate, activation_fn=None, normalizer_fn=None)
return net



def build_dense_aspp(inputs, num_classes, preset_model='DenseASPP', frontend="ResNet101", weight_decay=1e-5, is_training=True, pretrained_dir="models"):


logits, end_points, frontend_scope, init_fn = frontend_builder.build_frontend(inputs, frontend, is_training=is_training)

init_features = end_points['pool3']

### First block, rate = 3
d_3_features = DilatedConvBlock(init_features, n_filters=256, kernel_size=[1, 1])
d_3 = DilatedConvBlock(d_3_features, n_filters=64, rate=3, kernel_size=[3, 3])

### Second block, rate = 6
d_4 = tf.concat([init_features, d_3], axis=-1)
d_4 = DilatedConvBlock(d_4, n_filters=256, kernel_size=[1, 1])
d_4 = DilatedConvBlock(d_4, n_filters=64, rate=6, kernel_size=[3, 3])

### Third block, rate = 12
d_5 = tf.concat([init_features, d_3, d_4], axis=-1)
d_5 = DilatedConvBlock(d_5, n_filters=256, kernel_size=[1, 1])
d_5 = DilatedConvBlock(d_5, n_filters=64, rate=12, kernel_size=[3, 3])

### Fourth block, rate = 18
d_6 = tf.concat([init_features, d_3, d_4, d_5], axis=-1)
d_6 = DilatedConvBlock(d_6, n_filters=256, kernel_size=[1, 1])
d_6 = DilatedConvBlock(d_6, n_filters=64, rate=18, kernel_size=[3, 3])

### Fifth block, rate = 24
d_7 = tf.concat([init_features, d_3, d_4, d_5, d_6], axis=-1)
d_7 = DilatedConvBlock(d_7, n_filters=256, kernel_size=[1, 1])
d_7 = DilatedConvBlock(d_7, n_filters=64, rate=24, kernel_size=[3, 3])

full_block = tf.concat([init_features, d_3, d_4, d_5, d_6, d_7], axis=-1)

net = slim.conv2d(full_block, num_classes, [1, 1], activation_fn=None, scope='logits')

net = Upsampling(net, scale=8)

return net, init_fn

0 comments on commit 791dbc7

Please sign in to comment.