Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Re-implement segnet in MXnet #8423

Closed
wants to merge 16 commits into from
Closed

Conversation

solin319
Copy link
Contributor

Re-implement segnet in MXnet.
Details in https://github.com/solin319/incubator-mxnet/tree/solin-patch-segnet/example/segnet
The Camvid test accuracy achieved 83% in segnet_basic network.

Does MXNet need this work?

@solin319 solin319 changed the title Solin patch segnet Re-implement segnet in MXnet Oct 25, 2017
@chinakook
Copy link
Contributor

chinakook commented Oct 26, 2017

Very nice work! At long last, Unpooling is reimplemented in MXNet. Segnet and U-Net are both famous in Kaggle and many practical applications rather than FCN-xs and Deeplab so I think MXNet need this work.

@piiswrong
Copy link
Contributor

Does result match original paper?

@zhreshold Could you have a look?

@zhreshold
Copy link
Member

Nice, can you put some instruction of how to use your operators? @solin319
And I cannot find the performance in either your repo or the original caffe one.

@solin319
Copy link
Contributor Author

solin319 commented Oct 28, 2017

Build MXNet with new pooling and upsampling operators.

# copy new operators to src/operator/
cp segnet/op/* incubator-mxnet/src/operator/
# rebuild MXNet from source
cd incubator-mxnet/
make
cd python/
python setup.py install


mx.metric.check_label_shapes(label, pred_label)

self.sum_metric += (pred_label.flat == label.flat).sum()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ignore_label is in pred_label, e.g. ignore_label == 0, this is wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elf.sum_metric -= len(pred_label[pred_label == self.ignore_label])

This can solve the problem caused by ignore_label located in pred_label.

output_names=output_names, label_names=label_names)
self.eps = eps

def update(self, labels, preds):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the native one in metric.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.The native 'CorssEntropy' in metric.py can only used to pred_label with 2-d.
2. In this example, the pred_label is 4-d and the softmax result located in axis 1. We must swap the softmax result to the last axis.
3. Because of the ignore_label=11, the max number of label is 11, but the size of softmax result is only 0-10. We must add one column data to use the command
prob = pred[numpy.arange(label.shape[0]), numpy.int64(label)]

return list(data.items()), list(label.items())

def _read_img(self, img_name, label_name):
img = Image.open(os.path.join(self.root_dir, img_name))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use mx.image.imread instead of PIL package.


DMLC_REGISTER_PARAMETER(PoolingMaskParam);

MXNET_REGISTER_OP_PROPERTY(PoolingMask, PoolingMaskProp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think this operator is reusable to be put into operator/contrib?


DMLC_REGISTER_PARAMETER(UpSamplingMaskParam);

MXNET_REGISTER_OP_PROPERTY(UpSamplingMask, UpSamplingMaskProp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think this operator is reusable to be put into operator/contrib?

import logging
import pdb
import numpy as np
from PIL import Image
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, use mx.image

import os
import logging
import numpy as np
from PIL import Image
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Member

@zhreshold zhreshold left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comments.
Remove result pngs, you can use github issue to upload images and get permanent links.

@solin319
Copy link
Contributor Author

solin319 commented Nov 2, 2017

@zhreshold
Thank you for your suggestions. I have use mx.image to read pic and remove results pngs. But I still have two questions.

  1. I think the operator pooling and up-sampling with mask may used in other pixelwise segmentation networks. Shall I move them to operator/contrib.
  2. I use PIL to save numpy array to colorful png. I can't find the interface in mx.image.

@zhreshold
Copy link
Member

@solin319
Yes, move to operator/contrib if they are reusable.
Do you really need to save pics? If so, it's fine to leave it there.
And BTW, do you think your result can match the paper?

@solin319
Copy link
Contributor Author

solin319 commented Nov 3, 2017

@zhreshold
Yes, the test result can match the paper.
In CamVid data-set, we used 367 for training and 233 for testing like paper.
The train command is python train_segnet.py --gpus 0,1,2,3 --lr=0.01 --network=segnet_basic .
num_epochs = 250
lr_steps = 5000,7000
The result was shown below.
segnet_basic
We also trained with segnet_basic_with_drop, the result was shown below.
num_epochs = 400
lr_steps = 9000,10000
segnet_basic_with_drop
The test accuracy was similar to the paper.

@zhreshold
Copy link
Member

@solin319 Have you finished? Can you rebase to fix the CI?

@piiswrong
Copy link
Contributor

Result on the official site is:

Model Global Accuracy Class Accuracy Mean I/U
Segnet-Basic 82.8% 62.3% 46.3%
SegNet (Pretrained Encoder) 88.6% 65.9% 50.2%
SegNet (3.5K dataset) 86.8% 81.3% 69.1%

Which one are you comparing to?

@solin319
Copy link
Contributor Author

The first one segnet-basic, training with 367 pics and test in 233 pics.
@piiswrong

@piiswrong
Copy link
Contributor

@zhreshold @winstywang @solin319 Any updates on this?

@solin319
Copy link
Contributor Author

solin319 commented Dec 26, 2017

We trained segnet with vgg16-bn pretrained parameters and got accuracy 0.89.
segnet_vgg_pretrained

@piiswrong

@piiswrong
Copy link
Contributor

@zhreshold Could you check if this reproduces the paper?

@zhreshold
Copy link
Member

screen shot 2018-01-31 at 9 33 54 pm

@solin319 Is this table SOTA as you are results are comparing with?

@solin319
Copy link
Contributor Author

solin319 commented Feb 1, 2018

comparing with:

Model Global Accuracy Class Accuracy Mean I/U
Segnet-Basic 82.8% 62.3% 46.3%
SegNet (Pretrained Encoder) 88.6% 65.9% 50.2%

http://mi.eng.cam.ac.uk/projects/segnet/tutorial.html
@zhreshold

@CodingCat
Copy link
Contributor

Hi, the community has passed to vote about associating the code changes with JIRA (https://lists.apache.org/thread.html/ab22cf0e35f1bce2c3bf3bec2bc5b85a9583a3fe7fd56ba1bbade55f@%3Cdev.mxnet.apache.org%3E)

We have updated the guidelines for contributors in https://cwiki.apache.org/confluence/display/MXNET/Development+Process, please ensure that you have created a JIRA at https://issues.apache.org/jira/projects/MXNET/issues/ to describe your work in this pull request and include the JIRA title in your PR as [MXNET-xxxx] your title where MXNET-xxxx is the JIRA id

Thanks!

@solin319 solin319 closed this Apr 27, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants