1919from torch .testing ._internal .common_utils import skip_but_pass_in_sandcastle_if
2020import torch .nn .functional as F
2121
22+ torch .set_default_dtype (torch .double )
23+
2224NO_NCCL = not hasattr (torch .distributed , "ProcessGroupNCCL" )
2325
2426# batched grad doesn't support data parallel
@@ -38,11 +40,11 @@ def __init__(self, t):
3840 def forward (self , x ):
3941 return x * self .t_rg + self .t_not_rg
4042
41- m = TestModule (torch .randn (100 , device = 'cuda' , requires_grad = True , dtype = torch . double ))
43+ m = TestModule (torch .randn (100 , device = 'cuda' , requires_grad = True ))
4244 self .assertTrue (m .t_rg .requires_grad )
4345
4446 dpm = nn .DataParallel (m , [0 , 1 ])
45- inp = torch .randn (2 , 100 , device = 'cuda' , dtype = torch . double )
47+ inp = torch .randn (2 , 100 , device = 'cuda' )
4648
4749 def fn (t ):
4850 return dpm (inp )
@@ -510,11 +512,11 @@ def _test_scatter(self, tensor):
510512
511513 @skip_but_pass_in_sandcastle_if (not TEST_MULTIGPU , "multi-GPU not supported" )
512514 def test_scatter_cpu (self ):
513- self ._test_scatter (torch .randn ((4 , 4 ), dtype = torch . double ))
515+ self ._test_scatter (torch .randn ((4 , 4 )))
514516
515517 @skip_but_pass_in_sandcastle_if (not TEST_MULTIGPU , "multi-GPU not supported" )
516518 def test_scatter_gpu (self ):
517- self ._test_scatter (torch .randn ((4 , 4 ), dtype = torch . double ).cuda ())
519+ self ._test_scatter (torch .randn ((4 , 4 )).cuda ())
518520
519521 @skip_but_pass_in_sandcastle_if (not TEST_MULTIGPU , "At least 2 CUDA GPUS needed" )
520522 @skip_but_pass_in_sandcastle_if (NO_NCCL , "NCCL needed" )
@@ -537,8 +539,8 @@ def forward(self, x):
537539
538540 def _test_gather (self , output_device ):
539541 inputs = (
540- torch .randn (2 , 4 , device = 'cuda:0' , requires_grad = True , dtype = torch . double ),
541- torch .randn (2 , 4 , device = 'cuda:1' , requires_grad = True , dtype = torch . double ),
542+ torch .randn (2 , 4 , device = 'cuda:0' , requires_grad = True ),
543+ torch .randn (2 , 4 , device = 'cuda:1' , requires_grad = True ),
542544 )
543545 result = dp .gather (inputs , output_device )
544546 self .assertEqual (result .size (), torch .Size ([4 , 4 ]))
@@ -548,7 +550,7 @@ def _test_gather(self, output_device):
548550 self .assertEqual (result .get_device (), output_device )
549551 else :
550552 self .assertFalse (result .is_cuda )
551- grad = torch .randn ((4 , 4 ), dtype = torch . double )
553+ grad = torch .randn ((4 , 4 ))
552554 if output_device != - 1 :
553555 grad = grad .cuda (output_device )
554556 result .backward (grad )
@@ -558,8 +560,8 @@ def _test_gather(self, output_device):
558560
559561 # test scalar inputs, should stack into a vector in this case
560562 inputs = (
561- torch .randn ((), device = 'cuda:0' , requires_grad = True , dtype = torch . double ),
562- torch .randn ((), device = 'cuda:1' , requires_grad = True , dtype = torch . double ),
563+ torch .randn ((), device = 'cuda:0' , requires_grad = True ),
564+ torch .randn ((), device = 'cuda:1' , requires_grad = True ),
563565 )
564566 result = dp .gather (inputs , output_device )
565567 self .assertEqual (result .size (), torch .Size ([2 ]))
@@ -569,7 +571,7 @@ def _test_gather(self, output_device):
569571 self .assertEqual (result .get_device (), output_device )
570572 else :
571573 self .assertFalse (result .is_cuda )
572- grad = torch .randn (2 , dtype = torch . double )
574+ grad = torch .randn (2 )
573575 if output_device != - 1 :
574576 grad = grad .cuda (output_device )
575577 result .backward (grad )
0 commit comments