Permalink
Browse files

Added support for MobileNetV2

  • Loading branch information...
GeorgeSeif committed Aug 26, 2018
1 parent 4600319 commit e61f5dce3093876f687a879ea4552a49f5d418c9
View
@@ -1,6 +1,9 @@
*.pyc
models/*.pyc
models/*.ckpt
models/*.ckpt*
models/*.pb
models/*txt
models/*.tflite
checkpoints/
Test/*.png
accuracy_vs_epochs.png
View
BIN +4.68 MB Images/semseg.gif
Binary file not shown.
View
@@ -1,11 +1,15 @@
# Semantic Segmentation Suite in TensorFlow
![alt-text-10](https://github.com/GeorgeSeif/Semantic-Segmentation-Suite/blob/master/Images/semseg.gif)
## News
**What's New:**
- Plotting for every epoch, similar to Tensorboard
- Added support for MobileNetV2
- Code restructuring. Much easier to expand and debug **You can now set the segmentation model and frontend to use (ResNet50, ResNet101, etc) separately as command line arguments. See the updated usage section below**
- You can also check out my [Transfer Learning Suite](https://github.com/GeorgeSeif/Transfer-Learning-Suite).
@@ -14,7 +18,7 @@
**Coming Soon:**
- Support NASNet, MobileNet, Dilated ResNet for segmentation models that use classification network front-ends
- Support NASNet, Dilated ResNet for segmentation models that use classification network front-ends
- Anything that comes out at CVPR 2018!
@@ -72,41 +72,37 @@
# ResNet V2
###############################
if args.model == "ResNet50" or args.model == "ALL":
subprocess.check_output(['wget','http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz'])
subprocess.check_output(['tar', '-xvf', 'resnet_v2_50_2017_04_14.tar.gz'])
subprocess.check_output(['wget','http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz', "-P", "models"])
try:
subprocess.check_output(['mv', 'resnet_v2_50.ckpt', 'models'])
subprocess.check_output(['rm', 'resnet_v2_50_2017_04_14.tar.gz'])
subprocess.check_output(['tar', '-xvf', 'models/resnet_v2_50_2017_04_14.tar.gz', "-C", "models"])
subprocess.check_output(['rm', 'models/resnet_v2_50_2017_04_14.tar.gz'])
except Exception as e:
print(e)
pass
if args.model == "ResNet101" or args.model == "ALL":
subprocess.check_output(['wget','http://download.tensorflow.org/models/resnet_v2_101_2017_04_14.tar.gz'])
subprocess.check_output(['tar', '-xvf', 'resnet_v2_101_2017_04_14.tar.gz'])
subprocess.check_output(['wget','http://download.tensorflow.org/models/resnet_v2_101_2017_04_14.tar.gz', "-P", "models"])
try:
subprocess.check_output(['mv', 'resnet_v2_101.ckpt', 'models'])
subprocess.check_output(['rm', 'resnet_v2_101_2017_04_14.tar.gz'])
subprocess.check_output(['tar', '-xvf', 'models/resnet_v2_101_2017_04_14.tar.gz', "-C", "models"])
subprocess.check_output(['rm', 'models/resnet_v2_101_2017_04_14.tar.gz'])
except Exception as e:
print(e)
pass
if args.model == "ResNet152" or args.model == "ALL":
subprocess.check_output(['wget','http://download.tensorflow.org/models/resnet_v2_152_2017_04_14.tar.gz'])
subprocess.check_output(['tar', '-xvf', 'resnet_v2_152_2017_04_14.tar.gz'])
subprocess.check_output(['wget','http://download.tensorflow.org/models/resnet_v2_152_2017_04_14.tar.gz', "-P", "models"])
try:
subprocess.check_output(['mv', 'resnet_v2_152.ckpt', 'models'])
subprocess.check_output(['rm', 'resnet_v2_152_2017_04_14.tar.gz'])
subprocess.check_output(['tar', '-xvf', 'models/resnet_v2_152_2017_04_14.tar.gz', "-C", "models"])
subprocess.check_output(['rm', 'models/resnet_v2_152_2017_04_14.tar.gz'])
except Exception as e:
print(e)
pass
if args.model == "Mobile" or args.model == "ALL":
subprocess.check_output(['wget','https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz'])
subprocess.check_output(['tar', '-xvf', 'mobilenet_v2_1.4_224.tar.gz'])
if args.model == "MobileNetV2" or args.model == "ALL":
subprocess.check_output(['wget','https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz', "-P", "models"])
try:
subprocess.check_output(['mv', 'mobilenet_v2_1.4_224.ckpt', 'models'])
subprocess.check_output(['rm', 'mobilenet_v2_1.4_224.tar.gz'])
subprocess.check_output(['tar', '-xvf', 'models/mobilenet_v2_1.4_224.tgz', "-C", "models"])
subprocess.check_output(['rm', 'models/mobilenet_v2_1.4_224.tgz'])
except Exception as e:
print(e)
pass
View
BIN +2.27 KB (110%) iou_vs_epochs.png
Binary file not shown.
View
@@ -24,7 +24,7 @@ def download_checkpoints(model_name):
def build_model(model_name, net_input, num_classes, frontend="ResNet101"):
def build_model(model_name, net_input, num_classes, frontend="ResNet101", is_training=True):
# Get the selected model.
# Some of them require pre-trained ResNet
@@ -42,14 +42,16 @@ def build_model(model_name, net_input, num_classes, frontend="ResNet101"):
download_checkpoints("ResNet101")
if "ResNet152" == frontend and not os.path.isfile("models/resnet_v2_152.ckpt"):
download_checkpoints("ResNet152")
# if "MobileNetV2" == frontend and not os.path.isfile("models/mobilenet_v2_1.4_224.ckpt"):
# download_checkpoints("MobileNetV2")
network = None
init_fn = None
if model_name == "FC-DenseNet56" or model_name == "FC-DenseNet67" or model_name == "FC-DenseNet103":
network = build_fc_densenet(net_input, preset_model = model_name, num_classes=num_classes)
elif model_name == "RefineNet":
# RefineNet requires pre-trained ResNet weights
network, init_fn = build_refinenet(net_input, preset_model = model_name, frontend=frontend, num_classes=num_classes)
network, init_fn = build_refinenet(net_input, preset_model = model_name, frontend=frontend, num_classes=num_classes, is_training=is_training)
elif model_name == "FRRN-A" or model_name == "FRRN-B":
network = build_frrn(net_input, preset_model = model_name, num_classes=num_classes)
elif model_name == "Encoder-Decoder" or model_name == "Encoder-Decoder-Skip":
@@ -59,16 +61,16 @@ def build_model(model_name, net_input, num_classes, frontend="ResNet101"):
elif model_name == "PSPNet":
# Image size is required for PSPNet
# PSPNet requires pre-trained ResNet weights
network, init_fn = build_pspnet(net_input, label_size=[args.crop_height, args.crop_width], preset_model = model_name, frontend=frontend, num_classes=num_classes)
network, init_fn = build_pspnet(net_input, label_size=[args.crop_height, args.crop_width], preset_model = model_name, frontend=frontend, num_classes=num_classes, is_training=is_training)
elif model_name == "GCN":
# GCN requires pre-trained ResNet weights
network, init_fn = build_gcn(net_input, preset_model = model_name, frontend=frontend, num_classes=num_classes)
network, init_fn = build_gcn(net_input, preset_model = model_name, frontend=frontend, num_classes=num_classes, is_training=is_training)
elif model_name == "DeepLabV3":
# DeepLabV requires pre-trained ResNet weights
network, init_fn = build_deeplabv3(net_input, preset_model = model_name, frontend=frontend, num_classes=num_classes)
network, init_fn = build_deeplabv3(net_input, preset_model = model_name, frontend=frontend, num_classes=num_classes, is_training=is_training)
elif model_name == "DeepLabV3_plus":
# DeepLabV3+ requires pre-trained ResNet weights
network, init_fn = build_deeplabv3_plus(net_input, preset_model = model_name, frontend=frontend, num_classes=num_classes)
network, init_fn = build_deeplabv3_plus(net_input, preset_model = model_name, frontend=frontend, num_classes=num_classes, is_training=is_training)
elif model_name == "AdapNet":
network = build_adaptnet(net_input, num_classes=num_classes)
elif model_name == "custom":
Oops, something went wrong.

0 comments on commit e61f5dc

Please sign in to comment.