Skip to content

Commit

Permalink
Clean up layer norm tests (#418)
Browse files Browse the repository at this point in the history
* Bug fix for non-affine layer-norm + add backward unit test

* clean up tests and add tests for a large batch
  • Loading branch information
ngimel authored and mcarilli committed Aug 6, 2019
1 parent 37795aa commit 3ef01fa
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 30 deletions.
10 changes: 6 additions & 4 deletions csrc/layer_norm_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -795,11 +795,13 @@ void cuda_layer_norm_gradient(
invvar->data<accscalar_t>(),
input,
n1,n2,
gamma->data<scalar_t_0>(),
beta->data<scalar_t_0>(),
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma != NULL ? gamma->data<scalar_t_0>() : NULL,
gamma != NULL ? beta->data<scalar_t_0>() : NULL,
epsilon,
grad_input->data<scalar_t_0>(),
grad_gamma->data<scalar_t_0>(),
grad_beta->data<scalar_t_0>());
gamma != NULL ? grad_gamma->data<scalar_t_0>() : NULL,
gamma != NULL ? grad_beta->data<scalar_t_0>() : NULL);
)
}
53 changes: 27 additions & 26 deletions tests/L0/run_fused_layer_norm/test_fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,39 @@

import torch
import apex
from torch.autograd import Variable


class TestFusedLayerNorm(unittest.TestCase):
def setUp(self):
self.module = apex.normalization.FusedLayerNorm(normalized_shape=[32, 64], elementwise_affine=False)
self.input_ = torch.randn(16, 32, 64)
# bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one
self.module_cpu_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cpu()
self.module_cuda_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=False).cuda()

def _test_same_output(self, batch_size):
torch.cuda.manual_seed(42)

def forward_cpu(self, input_):
self.module.cpu()
return self.module(input_.cpu())

def forward_cuda(self, input_):
self.module.cuda()
return self.module(input_.cuda())

def test_forward_cuda(self):
out_ = self.forward_cuda(self.input_)
assert out_.is_cuda == True

def test_forward_cpu(self):
out_ = self.forward_cpu(self.input_)
assert out_.is_cuda == False

def test_same_output(self):
out_cpu = self.forward_cpu(self.input_)
out_cuda = self.forward_cuda(self.input_)
torch.testing.assert_allclose(out_cpu, out_cuda.cpu())
self.input_ = torch.randn((batch_size, *self.module_cpu_.normalized_shape), device="cpu").requires_grad_(True)
self.input_cuda_ = self.input_.cuda().detach().requires_grad_(True)
out_cpu_ = self.module_cpu_(self.input_)
gO = torch.rand_like(out_cpu_)
out_cpu_.backward(gO)
out_cuda_ = self.module_cuda_(self.input_cuda_)
gO = gO.cuda()
out_cuda_.backward(gO)
assert out_cpu_.is_cuda == False
assert out_cuda_.is_cuda == True
torch.testing.assert_allclose(out_cpu_, out_cuda_.cpu())
torch.testing.assert_allclose(self.input_.grad, self.input_cuda_.grad.cpu())

def test_layer_norm(self):
self._test_same_output(16)

def test_large_batch(self):
self._test_same_output(65536)


class TestFusedLayerNormElemWise(TestFusedLayerNorm):
def setUp(self):
self.module = apex.normalization.FusedLayerNorm(normalized_shape=[32, 64], elementwise_affine=True)
self.input_ = torch.randn(16, 32, 64)
torch.cuda.manual_seed(42)
self.module_cpu_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=True).cpu()
self.module_cuda_ = apex.normalization.FusedLayerNorm(normalized_shape=[32, 16], elementwise_affine=True).cuda()

0 comments on commit 3ef01fa

Please sign in to comment.