Skip to content

Commit 161ea46

Browse files
Revert "Remove remaining global set_default_dtype calls from tests (pytorch#107246)"
This reverts commit aa8ea1d. Reverted pytorch#107246 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#107246 (comment)))
1 parent c68d0a7 commit 161ea46

17 files changed

+874
-900
lines changed

test/distributed/test_data_parallel.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from torch.testing._internal.common_utils import skip_but_pass_in_sandcastle_if
2020
import torch.nn.functional as F
2121

22+
torch.set_default_dtype(torch.double)
23+
2224
NO_NCCL = not hasattr(torch.distributed, "ProcessGroupNCCL")
2325

2426
# batched grad doesn't support data parallel
@@ -38,11 +40,11 @@ def __init__(self, t):
3840
def forward(self, x):
3941
return x * self.t_rg + self.t_not_rg
4042

41-
m = TestModule(torch.randn(100, device='cuda', requires_grad=True, dtype=torch.double))
43+
m = TestModule(torch.randn(100, device='cuda', requires_grad=True))
4244
self.assertTrue(m.t_rg.requires_grad)
4345

4446
dpm = nn.DataParallel(m, [0, 1])
45-
inp = torch.randn(2, 100, device='cuda', dtype=torch.double)
47+
inp = torch.randn(2, 100, device='cuda')
4648

4749
def fn(t):
4850
return dpm(inp)
@@ -510,11 +512,11 @@ def _test_scatter(self, tensor):
510512

511513
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
512514
def test_scatter_cpu(self):
513-
self._test_scatter(torch.randn((4, 4), dtype=torch.double))
515+
self._test_scatter(torch.randn((4, 4)))
514516

515517
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "multi-GPU not supported")
516518
def test_scatter_gpu(self):
517-
self._test_scatter(torch.randn((4, 4), dtype=torch.double).cuda())
519+
self._test_scatter(torch.randn((4, 4)).cuda())
518520

519521
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed")
520522
@skip_but_pass_in_sandcastle_if(NO_NCCL, "NCCL needed")
@@ -537,8 +539,8 @@ def forward(self, x):
537539

538540
def _test_gather(self, output_device):
539541
inputs = (
540-
torch.randn(2, 4, device='cuda:0', requires_grad=True, dtype=torch.double),
541-
torch.randn(2, 4, device='cuda:1', requires_grad=True, dtype=torch.double),
542+
torch.randn(2, 4, device='cuda:0', requires_grad=True),
543+
torch.randn(2, 4, device='cuda:1', requires_grad=True),
542544
)
543545
result = dp.gather(inputs, output_device)
544546
self.assertEqual(result.size(), torch.Size([4, 4]))
@@ -548,7 +550,7 @@ def _test_gather(self, output_device):
548550
self.assertEqual(result.get_device(), output_device)
549551
else:
550552
self.assertFalse(result.is_cuda)
551-
grad = torch.randn((4, 4), dtype=torch.double)
553+
grad = torch.randn((4, 4))
552554
if output_device != -1:
553555
grad = grad.cuda(output_device)
554556
result.backward(grad)
@@ -558,8 +560,8 @@ def _test_gather(self, output_device):
558560

559561
# test scalar inputs, should stack into a vector in this case
560562
inputs = (
561-
torch.randn((), device='cuda:0', requires_grad=True, dtype=torch.double),
562-
torch.randn((), device='cuda:1', requires_grad=True, dtype=torch.double),
563+
torch.randn((), device='cuda:0', requires_grad=True),
564+
torch.randn((), device='cuda:1', requires_grad=True),
563565
)
564566
result = dp.gather(inputs, output_device)
565567
self.assertEqual(result.size(), torch.Size([2]))
@@ -569,7 +571,7 @@ def _test_gather(self, output_device):
569571
self.assertEqual(result.get_device(), output_device)
570572
else:
571573
self.assertFalse(result.is_cuda)
572-
grad = torch.randn(2, dtype=torch.double)
574+
grad = torch.randn(2)
573575
if output_device != -1:
574576
grad = grad.cuda(output_device)
575577
result.backward(grad)

0 commit comments

Comments
 (0)