-
Notifications
You must be signed in to change notification settings - Fork 6.8k
mx.io.ImageRecordIter does not respect dtype argument / FP16 performance on Volta #9774
Comments
I don't think it will perform better than producing fp32 and then casting to fp16 at the beginning of the training.
|
I see. Could you expand on 1 please? |
Engine does not seem to differentiate between first layer and subsequent layers on that it considers data going into first layer as being modified by the backward pass of the network (even though it does not actually happen). This means that copying of the next batch has to wait to the end of backward, which effectively exposes the copy. Having this double-buffering scheme of either cast fp32->fp16 or just identity fp32->fp32 makes sure that the ndarray used to copy the next batch to the gpu is returned from engine before backward pass ends, which enables copy to happen while backward computation takes place. |
Thanks for the explanation. Btw, is training in fp16 supposed to be ~2x faster than fp32 for a given batch size? Or is this only about reduced memory usage so we can use larger batch sizes. I have run the examples you had added for fp16 in image classification, and I see maybe +- 10-15% speed changes, nowhere close to 2x for a given batch size. Is this normal? I ran these tests on p3.8x and p3.16x EC2 machines, which use the Volta range of GPUs. I have CUDA9 as well. |
There are few possible explanations. |
Thanks, this blogpost mentions training imagenet with resnet50 on 8 volta gpus. Could you share the performance benefits you had observed in such settings? Or under optimal conditions, roughly how much speedup did you notice with MXNet. |
@rahul003 If you're only seeing ~15% speedups I'd recommend you run nvprof before your training. Take a look at the GEMMS and ensure they have s884 in the name. If they don't then one of these rules is probably not being followed:
(from https://devblogs.nvidia.com/programming-tensor-cores-cuda-9/) The tensor cores are a little tricky to use in a lot of cases. Let us know if nvprof shows that your model isn't being run on tensor cores and we might be able to give you some next steps. |
I ran Resnet50 with Imagenet and got about 70% speedup. Some of the top ones don't seem to have s884 but some operations do. Can I improve the speed further?
|
But for Resnet 110 on Cifar10, fp16 is much slower. Do you see something fishy here? There are barely any operations with s884 in their names. All the top ones don't. So fp16 would not help us with small networks/models?
|
@rahul003 Could you paste here how you invoked the benchmark script? Did you set the autotune env var to 2? Also, which cudnn version do you use? You should see many more s884 kernels in your profile... |
I'm running this command.
Are you sure you have the right variable? Relevant code is
|
I was asking about the imagenet script. If you use smaller batch size like 256 for 8 GPUs (the best results you will see with larger batch size like 1024 for 8 GPUs), you may consider turning on the NCCL kvstore (
|
Okay cool, I'll try to document that. I was using
|
I have cudnn v7005 and cuda 9.0 |
Just in case try synthetic data with |
Both suggestions didn't help improve the speed unforunately. Using MXNET_CUDNN_AUTOTUNE_DEFAULT=2 helped in some cases. But we can't say this setting helps consistently. If it picks the fastest, why would it not help in all cases? I understand cases where it should be same speed as other algos. But sometimes, this is slower than setting it to 1. All else should remain same, right? |
Sorry I was digressing from the topic of the issue. Regarding the iterator issue, we need to document the behavior that it will return fp32 data regardless. Keeping this open till we fix it or document it |
Description
mx.io.ImageRecordIter
orsrc/io/iter_image_recordio_2.cc
doesn't respect dtype parameter taken.It is designed to only work with float32 because of instantiating the class with real_t dtype. (in src/io/iter_image_recordio_2.cc). Can we make it handle fp16 too? This is important for fp16 training.
Also, training in fp16 seems slower than fp32 for some models.
Environment info (Required)
Mxnet 1.0
Package used: Python
Error Message:
Silently generates fp32 data
Minimum reproducible example
N/A
Steps to reproduce
N/A
What have you tried to solve it?
Can we come up with a better way than to create a new operator passing DType as fp16?
@ptrendx
The text was updated successfully, but these errors were encountered: