forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sum_test.py
45 lines (36 loc) · 1.2 KB
/
sum_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for sum reduction operator."""
# Configs for PT add operator
sum_configs = op_bench.cross_product_configs(
R=[64, 256], # Length of reduced dimension
V=[32, 512], # Length of other dimension
dim=[0, 1],
contiguous=[True, False],
device=['cpu', 'cuda'],
tags=['short']
) + op_bench.cross_product_configs(
R=[1024, 8192],
V=[512, 1024],
dim=[0, 1],
contiguous=[True, False],
device=['cpu', 'cuda'],
tags=['long']
)
class SumBenchmark(op_bench.TorchBenchmarkBase):
def init(self, R, V, dim, contiguous, device):
shape = (R, V) if dim == 0 else (V, R)
tensor = torch.rand(shape, device=device)
if not contiguous:
storage = torch.empty([s * 2 for s in shape], device=device)
storage[::2, ::2] = tensor
self.input_tensor = storage[::2, ::2]
else:
self.input_tensor = tensor
self.dim = dim
self.set_module_name("sum")
def forward(self):
return self.input_tensor.sum(dim=self.dim)
op_bench.generate_pt_test(sum_configs, SumBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()