Skip to content

Commit

Permalink
Update test: cast arange to float to avoid divsion problem
Browse files Browse the repository at this point in the history
  • Loading branch information
bichengying committed Nov 13, 2020
1 parent 59b98a6 commit ce5df1a
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions test/torch_win_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,8 @@ def test_win_put_with_varied_tensor_elements(self):
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
tensor = torch.FloatTensor(*([DIM_SIZE] * dim)).fill_(1).mul_(rank)
base_tensor = torch.arange(DIM_SIZE**dim).view_as(tensor)/1000
base_tensor = torch.arange(
DIM_SIZE**dim, dtype=torch.float32).view_as(tensor).div(1000)
tensor = self.cast_and_place(tensor, dtype)
base_tensor = self.cast_and_place(base_tensor, dtype)
tensor = tensor + base_tensor
Expand Down Expand Up @@ -461,7 +462,8 @@ def test_win_accumulate_with_varied_tensor_elements(self):
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
tensor = torch.FloatTensor(*([DIM_SIZE] * dim)).fill_(1).mul_(rank)
base_tensor = torch.arange(DIM_SIZE**dim).view_as(tensor)/1000
base_tensor = torch.arange(
DIM_SIZE**dim, dtype=torch.float32).view_as(tensor).div(1000)
tensor = self.cast_and_place(tensor, dtype)
base_tensor = self.cast_and_place(base_tensor, dtype)
tensor = tensor + base_tensor
Expand Down Expand Up @@ -627,7 +629,8 @@ def test_win_get_with_varied_tensor_elements(self):
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
tensor = torch.FloatTensor(*([DIM_SIZE] * dim)).fill_(1).mul_(rank)
base_tensor = torch.arange(DIM_SIZE**dim).view_as(tensor)/1000
base_tensor = torch.arange(
DIM_SIZE**dim, dtype=torch.float32).view_as(tensor).div(1000)
tensor = self.cast_and_place(tensor, dtype)
base_tensor = self.cast_and_place(base_tensor, dtype)
tensor = tensor + base_tensor
Expand Down

0 comments on commit ce5df1a

Please sign in to comment.