-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-116] object detector class added #10179
Conversation
@@ -0,0 +1,96 @@ | |||
package SSDClassifierExample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add license and standard preamble docs as to what this does and how you use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the License, will add a short description to this example
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add comments on what this does and how it is used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add this part in the next day
* Takes input as NDArrays, useful when you want to perform multiple operations on | ||
* the input Array or when you want to pass a batch of input. | ||
* @param input: Indexed Sequence of NDArrays | ||
* @param topK: (Optional) How many top_k(sorting will be based on the last axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
space after "k".
* the input Array or when you want to pass a batch of input. | ||
* @param input: Indexed Sequence of NDArrays | ||
* @param topK: (Optional) How many top_k(sorting will be based on the last axis) | ||
* elements to return, if not passed returns unsorted output. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comma after passed
} | ||
|
||
/** | ||
* Takes input as NDArrays, useful when you want to perform multiple operations on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch comma to period and start new sentence.
class SSDClassifierExample { | ||
@Option(name = "--model-dir", usage = "the input model directory") | ||
private val modelPath: String = "/model" | ||
@Option(name = "--model-prefix", usage = "The prefix of the model") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lower case The... unless you make the others all The
@Option(name = "--model-prefix", usage = "The prefix of the model") | ||
private val modelPrefix: String = "/ssd_resnet50_512" | ||
@Option(name = "--input-image", usage = "the input image") | ||
private val inputImagePath: String = "/images/Cat-hd-wallpapers.jpg" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we stick to kitten.jpg since it's probably around in the other test artifacts?
|
||
import ml.dmlc.mxnet._ | ||
import ml.dmlc.mxnet.infer._ | ||
import org.kohsuke.args4j.{CmdLineParser, Option} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any comment why we need these things (next four lines)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Used the ObjectDetection Class and ImageClassifier Object from mxnet.infer package. Do I need to import them separately in two lines or adding some comments in the code?
Will double check the functions being imported from mxnet._
Args4j used to get the user input args (such as '--model-dir'), same question here, do I need to add comments about this usage?
try { | ||
println(mdDir) | ||
val dType = DType.Float32 | ||
val inputShape = Shape(1, 3, 512, 512) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why isn't this being fetched from the signature.json file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add this section on 03/21
val dType = DType.Float32 | ||
val inputShape = Shape(1, 3, 512, 512) | ||
// ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...]) | ||
val outputShape = Shape(1, 6132, 6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
signature.json?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job!.
Please make the changes requested.
class ObjectDetector(modelPathPrefix: String, | ||
inputDescriptors: IndexedSeq[DataDesc]) | ||
extends Classifier(modelPathPrefix: String, | ||
inputDescriptors: IndexedSeq[DataDesc]) {1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is 1 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will kill that 1 in the next update
|
||
// Considering 'NCHW' as default layout when not provided | ||
// Else get axis according to the layout | ||
val batch = inputShape(if (inputLayout.indexOf('N')<0) 0 else inputLayout.indexOf('N')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should just require user to pass the input descriptor. Most models will have batch size in the input so thats the only thing you should default and force the user to pass CHW.
Also these are already checked in the Classifier constructor, you don't need to do it again here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that one is the ImageClassifier part, this is ObjectDetector, we need to do this part again. Will force user input CHW
val width = inputShape(if (inputLayout.indexOf('W')<0) 3 else inputLayout.indexOf('W')) | ||
|
||
/** | ||
* To classify the image according to the provided model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To Detect bounding boxes and corresponding labels
} | ||
|
||
/** | ||
* Takes input as NDArrays. Useful when you want to perform multiple operations on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Takes input images as NDArrays. Useful when you want to perform multiple operations on
the input Array, or when you want to pass a batch of input images.
r.dispose() | ||
} | ||
handler.execute(predictResult.dispose()) | ||
batchResult.toList.toIndexedSeq |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can just do batchResult.toIndexedSeq
val predictResult: ListBuffer[Array[Float]] = ListBuffer[Array[Float]]() | ||
val accuracy : ListBuffer[Float] = ListBuffer[Float]() | ||
|
||
// iterating over the individual items(batch size is in axis 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are actually iterating over all results here
r.dispose() | ||
} | ||
|
||
var sortedIndices = accuracy.zipWithIndex.sortBy(-_._1).map(_._2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need to sort if topK is not defined.
val op = NDArray.concatenate(imageBatch) | ||
// printf("concatenateed shape %s", op.shape) | ||
val result = objectDetectWithNDArray(IndexedSeq(op), topK) | ||
handler.execute(op.dispose()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you also have to dispose NDArrays created in imageBatch
|
||
val result = sortedIndices.map(idx | ||
=> (synset(predictResult(idx)(0).toInt), | ||
predictResult(idx).takeRight(5))).toList |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment on Why takeRight(5)
sortedIndices = sortedIndices.take(topK.get) | ||
} | ||
|
||
val result = sortedIndices.map(idx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again no need sort if topK is not defined.
a26c881
to
03a37d5
Compare
*/ | ||
class ObjectDetector(modelPathPrefix: String, | ||
inputDescriptors: IndexedSeq[DataDesc]) | ||
extends Classifier(modelPathPrefix: String, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it does not feel like ObjectDetector "is a" Classifier, this should be a composition rather than inhertiance
wget https://cloud.githubusercontent.com/assets/3307514/20012566/cbb53c76-a27d-11e6-9aaa-91939c9a1cd5.jpg -O 000001.jpg | ||
wget https://cloud.githubusercontent.com/assets/3307514/20012567/cbb60336-a27d-11e6-93ff-cbc3f09f5c9e.jpg -O dog.jpg | ||
wget https://cloud.githubusercontent.com/assets/3307514/20012563/cbb41382-a27d-11e6-92a9-18dab4fd1ad3.jpg -O person.jpg | ||
fi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the return char
|
||
java -Xmx8G -cp $CLASS_PATH \ | ||
ml.dmlc.mxnetexamples.inferexample.objectdetector.SSDClassifierExample \ | ||
--model-dir $MODEL_DIR \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you just take model-path-prefix instead of two different directories ?
|
||
// Considering 'NCHW' as default layout when not provided | ||
// Else get axis according to the layout | ||
val batch = inputShape(if (inputLayout.indexOf('N')<0) 0 else inputLayout.indexOf('N')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you probably can access all the below variables from the ImageClassifier
Also no need for the if else
|
||
val result = objectDetectWithNDArray(IndexedSeq(op), topK) | ||
handler.execute(op.dispose()) | ||
for (ele <- imageBatch) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The beauty of Scala is you can express statements in concise format like this:
imageBatch.foreach(_.dispose())
c109d2b
to
48b077a
Compare
I think that you should have the examples work from a common directory. |
@lanking520 we can put them in a folder like this: |
Yes, Will update that |
changed the println option into logger.info to follow the guideline
48b077a
to
ad68e67
Compare
|
||
test("objectDetectWithInputImage") { | ||
val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512))) | ||
val inputImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are these hardcode 224 size for? I don't think object detection works on such resolution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right. This image is a sample image feed into the unit test. Made the change.
Docs GTG for now, but would like to test again after #10054 is merged. |
* Fix mem alignment. * use 64 for alignments. * Update.
changed the println option into logger.info to follow the guideline
…fix the existing issues in the script and update readme
* update windows installation doc to inform user that they need to update the PATH variable * update
* Fix failed test with opencv 3.4.1 * Replace int conversion with `//` * retrigger test
Please move to: #10229. |
Description
Add Object Detection Class for MXNet Inference API, also add test and example to it. This code is fully tested with SSD model.
@nswamy @Roshrini
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments