Disable tensorfloat32 type #2404
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Apparently the "new math" is just being sloppy as hardware instructions: https://blogs.nvidia.com/blog/tensorfloat-32-precision-format/
Note that the above is for GPU (or TPUs) but newer CPUs might be affected by the same issue in case we see it again:
https://www.intel.com/content/dam/develop/external/us/en/documents/lower-numerical-precision-deep-learning-jan2018-754765.pdf
For context here are the results,
251105-jcm-nnpdf41-mhou-001is the new baseline and these are a few comparisons with a computer with this problem:a) Default options, tensorfloat 32 is used for some operations with the consequence of a loss of precision

b)
double_precision: trueinn3fit. By setting the default type to float64, only some (not clear to me which) operations are downgraded to these tensorfloat types, but everything looks okc) And finally comparison to a fit in which the tensorfloat is disabled and the entire fit is done in single precision, results are ok

I'll benchmark cases b) and c) to check what is faster in cineca before merging this branch. Or if you prefer we can merge it asap and if during the benchmark we find that option b) is better (run by default in dp but keep these special features on) we can change the default.
NB: I'm only disabling it in TensorFlow, although in principle this is a hardware thing. However, apparently pytorch is more conservative in which operations have it on by default (perhaps that's why it is much slower...) and the results are actually ok. I haven't tested jax.