New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better FP16 support in pytorch fp16 utils. #144

Merged
merged 1 commit into from Feb 6, 2019

Conversation

Projects
None yet
2 participants
@jma127
Copy link
Contributor

jma127 commented Feb 1, 2019

This commit adds an FP16Model class as a successor to network_to_half().

The benefits of this class are:

  • Preservation of single-precision for BatchNorm layers. The models
    generated by network_to_half() convert BatchNorm moment tensors to
    half-precision, then back to single-precision, which hurts the
    accuracy of the moment estimators and occasionally results in NaNs.
  • Support for multi-argument nn.Modules (self-explanatory from code).
@jma127

This comment has been minimized.

Copy link
Contributor Author

jma127 commented Feb 1, 2019

cc @ngimel @mcarilli for review

@ngimel

This comment has been minimized.

Copy link
Contributor

ngimel commented Feb 1, 2019

Do you have examples of workloads where converting batchNorm to single precision hurts the accuracy? Thanks!

@jma127

This comment has been minimized.

Copy link
Contributor Author

jma127 commented Feb 2, 2019

For ELF OpenGo model training, the roundtrip to half precision straight up results in NaNs (we encounter some rather extreme batch moments in intermediate stages of training). Full FP32 training does not result in this error, nor does training with a version of this fix.

@jma127 jma127 force-pushed the jma127:master branch 2 times, most recently from 00aa22d to 6f7e115 Feb 2, 2019

@ngimel

This comment has been minimized.

Copy link
Contributor

ngimel commented Feb 2, 2019

Ah I see, makes sense.

@jma127

This comment has been minimized.

Copy link
Contributor Author

jma127 commented Feb 4, 2019

@ngimel Any concerns or does this look good to merge?

@ngimel
Copy link
Contributor

ngimel left a comment

Can you please add some tests to see that results of converting match your expectations (e.g. resnet18 from torchvision, some model with nested modules, batchNorm with affine = True/False?
This is a useful PR, thanks!

Converts a network's parameters and buffers to dtype.
"""
for module in network.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True:

This comment has been minimized.

@ngimel

ngimel Feb 4, 2019

Contributor

I don't see it recursively going over modules (e.g. if your network has nested modules).

This comment has been minimized.

@jma127

jma127 Feb 4, 2019

Author Contributor

torch.nn.Module.modules() enumerates all descendant modules in a recursive manner.

@jma127 jma127 force-pushed the jma127:master branch from 6f7e115 to c8d3bee Feb 5, 2019

@jma127

This comment has been minimized.

Copy link
Contributor Author

jma127 commented Feb 5, 2019

Added a barebones test with a custom dummy net. Probably makes sense to avoid torchvision dependencies if it can be avoided.

Better FP16 support in pytorch fp16 utils.
This commit adds an FP16Model class as a successor to network_to_half.

The benefits of this class are:

- Preservation of single-precision for BatchNorm layers. The models
  generated by network_to_half() convert BatchNorm moment tensors to
  half-precision, then back to single-precision, which hurts the
  accuracy of the moment estimators and occasionally results in NaNs.
- Support for multi-argument nn.Modules (self-explanatory from code).

@jma127 jma127 force-pushed the jma127:master branch from c8d3bee to 713e0fb Feb 5, 2019

@ngimel ngimel merged commit 1b90385 into NVIDIA:master Feb 6, 2019

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment