Skip to content

Commit

Permalink
hot fix trunc_normal_ bug (#9711)
Browse files Browse the repository at this point in the history
修复 trunc_normal_ 实现 bug。

#close https://github.com/Oneflow-Inc/OneTeam/issues/1867

分布测试结果:

torch:


![图片](https://user-images.githubusercontent.com/35585791/211011065-a498dead-ad61-4d8a-a3a8-8966f4b6a513.png)

本pr:


![图片](https://user-images.githubusercontent.com/35585791/211011096-a1da39e1-9745-4336-84b8-0c544cc7034c.png)

oneflow master:

<img width="465" alt="图片"
src="https://user-images.githubusercontent.com/35585791/211011132-40ca92a5-6db8-4db5-9676-c39cce012257.png">

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Juncheng <liujuncheng1022@gmail.com>
  • Loading branch information
3 people committed Jan 7, 2023
1 parent 29bfeab commit 82ce240
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion python/oneflow/nn/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""
import os
import math
import warnings

import numpy as np

Expand Down Expand Up @@ -251,9 +252,42 @@ def kaiming_normal_(
return normal_(tensor, 0.0, std)


# The trunc_normal_ implemention is referenced from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py#L22
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)

with flow.no_grad():
return tensor.normal_(mean, std).clamp_(a, b)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)

# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)

# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()

# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)

# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor


def constant_(tensor, val):
Expand Down

0 comments on commit 82ce240

Please sign in to comment.