Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hot fix trunc_normal_ bug #9711

Merged
merged 7 commits into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion oneflow/core/profiler/event_recorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Maybe<EventRecorder> EventRecorder::CreateKernelEventRecorder(
}
return std::make_shared<EventRecorder>(event);
}
#else // WITH_CUDA
#else // WITH_CUDA
if (pmgr->use_cpu_) {
return std::make_shared<EventRecorder>(
KernelEvent::Create(name, pmgr->record_shapes_ ? shape_getter : nullptr));
Expand Down
36 changes: 35 additions & 1 deletion python/oneflow/nn/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""
import os
import math
import warnings

import numpy as np

Expand Down Expand Up @@ -251,9 +252,42 @@ def kaiming_normal_(
return normal_(tensor, 0.0, std)


# The trunc_normal_ implemention is referenced from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py#L22
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)

with flow.no_grad():
return tensor.normal_(mean, std).clamp_(a, b)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)

# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)

# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()

# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)

# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor


def constant_(tensor, val):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class TestGlobalDivHigherDerivative(flow.unittest.TestCase):
@globaltest
def test_global_div_grad_grad(test_case):
for placement in all_placement():
for i in range(5):
for i in range(500):
_test_global_div_grad_grad_impl(test_case, placement)


Expand Down