Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ICNet for image segmentation. #975

Merged
merged 3 commits into from
Jun 21, 2018
Merged

Conversation

4luojing
Copy link
Contributor

No description provided.

@qingqing01 qingqing01 changed the title Add icnet. Add ICNet. Jun 11, 2018
@qingqing01 qingqing01 changed the title Add ICNet. Add ICNet for image segmentation. Jun 11, 2018
import numpy as np
import paddle.v2 as paddle

DATA_PATH = "../../data/cityscape"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认路径修改为"./data/cityscape"

yield image, label
return reader

def load(image_label):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

每个方法都要给出简要说明。

return image_labels[0].astype("float32"), label_sub1, mask_sub1.astype("int32"), label_sub2, mask_sub2.astype("int32"), label_sub4, mask_sub4.astype("int32")


def train(batch_size=32, random_mirror=False, random_scaling=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对外暴露的方法要给出详细使用说明。

_, _, sub124_out = network.icnet(args, images, num_classes, np.array(data_shape[1:]).astype("float32"), is_test=True)
predict = fluid.layers.upsampling_bilinear2d(sub124_out, out_shape=data_shape[1:3])
predict = fluid.layers.transpose(predict, perm=[0,2,3,1])
fluid.layers.Print(predict, summarize=10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉Print语句

import sys


def conv(input, k_h, k_w, c_o, s_h, s_w, relu=True, padding="VALID", group=1, biased=None, name=None, print_pad=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果group没有被用到,建议将其从参数列表删除。

import sys


def conv(input, k_h, k_w, c_o, s_h, s_w, relu=True, padding="VALID", group=1, biased=None, name=None, print_pad=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print_op 貌似也是多余的。


def interp(input, out_shape):
out_shape = list(out_shape.astype("int32"))
return fluid.layers.upsampling_bilinear2d(input, out_shape=out_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最新paddle api已经将upsampling_bilinear2d 重命名为resize_bilinear

conv3_3_1_1_increase = conv(conv3_3_3_3_bn, 1, 1, 256, 1, 1, biased=False, relu=False, name="conv3_3_1_1_increase")
conv3_3_1_1_increase_bn = bn(conv3_3_1_1_increase, relu=False, name="conv3_3_1_1_increase_bn", is_test=is_test)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

空行太多。

@lgone2000
Copy link

能否在image-segmentatation 建立一个子目录?便于未来扩展? by guoyi

@qingqing01
Copy link
Collaborator

@lgone2000 good idea.




本文采用Cityscape数据集,请前往[Cityscape官网]()注册下载。下载数据之后,按照[这里](https://github.com/mcordts/cityscapesScripts)的说明和工具处理数据。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

官网地址补充完整。

place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_persistables(exe, args.model_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于evaluator来说,只load parameter就行。

exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_persistables(exe, args.model_path)
print "loaded model from: %s" % args.model_path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

判断下model_path是否存在。

fetch_list=fetch_vars)
out_wrong += result[1]
out_right += result[2]
print "count: %s; current iou: %.3f;" % (count, result[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议在同行输出。

add_arg('use_gpu', bool, True, "Whether use GPU to test.")
# yapf: enable

def cal_mean_iou(wrong, cerroct, num_classes):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cerroct -> correct


def cal_mean_iou(wrong, cerroct, num_classes):
sum = wrong + cerroct
return (cerroct.astype("float64") / sum).sum() / num_classes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果class i 从来没出现过,那么sum[i]就是0, 不能直接除;最后求平均也要除真实出现的class数量。

另外num_classes == len(num_classes), 所以可以去掉num_classes这个参数。

add_arg('model_path', str, "", "Model path.")
add_arg('images_list', str, "", "List file with images to be infered.")
add_arg('images_path', str, "", "The images path.")
add_arg('out_path', str, "", "Output path.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

给个默认的out_path. 其它路径也给个默认值。

,[119, 10, 32]]
# 18 = bicycle

def color(input):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

说明下方法的作用

place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_persistables(exe, args.model_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load parameter既可

image_t = fluid.core.LoDTensor()
image_t.set(img, place)
result = exe.run(feed={"image": image_t}, fetch_list=[predict])
cv2.imwrite(args.out_path + "/" + filename + "_result.png", color(result[0]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果args.out_path不存在,需要新建一个。

- https://github.com/hszhao/ICNet
- https://github.com/hellochick/ICNet-tensorflow
- https://github.com/mcordts/cityscapesScripts
- https://zhuanlan.zhihu.com/p/26653218
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考文献给知乎不够严谨, 只保留一篇paper就行了?

<p align="center">
<img src="images/train_loss.png" width="620" hspace='10'/> <br/>
<strong>图 2</strong>
</p>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有mean IoU的结果吗?


## 简介

Image Cascade Network(ICNet),在兼顾实时性的同时,比原来的Fast Semantic Segmentation,比如SQ, SegNet, ENet等大大地提高了准确率,足以与Deeplab v2媲美,给语义分割的落地提供了可能。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段中文和后面链接里知乎的中文相同,需要修改。

## 简介

Image Cascade Network(ICNet),在兼顾实时性的同时,比原来的Fast Semantic Segmentation,比如SQ, SegNet, ENet等大大地提高了准确率,足以与Deeplab v2媲美,给语义分割的落地提供了可能。
ICNet利用了低分辨率图片的高效处理和高分辨率图片的高推断质量两种优点。主要思想是:让低分辨率图像经过整个语义网络输出一个粗糙的预测,然后利用文中提出的级联融合单元来引入中分辨率和高分辨率图像的特征,从而逐渐提高精度。整个网络结构如下:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,和知乎的中文描述相同,需要修改。




本文采用Cityscape数据集,请前往[Cityscape官网]()注册下载。下载数据之后,按照[这里](https://github.com/mcordts/cityscapesScripts)的说明和工具处理数据。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

进入https://github.com/mcordts/cityscapesScripts 这个链接里,一眼看去,看不懂如何处理数据,我们这里可以提供处理脚本吗?

def test():
reader = DataGenerater(TEST_LIST).create_reader()
reader = paddle.reader.map_readers(load, reader)
reader = paddle.reader.map_readers(test_process, reader)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要两个 paddle.reader.map_readers吧。

# define network
images = fluid.layers.data(name='image', shape=data_shape, dtype='float32')
_, _, sub124_out = network.icnet(args, images, num_classes, np.array(data_shape[1:]).astype("float32"), is_test=True)
predict = fluid.layers.upsampling_bilinear2d(sub124_out, out_shape=data_shape[1:3])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upsampling_bilinear2d -> resize_bilinear

fluid.io.load_persistables(exe, args.model_path)
print "loaded model from: %s" % args.model_path
sys.stdout.flush()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为网络中有BN,需要:

inference_program = fluid.default_main_program().clone(for_test=True)

下面83行的exe运行这个 inference_program


sub4_out = conv(conv5_4_interp, 1, 1, num_classes, 1, 1, biased=True, relu=False, name="sub4_out")

sub24_out = conv(sub24_sum_interp, 1, 1, num_classes, 1, 1, biased=True, relu=False, name="sub24_out")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面网络可以模块化的写下吗? 现在看着网络特别复杂。

@@ -0,0 +1,337 @@
import paddle.fluid as fluid
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

network.py 有更精确的名字吗?

@4luojing
Copy link
Contributor Author

已按review修改代码。请@wanghaoshuang @qingqing01 确认。

Copy link
Contributor

@wanghaoshuang wanghaoshuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. 不知道 @qingqing01 还有什么意见么?

@wanghaoshuang wanghaoshuang merged commit 2cb27d0 into PaddlePaddle:develop Jun 21, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants