Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

How to use mxnet for image segmentation training? #337

Closed
tornadomeet opened this issue Oct 20, 2015 · 34 comments
Closed

How to use mxnet for image segmentation training? #337

tornadomeet opened this issue Oct 20, 2015 · 34 comments

Comments

@tornadomeet
Copy link
Contributor

In image segmentation, if one image has N pixpels, the number of labels is also N(not one), so the Softmax operator in mxnet can't handle it(just my opinion);

I want solve it through:

  1. add a new operator(egs, softmaxseg-inl.h, softmaxseg.cu, softmaxseg.cc) in src/operator directory, SoftmaxSeg is used for image segmentation in forward, backword, calc loss..

  2. change the format of image list file, like this : integer_image_index \t label.jpg \t data.jpg (each pixel in label.jpg stands for its class label).

  3. change the code of iter_image_recordio.cc, so class ImageLabelMap can read label.jpg and store the value in label_;

    I am a beginner of mxnet, How should we use mxnet for image segmentation training? can the above way solve it? or is there a better solution?

    Give some advise, thanks.

@pluskid
Copy link
Contributor

pluskid commented Oct 20, 2015

The Caffe / Mocha.jl way of handling this is to allow "multiple dimensions" in the softmax loss layer. For example, the label could be (using Python's row-based ordering) N-by-1-by-P, while the predictions will be N-by-K-by-P where N is number of samples in mini-batch, K is number of classes, P could be interpreted as number of positions / pixels. In general, the prediction could be any ND tensor, and label will be a tensor with corresponding shape, except one of the dimension is a singleton dimension (of size 1). For example, pixel-wise prediction would be

  • Label: N-by-1-by-H-by-W
  • Prediction: N-by-K-by-H-by-W

In terms of image segmentation, I think this CRF-as-RNN Caffe Code might be a very nice one to incorporate into MXNet and serving as a cool demo. I might consider doing it when I squeeze out some time but it seems not very recently. So @tornadomeet if you are working in segmentation you could probably add this.

@pluskid
Copy link
Contributor

pluskid commented Oct 20, 2015

BTW: just as a side comment since our current Softmax operator is being discussed. It is a bit confusing at first when I was using the Softmax operator. The forward operation behave exactly like a softmax, but in the backward operation, it becomes a softmax with multiclass logistic loss. I guess this might be due to efficiency consideration or code re-use, but it leads to some inconvenience / inconsistency.

  • The Softmax operator needs both data and label as arguments, though if people only do prediction, only the data is needed.
  • There is no objective function value of the logistic loss computed. Though it does not really affect the learning, and especially we are running for a fixed number of iterations, we do not need to rely on it for stop condition, either. But the correct objective function value serves important role for debugging. For example, when researcher are writing their own layers or new optimizers, looking at how the objective function changing at a finer scale might be very helpful for testing whether it is implemented correctly, and if not where might be some issues.
  • People might need to use some other losses on top of the Softmax probabilistic output. Caffe / Mocha.jl currently have a Softmax layer that does just softmax, and then a SoftmaxLoss layer which combines softmax and logistic loss. I think this might be a viable alternative way.

The followings are just some minor thoughts on the general design:

  • Currently the label arguments for Softmax is an implicit arguments. If I understand correctly, there is no way to (like data) construct a Variable as the label and compose them like this
data = mx.symbol.Variable('data-foo')
label = mx.symbol.Variable('label-foo')
net = mx.symbol.Softmax(data=data, label=label, name='out')

This essentially allows us to rename the label variable to whatever we want, as in the data case. Then when the user construct DataIter, he/she can specify what names (data-foo and label-foo in this case) the DataIter is providing (this is also the current design used in MXNet.jl). So when we train the network with this DataIter we actually knows which ones are data which ones are labels. The current way (automatically deciding based on a data or label postfix) is nice, but it might get confused to figure out the exact correspondence in multiple-input multiple-output case (RNN case, not image segmentation case).

  • Since the loss functions are only used during training, we might use the following convention: when defining the architecture, the user only defines the network up to the output layer. A loss layer is provided when the user calls fit, and the network is composed with the loss criterion on the fly. But otherwise, the loss is not a part of the architecture and the semantic for doing predict with the network is clearer. For example, Lasagne, a light-weighted framework based on Theano has this design.

@tqchen
Copy link
Member

tqchen commented Oct 20, 2015

@pluskid I agree with your idea on softmax should be solely softmax transformation.
And softmax-loss should be used for loss function. Do you want to take a stab on the refactor?

@tqchen
Copy link
Member

tqchen commented Oct 20, 2015

The care need to make a truely valid softmax operator, though. Because being able to propback the right gradient for any composition after softmax requires the log-probability instead of probability for numerical stability. A better approach might be take a log-softmax operator to allow arbitrary composition. Or simply restrict softmax to only be able to composed with SoftmaxLoss.

The current decision was made to make things work in a good way for restricted case, which should be changed.

@antinucleon
Copy link
Contributor

@tqchen We should do refactor soon after we fix RNN stuff, because 2 reason: 1. Capable with CuDNN 2. Better support different loss on Softmax

@winstywang
Copy link
Contributor

Hi, all

Thanks are quite a lot of work need to be done for segmentation task. The multi softmax loss for segmentation is only one issue. However, there are several other ops need to be implemented:

  1. The UpSampling layer used in FCN
  2. IO part for image segmentation

The most concerned part is the UpSampling layer. It may not easy to implement in current mshadow framework.

@antinucleon
Copy link
Contributor

@winstywang

  1. We can enable calling CUDA Kernel directly in OP instead of writing mshadow OP
  2. It is not hard to write Python IO directly instead of writing special C++ IO for segmentation

@pluskid
Copy link
Contributor

pluskid commented Oct 20, 2015

@tqchen Sure I'm glad to help on this but let me roll MXNet.jl in a good shape first. We might need to have somewhere an official TODO items for the whole project so that we do not forget things.

@winstywang I heard from a friend working on image segmentation saying that

though the up-sampling layer could be formulated in various ways, using, for example, fancy un-convolution, etc. In the end, a naive up-sampling by image-resizing-kind of operation just works fine. Even those complicated up-sampling layers are initialized in this way, and if you look at the learned filters in those layers, they are not really far from the initialization, which is simple up-resizing.

I have not personally tried to implement and verify this statement. But I think it might be worth to keep in mind. If this is true, it will definitely save everybody a lot of extra work.

@pluskid
Copy link
Contributor

pluskid commented Oct 20, 2015

@winstywang What do you mean by the IO part? If our softmax is extended to Caffe-like that support evaluating at each pixels, then the whole pipeline is still single-input single-output (though multi-dimensional). It seems the current pipeline fits, or is the current image-records-IO only support single value labels?

@pluskid
Copy link
Contributor

pluskid commented Oct 20, 2015

@tqchen I agree that combined softmax and logistic loss is much more numerically stable. The compromise in Caffe/Mocha.jl is to provide a SoftmaxLoss layer that is essentially combined softmax and logistic loss layer, but do (backward) computation together. It is more efficient and numerically stable. But meanwhile, also offer softmax layer and logistic regression layer separately in case people need to do something else. Actually, a while ago, I was working on a project that use a different loss other than the logistic on the softmax output.

In fact, doing them separately might not be that bad than we think in practice, if you do softmax carefully. I did some experiments quite long time ago (scroll down the page for the last figure). It seems that the discrepancy is within a reasonable range for relatively bounded inputs even for Float32. As far as I know, Theano just provide softmax and logistic loss separately, though I do not know if during the compilation stage this actually get optimized into a single softmax-loss optimization. If we are still concerned here, the option adopted by Torch might also be possible: they have a LogSoftmax layer. 😄

@tornadomeet
Copy link
Contributor Author

Thansk all for discussion, Learned a lot.
@pluskid Yes, i want to do some experiments with mxnet on segmentation, using deeplab code, and CRF-as-RNN code, If I have the ability to complete this, i'll share it , but Firstly, let me be familar with mxnet . ^^

@tqchen
Copy link
Member

tqchen commented Oct 25, 2015

The multi label softmax is merged in #387

@tornadomeet
Copy link
Contributor Author

@tqchen ,Yeah, i see it! thanks

@futurely
Copy link

Torch's upsampling implementation is here and here.

@futurely
Copy link

futurely commented Nov 3, 2015

The current state-of-the-art on the VOC 2012 segmentation competition leaderboard is deep parsing network [1]. Not implemented operations include padded convolution filter copy initialization, up-sampling, local convolution and channels block min pooling.

[1] Ziwei Liu, Xiaoxiao Li, Ping Luo, Chen Change Loy, and Xiaoou Tang. Semantic Image Segmentation via Deep Parsing Network. ICCV 2015.

@futurely
Copy link

@HyeonwooNoh's entry in the PASVOC VOC segmentation benchmark POSTECH_DeconvNet_CRF_VOC achieved 74.8% average precision. In his implementation, pixel wise classification loss RedSoftmaxWithLossLayer, EltwiseAccuracyLayer and RedAccuracyLayer are related to segmentation. The C++ seg data layers are better replaced by Python layers like BVLC/caffe#1698 (comment) and #554 (comment) which would be easier to implement.

[1] Hyeonwoo Noh, Seunghoon Hong, Bohyung Han. Learning Deconvolution Network for Semantic Segmentation. ICCV 2015.

@HyeonwooNoh
Copy link

@futurely
EltwiseAccuracyLayer is used to compute accuracy for segmentation, but
RedSoftmaxWithLossLayer is not used for segmentation (it's implemented for another purpose).
For segmentation, SoftmaxLoss layer should be refered.

@winstywang
Copy link
Contributor

wow, the authors are here :)
I will work on the segmentation task soon

@playerkk
Copy link

Is that possible to assign different weights to different ground-truth labels? The weight can be set as 1/f_i, where f_i is the frequency of the i-th label in the training set, to deal with the unbalanced number of training samples of different classes.

@futurely
Copy link

In "Learning Deconvolution Network for Semantic Segmentation", "for both datasets, we maintain the balance for the number of examples across classes by adding redundant examples for the classes with limited number of examples."

@futurely
Copy link

In more traditional machine learning algorithms, weighted sample are not unusual. But long tail samples sometimes hurt the performance in DL experiments [2]. It's doubtful whether duplicating them will be of any help.

[2] Erjin Zhou, Zhimin Cao, Qi Yin: Naive-Deep Face Recognition: Touching the Limit of LFW Benchmark or Not? CoRR abs/1501.04690 (2015).

@futurely
Copy link

As you can see on the PASCAL VOC 2012 segmentation leaderboard, most top performing submissions used the Microsoft COCO dataset to further boost performance. The observation is very similar with what was discovered in "Naive-Deep Face Recognition": big data is more effective than complicated algorithms.

@winstywang
Copy link
Contributor

I will work on segmentation task from next week. I think the first step is to implement the basic FCN model instead of the second model.

@futurely
Copy link

@playerkk class weighting is implemented by the authors of [3][4]. But the experimental result on the PASCAL VOC 2012 segmentation dataset is not very competitive with DeconvNet [1] and CRF-RNN [5] although its network architecture seems to be very similar with DeconvNet. It's would be instructive to get some explanations from @alexgkendall.

[3] Alex Kendall, Vijay Badrinarayanan and Roberto Cipolla "Bayesian SegNet: Model Uncertainty in Deep Convolutional Encoder-Decoder Architectures for Scene Understanding." arXiv preprint arXiv:1511.02680, 2015. http://arxiv.org/abs/1511.02680
[4] Vijay Badrinarayanan, Alex Kendall and Roberto Cipolla "SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation." arXiv preprint arXiv:1511.00561, 2015. http://arxiv.org/abs/1511.00561
[5] Shuai Zheng, Sadeep Jayasumana, Bernardino Romera-Paredes, Vibhav Vineet, Zhizhong Su, Dalong Du, Chang Huang and Philip H.S. Torr, Conditional Random Fields as Recurrent Neural Networks. IEEE International Conference on Computer Vision (ICCV), 2015.

@alexgkendall
Copy link

Hi - Our work [3][4] differs from [1] in class weighting as you pointed out. Another major difference is the much more efficient parametrisation in SegNet [4], which is an order of magnitude faster to run, and can be trained end to end in one step. I believe DeConvNet [1] uses stage wise training and multiple region proposals in inference time.

We found class weighting to be very important for scene understanding tasks. If you are more interested in datasets such as SUN or CamVid then I'd recommend implementing it. Cheers.

@playerkk
Copy link

Thanks all for your reply.

I am working on predicting pixel-wise labels, not necessarily image segmentation. The distributions of labels are highly unbalanced. Experimental results based on MatConvNet indicate that weighted loss is quite helpful. MatConvNet, however, is very slow. So I am considering to switch using other packages.

@alexgkendall
Copy link

Cool, SegNet can do that and it is built on Caffe which is pretty fast - you might find this tutorial helpful.

@futurely
Copy link

Once #640 is completed, most of the published image segmentation models that are based on FCN can be replicated without much difficulty in MXNet.

Maybe most pre-trained models can even be imported directly.

@futurely
Copy link

Actually, convolutional autoencoder networks don't have to use the pooling and unpooling layers. Simply stacking several conv layers and then a few deconv layers together is alright. Just for a little more fun, they can be arbitrarily interleaved!

@mli
Copy link
Member

mli commented Dec 14, 2015

@tornadomeet congras you make it work, i'm going to close this issue now. it will be great if you can PR the example back.

to others, the example is available at https://github.com/tornadomeet/mxnet/tree/seg/example/fcn-xs

@zhangfanqie
Copy link

@tornadomeet Hi,I'm confused about how to make my own training data.Does the im2rec.py support making .rec files when labels are images as well? How to do with its arguments?
Thanks a lot!!!

@great-thoughts
Copy link

If my data resides in AWS S3,how should I modify the code? passing s3://bucketname/... doesnt seem to work.

anirudh2290 pushed a commit to anirudh2290/mxnet that referenced this issue Jan 11, 2018
iblislin added a commit to iblislin/incubator-mxnet that referenced this issue Mar 18, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests