diff --git a/apex/fused_dense/fused_dense.py b/apex/fused_dense/fused_dense.py index 0f50532c3..97377a423 100644 --- a/apex/fused_dense/fused_dense.py +++ b/apex/fused_dense/fused_dense.py @@ -128,6 +128,21 @@ def __init__(self, in_features, intermediate_features, out_features, bias=True): self.bias1 = nn.Parameter(torch.randn(intermediate_features)) self.weight2 = nn.Parameter(torch.randn(out_features, intermediate_features)) self.bias2 = nn.Parameter(torch.randn(out_features)) + self.reset_parameters() + + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) + if self.bias1 is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias1, -bound, bound) + if self.bias2 is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias2, -bound, bound) + def forward(self, input): return fused_dense_gelu_dense_function(input, self.weight1, self.bias1, self.weight2, self.bias2) diff --git a/tests/L0/run_fused_dense/test_gelu.py b/tests/L0/run_fused_dense/test_gelu.py index 9153bd54c..913fec7ab 100644 --- a/tests/L0/run_fused_dense/test_gelu.py +++ b/tests/L0/run_fused_dense/test_gelu.py @@ -7,6 +7,8 @@ class FusedDenseGeluDenseTest(unittest.TestCase): def test_fused_dense_gelu_dense(self) : + seed = 0 + torch.manual_seed(seed) batch_size = 4 in_features = 3 intermediate_features = 3 @@ -16,7 +18,7 @@ def test_fused_dense_gelu_dense(self) : # tst_dtype = torch.float8_e5m2 tst_dtype = torch.float16 - I = torch.randn(batch_size, in_features, dtype=tst_dtype, device='cuda') + I = torch.randn(batch_size, in_features, dtype=tst_dtype, device='cuda').requires_grad_(True) denseGelu = fused_dense.FusedDenseGeluDense(in_features, intermediate_features, out_features) denseGelu.to(dtype=tst_dtype) @@ -28,10 +30,11 @@ def test_fused_dense_gelu_dense(self) : W2 = denseGelu.weight2 b2 = denseGelu.bias2 + y_tst = denseGelu(I.clone().detach().requires_grad_(True)) + C1 = torch.matmul(I, W1.t())+b1 gelu_output = F.gelu(C1) y_ref = torch.matmul(gelu_output, W2.t())+b2 - y_tst = denseGelu(I) torch.testing.assert_close(y_ref, y_tst, atol=1e-3, rtol=1e-3, equal_nan=True)