diff --git a/README.md b/README.md index 42839a75..80b14d7b 100644 --- a/README.md +++ b/README.md @@ -125,7 +125,17 @@ print(torch.norm(Z)) ``` **Note**: you don't need Pytorch geometric to use our kernels. When `deterministic=False`, the `sender` and `receiver` indices can have -arbitrary order. +arbitrary order. + +**New:** If you're working in FP32 precision and want +higher accuracy during graph convolution, we offer a Kahan +summation variant of our deterministic algorithm: + +```python +tp_conv_kahan = oeq.TensorProductConv(problem, torch_op=True, deterministic=True, kahan=True) +Z = tp_conv_kahan.forward(X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm) +print(torch.norm(Z)) +``` ## Installation We currently support Linux systems only. @@ -172,6 +182,7 @@ python tests/benchmark.py -o outputs/uvu uvu --plot python tests/benchmark.py -o outputs/uvw uvw --plot python tests/benchmark.py -o outputs/roofline roofline --plot python tests/benchmark.py -o outputs/conv conv --plot --data data/molecular_structures +python tests/benchmark.py -o outputs/kahan_conv kahan_conv --data data/molecular_structures/ ``` If your GPU has limited memory, you might want to try diff --git a/openequivariance/benchmark/ConvBenchmarkSuite.py b/openequivariance/benchmark/ConvBenchmarkSuite.py index cca53f6e..35b17600 100644 --- a/openequivariance/benchmark/ConvBenchmarkSuite.py +++ b/openequivariance/benchmark/ConvBenchmarkSuite.py @@ -24,22 +24,22 @@ def __init__(self, configs, num_warmup = 10, num_iter = 30, reference_impl=None, - torch_op=True, test_name=None, - prng_seed = 12345): + prng_seed = 12345, + correctness_threshold = 1e-5): self.configs = configs self.num_warmup = num_warmup self.num_iter = num_iter self.reference_impl = reference_impl self.prng_seed = 12345 - self.correctness_threshold = 1e-5 - self.torch_op = torch_op + self.correctness_threshold = correctness_threshold self.exp_count = 0 self.test_name = test_name self.millis_since_epoch = round(time.time() * 1000) - def run(self, graph, implementations, direction, output_folder=None, correctness=True, double_backward_correctness=False, benchmark=True): + def run(self, graph, implementations, direction, output_folder=None, + correctness=True, benchmark=True, high_precision_ref=False): if output_folder is None: if oeq._check_package_editable(): output_folder = oeq._editable_install_output_path / f"{self.millis_since_epoch}" @@ -65,20 +65,15 @@ def run(self, graph, implementations, direction, output_folder=None, correctness for impl in implementations: tc_name = f"{config}, {impl.name()}" logger.info(f'Starting {tc_name}, graph {graph.name}, {direction}') - conv = impl(config, torch_op=self.torch_op) - - if double_backward_correctness: - double_backward_correctness = conv.test_correctness_double_backward(self.graph, - thresh=self.correctness_threshold, - prng_seed=self.prng_seed, - reference_implementation=self.reference_impl) + conv = impl(config) if direction == "forward": if correctness: correctness = conv.test_correctness_forward(graph, thresh=self.correctness_threshold, prng_seed=self.prng_seed, - reference_implementation=self.reference_impl) + reference_implementation=self.reference_impl, + high_precision_ref=high_precision_ref) if benchmark: benchmark = conv.benchmark_forward(self.num_warmup, @@ -90,23 +85,33 @@ def run(self, graph, implementations, direction, output_folder=None, correctness correctness = conv.test_correctness_backward(graph, thresh=self.correctness_threshold, prng_seed=self.prng_seed, - reference_implementation=self.reference_impl) + reference_implementation=self.reference_impl, + high_precision_ref=high_precision_ref) if benchmark: benchmark = conv.benchmark_backward(self.num_warmup, self.num_iter, graph, prng_seed=12345) + + if direction == "double_backward": + if correctness: + correctness = conv.test_correctness_double_backward(self.graph, + thresh=self.correctness_threshold, + prng_seed=self.prng_seed, + reference_implementation=self.reference_impl, + high_precision_ref=high_precision_ref) + + assert not benchmark result = { "config": str(config), "irrep_dtype": str(config.irrep_dtype), "weight_dtype": str(config.weight_dtype), - "torch_overhead_included": self.torch_op, + "torch_overhead_included": conv.torch_op, "direction": direction, "graph": graph.name, "name": impl.name(), "correctness": correctness, - "benchmark": benchmark, - "double_backward_correctness": double_backward_correctness + "benchmark": benchmark } fname = pathlib.Path(f"{output_folder}/{self.exp_count}_{impl.name()}_{graph.name}.json") diff --git a/openequivariance/implementations/ComputationSchedule.py b/openequivariance/implementations/ComputationSchedule.py index ebb513dd..54ae5d9c 100644 --- a/openequivariance/implementations/ComputationSchedule.py +++ b/openequivariance/implementations/ComputationSchedule.py @@ -5,6 +5,11 @@ from openequivariance.implementations.TensorProductBase import * logger = getLogger() +class SMEMCapacityException(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + class IrrepMapping: ''' Maps irreps from a source to a destination set. @@ -104,7 +109,7 @@ def create_schedule_case2(instructions, memory_per_warp, calculate_smem, directi segments.append((cL1, cL2, cL3, cinst)) cL3, cinst = set(), [] else: - raise Exception(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!") + raise SMEMCapacityException(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!") else: cL3.add(w) cinst.append(inst_idx) @@ -130,7 +135,7 @@ def create_schedule_case3(instructions, memory_per_warp, calculate_smem, directi segments.append((cL1, cL2, cL3, cinst)) cL1, cL2, cL3, cinst = set(), set(), set(), [] else: - raise Exception(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!") + raise SMEMCapacityException(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!") else: cL1.add(u) cL2.add(v) @@ -245,10 +250,15 @@ def __init__(self, weight_dtype, include_scratch=False, stream_weights=False, - schedule_type=2): + schedule_type=2, + kahan=False): ''' smem_limit: size of available shared memory in bytes ''' + self.kahan = kahan + if kahan: + assert irrep_dtype == weight_dtype == np.float32 + # Note: does not work with variances for irreps; easy to add that in self.total_warps = warps_per_block * block_count @@ -288,10 +298,16 @@ def calculate_forward_smem(L1_set, L2_set, L3_set, inst_idxs): "L1": {"size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, "L2": {"size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, "L3": {"size": sum([self.L3[el].dim for el in L3_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, + "L3_kahan": {"size": 0, "dtype": self.irrep_dtype_cstr}, "weights": {"size": 0, "dtype": self.weight_dtype_cstr}, "scratch": {"size": 0, "dtype": self.weight_dtype_cstr} } + if kahan: + smem["L3_kahan"]["size"] = smem["L3"]["size"] + else: + smem.pop("L3_kahan") + weights_smem = 0 for inst_idx in inst_idxs: inst = self.new_instructions[inst_idx] @@ -325,6 +341,7 @@ def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs, smem = { "L1": {"size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, "L1_grad": {"size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, + "L1_kahan": {"size": 0, "dtype": self.irrep_dtype_cstr}, "L2": {"size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, "L2_grad": {"size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, "L3_grad": {"size": sum([self.L3[el].dim for el in L3_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, @@ -333,6 +350,11 @@ def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs, "scratch": {"size": 0, "dtype": self.weight_dtype_cstr} } + if kahan: + smem["L1_kahan"]["size"] = smem["L1"]["size"] + else: + smem.pop("L1_kahan") + if L2_dgrad: smem["L2_dgrad"] = {"size": smem["L2"]["size"], "dtype": self.irrep_dtype_cstr} @@ -376,11 +398,11 @@ def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs, schedule2_succeeded = False try: if schedule_type != 2: - raise Exception("Asked for schedule case 3.") + raise SMEMCapacityException("Asked for schedule case 3.") self.segments = create_schedule_case2(self.new_instructions, self.memory_per_warp, calculate_smem, direction) logger.info(f"{direction.title()} case 2 scheduling succeeded with {len(self.segments)} segments.") schedule2_succeeded = True - except Exception as e: + except SMEMCapacityException as e: self.segments = create_schedule_case3(self.new_instructions, self.memory_per_warp, calculate_smem, direction) logger.info(f"{direction.title()} case 3 scheduling succeeded with {len(self.segments)} segments.") diff --git a/openequivariance/implementations/convolution/ConvolutionBase.py b/openequivariance/implementations/convolution/ConvolutionBase.py index a806850e..b645f896 100644 --- a/openequivariance/implementations/convolution/ConvolutionBase.py +++ b/openequivariance/implementations/convolution/ConvolutionBase.py @@ -1,3 +1,4 @@ +import copy import numpy as np import numpy.linalg as la from openequivariance.extlib import * @@ -11,7 +12,6 @@ def flops_data_per_tp(config, direction): ''' Assumes all interactions are "uvu" for now - Returns (flops_per_tp, data_per_tp, nnz) ''' bytes_per_word = np.dtype(config.irrep_dtype).itemsize @@ -97,7 +97,7 @@ def allocate_workspace(self, size_bytes): else: self.workspace_buffer = DeviceBuffer(size_bytes) self.workspace_ptr = self.workspace_buffer.data_ptr() - logger.info(f"Deterministic Convolution requires {size_bytes // 1000000}MB of workspace.") + logger.info(f"Convolution requires {size_bytes // 1000000}MB of workspace.") @staticmethod def name(): @@ -187,48 +187,37 @@ def backward_cpu(self, def test_correctness_forward(self, graph, thresh, prng_seed, reference_implementation=None, - check_reproducible=True, torch_op=False): - L1, L2, L3 = self.L1, self.L2, self.L3 + check_reproducible=True, high_precision_ref=False): if reference_implementation is None: from openequivariance.implementations.convolution.E3NNConv import E3NNConv reference_implementation = E3NNConv - result = { - "thresh": thresh - } + result = {"thresh": thresh} in1, in2, weights, out = get_random_buffers_forward_conv(self.config, graph.node_count, graph.nnz, prng_seed) + ref_in1, ref_in2, ref_weights, ref_out = [buf.copy() for buf in [in1, in2, weights, out]] - ref_tp = reference_implementation(self.config) - ref_out = out.copy() - - if ref_tp.torch_op: - ref_out_torch = None - if not ref_tp.deterministic: - ref_out_torch = ref_tp.forward( - L1_in=torch.tensor(in1, device='cuda'), - L2_in=torch.tensor(in2, device='cuda'), - weights=torch.tensor(weights, device='cuda'), - rows = torch.tensor(graph.rows, device='cuda'), - cols = torch.tensor(graph.cols, device='cuda')) - else: - ref_out_torch = ref_tp.forward( - torch.tensor(in1, device='cuda'), - torch.tensor(in2, device='cuda'), - torch.tensor(weights, device='cuda'), - torch.tensor(graph.rows, device='cuda'), - torch.tensor(graph.cols, device='cuda'), - torch.tensor(graph.transpose_perm, device='cuda')) - ref_out[:] = ref_out_torch.cpu().numpy() - else: - ref_tp.forward_cpu( - L1_in=in1.copy(), - L2_in=in2.copy(), - weights=weights.copy(), - L3_out=ref_out, - graph=graph) + reference_config = self.config + if high_precision_ref: + reference_config = copy.deepcopy(self.config) + reference_config.irrep_dtype = np.float64 + reference_config.weight_dtype = np.float64 + ref_in1, ref_in2, ref_weights, ref_out = [np.array(el, dtype=np.float64) + for el in [ref_in1, ref_in2, ref_weights, ref_out]] + + args = {"L1_in": ref_in1, "L2_in": ref_in2, "weights": ref_weights, + "rows": graph.rows, "cols": graph.cols} + + ref_tp = reference_implementation(reference_config) + if ref_tp.deterministic: + args["transpose_perm"] = graph.transpose_perm + + for key in args: + args[key] = torch.tensor(args[key], device='cuda') + + ref_out[:] = ref_tp.forward(**args).cpu().numpy() test_out = out.copy() self.forward_cpu( @@ -458,32 +447,36 @@ def calculate_bench_stats(self, direction, ops_per_tp, data_per_tp, time_millis, logger.info(f"{bcolors.OKCYAN}Avg. Bandwidth: {bcolors.ENDC} {bcolors.OKGREEN}{np.mean(bandwidth_gbps):.2f} ± {np.std(bandwidth_gbps):.2f} GBPs{bcolors.ENDC}") return result - def test_correctness_backward(self, graph, thresh, prng_seed, reference_implementation=None): + def test_correctness_backward(self, graph, thresh, prng_seed, reference_implementation=None, high_precision_ref=False): L1, L2, L3 = self.L1, self.L2, self.L3 if reference_implementation is None: from openequivariance.implementations.convolution.E3NNConv import E3NNConv reference_implementation = E3NNConv - result = { - "thresh": thresh - } + result = {"thresh": thresh} - in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = get_random_buffers_backward_conv(self.config, graph.node_count, graph.nnz, prng_seed) + buffers = get_random_buffers_backward_conv(self.config, graph.node_count, graph.nnz, prng_seed) + reference_buffers = [buf.copy() for buf in buffers] + reference_problem = self.config - ref_tp = reference_implementation(self.config) + if high_precision_ref: + reference_problem = copy.deepcopy(self.config) + reference_problem.irrep_dtype = np.float64 + reference_problem.weight_dtype = np.float64 + reference_buffers = [np.array(el, dtype=np.float64) for el in reference_buffers] - ref_weights_grad = weights_grad.copy() - ref_in1_grad = in1_grad.copy() - ref_in2_grad = in2_grad.copy() + ref_tp = reference_implementation(reference_problem) + in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = buffers + ref_in1, ref_in2, ref_out_grad, ref_weights, ref_weights_grad, ref_in1_grad, ref_in2_grad = reference_buffers ref_tp.backward_cpu( - L1_in=in1.copy(), + L1_in=ref_in1, L1_grad=ref_in1_grad, - L2_in=in2.copy(), + L2_in=ref_in2, L2_grad=ref_in2_grad, - L3_grad=out_grad.copy(), - weights=weights.copy(), + L3_grad=ref_out_grad, + weights=ref_weights, weights_grad=ref_weights_grad, graph=graph) @@ -497,8 +490,8 @@ def test_correctness_backward(self, graph, thresh, prng_seed, reference_implemen L1_grad=test_in1_grad, L2_in=in2.copy(), L2_grad=test_in2_grad, - L3_grad=out_grad.copy(), - weights=weights.copy(), + L3_grad=out_grad.copy(), + weights=weights.copy(), weights_grad=test_weights_grad, graph=graph) @@ -510,13 +503,13 @@ def test_correctness_backward(self, graph, thresh, prng_seed, reference_implemen return result - def test_correctness_double_backward(self, graph, thresh, prng_seed, reference_implementation=None): + def test_correctness_double_backward(self, graph, thresh, prng_seed, reference_implementation=None, high_precision_ref=False): global torch import torch assert(self.torch_op) + buffers = get_random_buffers_backward_conv(self.config, graph.node_count, graph.nnz, prng_seed) - in1, in2, out_grad, weights, _, _, _ = get_random_buffers_backward_conv(self.config, graph.node_count, graph.nnz, prng_seed) rng = np.random.default_rng(seed=prng_seed * 2) dummy_grad_value = rng.standard_normal(1)[0] @@ -524,11 +517,22 @@ def test_correctness_double_backward(self, graph, thresh, prng_seed, reference_i from openequivariance.implementations.convolution.E3NNConv import E3NNConv reference_implementation = E3NNConv - reference_tp = reference_implementation(self.config, torch_op=True) + reference_problem = self.config + if high_precision_ref: + reference_problem = copy.deepcopy(self.config) + reference_problem.irrep_dtype = np.float64 + reference_problem.weight_dtype = np.float64 + + reference_tp = reference_implementation(reference_problem, torch_op=True) result = {"thresh": thresh} tensors = [] for i, tp in enumerate([self, reference_tp]): + in1, in2, out_grad, weights, _, _, _ = [buf.copy() for buf in buffers] + + if i == 1 and high_precision_ref: + in1, in2, out_grad, weights, _, _, _ = [np.array(el, dtype=np.float64) for el in buffers] + in1_torch = torch.tensor(in1, device='cuda', requires_grad=True) in2_torch = torch.tensor(in2, device='cuda', requires_grad=True) @@ -566,10 +570,10 @@ def test_correctness_double_backward(self, graph, thresh, prng_seed, reference_i self.reorder_weights_oeq_to_e3nn(weights_grad_copy, weights_grad, not self.config.shared_weights) tensors.append(( - out_grad_torch.grad.detach().cpu().numpy(), - in1_torch.grad.detach().cpu().numpy(), - in2_torch.grad.detach().cpu().numpy(), - weights_grad + out_grad_torch.grad.detach().cpu().numpy().copy(), + in1_torch.grad.detach().cpu().numpy().copy(), + in2_torch.grad.detach().cpu().numpy().copy(), + weights_grad.copy() )) for name, to_check, ground_truth in [ diff --git a/openequivariance/implementations/convolution/LoopUnrollConv.py b/openequivariance/implementations/convolution/LoopUnrollConv.py index 5d6b286c..676397a0 100644 --- a/openequivariance/implementations/convolution/LoopUnrollConv.py +++ b/openequivariance/implementations/convolution/LoopUnrollConv.py @@ -1,5 +1,5 @@ from openequivariance.implementations.convolution.ConvolutionBase import * -from openequivariance.implementations.ComputationSchedule import ComputationSchedule +from openequivariance.implementations.ComputationSchedule import ComputationSchedule, SMEMCapacityException from openequivariance.implementations.TensorProduct import * from openequivariance.templates.jinja_utils import * from openequivariance.extlib import * @@ -8,10 +8,13 @@ class LoopUnrollConv(ConvolutionBase): def __init__(self, config, idx_dtype=np.int64, - torch_op=False, deterministic=False): + torch_op=False, deterministic=False, kahan=False): super().__init__(config, idx_dtype, torch_op, deterministic) L1, L2, L3 = self.L1, self.L2, self.L3 + if kahan: + assert deterministic + env = get_jinja_environment() template = env.get_template("loop_unroll_conv_atomic.cuh") dp = DeviceProp(0) @@ -36,7 +39,8 @@ def generate_forward_schedule(warps_per_block): schedule_type=forward_schedule_type, warp_size=dp.warpsize, include_scratch=self.is_uvw, - stream_weights=self.is_uvw) + stream_weights=self.is_uvw, + kahan=kahan) def generate_backward_schedule(warps_per_block): self.backward_schedule = ComputationSchedule(self.config, @@ -48,7 +52,8 @@ def generate_backward_schedule(warps_per_block): schedule_type=backward_schedule_type, warp_size=dp.warpsize, include_scratch=self.is_uvw, - stream_weights=self.is_uvw) + stream_weights=self.is_uvw, + kahan=kahan) def generate_double_backward_schedule(warps_per_block): self.double_backward_schedule = ComputationSchedule(self.config, @@ -61,7 +66,8 @@ def generate_double_backward_schedule(warps_per_block): weight_dtype = config.weight_dtype, include_scratch=self.is_uvw, stream_weights=self.is_uvw, - schedule_type=3) + schedule_type=3, + kahan=kahan) scheduler_generators = [generate_forward_schedule, generate_backward_schedule, generate_double_backward_schedule] @@ -71,10 +77,10 @@ def generate_double_backward_schedule(warps_per_block): try: generate_schedule(warp_count) break - except Exception as e: + except SMEMCapacityException as e: warp_count -= 1 if warp_count == 0: - raise RuntimeError("Tensor product schedule generation failed, shared memory inadequate!") + raise SMEMCapacityException("Tensor product schedule generation failed, shared memory inadequate!") if not deterministic: for segment in self.forward_schedule.segments: @@ -148,10 +154,11 @@ def generate_double_backward_schedule(warps_per_block): #with open("scratch.txt", "w") as f: # f.write(self.jit_kernel) - self.reorder_weights_e3nn_to_oeq = lambda input, output, has_batch_dim: \ - self.forward_schedule.reorder_weights(input, output, "forward", has_batch_dim) - self.reorder_weights_oeq_to_e3nn = lambda input, output, has_batch_dim: \ - self.forward_schedule.reorder_weights(input, output, "backward", has_batch_dim) + def reorder_weights_e3nn_to_oeq(self, input, output, has_batch_dim): + return self.forward_schedule.reorder_weights(input, output, "forward", has_batch_dim) + + def reorder_weights_oeq_to_e3nn(self, input, output, has_batch_dim): + return self.forward_schedule.reorder_weights(input, output, "backward", has_batch_dim) @staticmethod def name(): diff --git a/openequivariance/implementations/convolution/TensorProductConv.py b/openequivariance/implementations/convolution/TensorProductConv.py index af1ddd32..dca2a479 100644 --- a/openequivariance/implementations/convolution/TensorProductConv.py +++ b/openequivariance/implementations/convolution/TensorProductConv.py @@ -4,40 +4,50 @@ import numpy as np from typing import Optional +import types class TensorProductConv(torch.nn.Module, LoopUnrollConv): ''' PyTorch-specialized dispatcher class. ''' - def __init__(self, config, idx_dtype=np.int64, torch_op=True, deterministic=False): + def __init__(self, config, idx_dtype=np.int64, torch_op=True, deterministic=False, kahan=False): torch.nn.Module.__init__(self) LoopUnrollConv.__init__(self, config, idx_dtype=np.int64, - torch_op=torch_op, deterministic=deterministic) + torch_op=torch_op, deterministic=deterministic, kahan=kahan) self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device='cuda') self.weight_numel = self.config.weight_numel - if extlib.TORCH_COMPILE: - self.forward = self.forward_deterministic if deterministic else self.forward_atomic + if not extlib.TORCH_COMPILE: + self.forward = types.MethodType(LoopUnrollConv.forward, self) + + def forward(self, L1_in: torch.Tensor, L2_in: + torch.Tensor, W: torch.Tensor, + rows: torch.Tensor, cols: torch.Tensor, sender_perm: Optional[torch.Tensor]=None) -> torch.Tensor: + if sender_perm is None: + return torch.ops.torch_tp_jit.jit_conv_forward(self.internal, L1_in, L2_in, W, rows, cols, self.workspace_buffer, self.dummy_transpose_perm) + else: + return torch.ops.torch_tp_jit.jit_conv_forward(self.internal, L1_in, L2_in, W, rows, cols, self.workspace_buffer, sender_perm) @staticmethod def name(): return LoopUnrollConv.name() - - def forward_deterministic(self, L1_in: torch.Tensor, L2_in: - torch.Tensor, W: torch.Tensor, - rows: torch.Tensor, cols: torch.Tensor, sender_perm: torch.Tensor) -> torch.Tensor: - return torch.ops.torch_tp_jit.jit_conv_forward(self.internal, L1_in, L2_in, W, rows, cols, self.workspace_buffer, sender_perm) - - def forward_atomic(self, L1_in: torch.Tensor, L2_in: - torch.Tensor, W: torch.Tensor, - rows: torch.Tensor, cols: torch.Tensor) -> torch.Tensor: - return torch.ops.torch_tp_jit.jit_conv_forward(self.internal, L1_in, L2_in, W, rows, cols, self.workspace_buffer, self.dummy_transpose_perm) - + # ================================================================== # Reference implementations for benchmarking +class TensorProductConvKahan(TensorProductConv): + def __init__(self, config, + idx_dtype=np.int64, + torch_op=True): + super().__init__(config, idx_dtype, torch_op, deterministic=True, kahan=True) + + @staticmethod + def name(): + return "LoopUnrollConvKahan" + + class TensorProductConvDeterministic(TensorProductConv): def __init__(self, config, idx_dtype=np.int64, diff --git a/openequivariance/implementations/utils.py b/openequivariance/implementations/utils.py index f37b8384..aafb9715 100644 --- a/openequivariance/implementations/utils.py +++ b/openequivariance/implementations/utils.py @@ -54,10 +54,13 @@ def filter_and_analyze_problem(problem): Centralized function that stops unhandled problem configurations, returns a dictionary of useful information about the problem. ''' - for inst in problem.instructions: + for i, inst in enumerate(problem.instructions): assert inst.connection_mode == problem.instructions[0].connection_mode, \ f"All instructions must have the same connection mode, got {inst.connection_mode} and {problem.instructions[0].connection_mode}" + assert inst.has_weight, \ + f"All instructions must have trainable weights, got {inst.has_weight} at index {i}" + assert problem.instructions[0].connection_mode in ["uvu", "uvw"], \ f"Connection mode must be 'uvu' or 'uvw', got {problem.instructions[0].connection_mode}" diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/templates/loop_unroll_conv_det.cuh index c7cee536..5392da17 100644 --- a/openequivariance/templates/loop_unroll_conv_det.cuh +++ b/openequivariance/templates/loop_unroll_conv_det.cuh @@ -63,6 +63,31 @@ __global__ void {{name}}(void* workspace, IRREP_T* dst_ptr) { {{ generate_fixup_kernel("fixup_forward", forward_schedule.launch_config.warp_size, forward_schedule.L3.dim, forward_workspace_offset) }} +template +__device__ __forceinline__ void kahanAdd(IRREP_T* c_arr, IRREP_T* sum_arr, int lane_id) { + c_arr += lane_id; + sum_arr += lane_id; + #pragma unroll + for(int j = 0; j < ROW_LEN; j += THREADS_PER_WARP) { + if(j >= ROW_LEN - THREADS_PER_WARP) { + if(lane_id < ROW_LEN - j) { + IRREP_T y = c_arr[j]; + IRREP_T sum = sum_arr[j]; + IRREP_T t = sum + y; + c_arr[j] = y - (t - sum); + sum_arr[j] = t; + } + } + else { + IRREP_T y = c_arr[j]; + IRREP_T sum = sum_arr[j]; + IRREP_T t = sum + y; + c_arr[j] = y - (t - sum); + sum_arr[j] = t; + } + } +} + __global__ void forward( IRREP_T* L1_in, IRREP_T* L2_in, @@ -99,6 +124,12 @@ __global__ void forward( bool firstSegment = true; ROW_OPERATION({{segment.L3.dim}}, j, L3_smem[j + lane_id] = 0.0f;) + {%- set ns = namespace(L3_accum="L3_smem") %} + {%- if forward_schedule.kahan %} + ROW_OPERATION({{segment.L3.dim}}, j, L3_kahan_smem[j + lane_id] = 0.0f;) + {%- set ns.L3_accum="L3_kahan_smem" %} + {%- endif %} + for(size_t i = start; i < end; i++) { {{idx_type}} row = rows[i]; {{idx_type}} col = cols[i]; @@ -115,9 +146,13 @@ __global__ void forward( {%- endif %} __syncwarp(); - forward_loop_unroll_{{i}}(L1_smem, L2_smem, w, weights_smem, L3_smem, scratch_smem, lane_id); + forward_loop_unroll_{{i}}(L1_smem, L2_smem, w, weights_smem, {{ns.L3_accum}}, scratch_smem, lane_id); __syncwarp(); + {%- if forward_schedule.kahan %} + kahanAdd<{{segment.L3.dim}}>(L3_kahan_smem, L3_smem, lane_id); + {%- endif %} + bool changeRow = (i < end - 1) && (row != rows[i+1]); if(changeRow || i == end - 1) { @@ -179,6 +214,12 @@ __global__ void backward( bool firstSegment = true; ROW_OPERATION({{segment.L1.dim}}, j, L1_grad_smem[j + lane_id] = 0.0f;) + {%- set ns = namespace(L1_accum="L1_grad_smem") %} + {%- if backward_schedule.kahan %} + ROW_OPERATION({{segment.L1.dim}}, j, L1_kahan_smem[j + lane_id] = 0.0f;) + {%- set ns.L1_accum="L1_kahan_smem" %} + {%- endif %} + for(size_t i = start; i < end; i++) { {{idx_type}} row = rows[i]; {{idx_type}} col = cols[i]; {{idx_type}} tperm_idx = tperm[i]; @@ -214,9 +255,13 @@ __global__ void backward( __syncwarp(); backward_loop_unroll_{{i}}(L1_smem, L2_smem, w, weights_smem, L3_grad_smem, - L1_grad_smem, L2_grad_smem, wgrad, weights_grad_smem, scratch_smem, lane_id); + {{ns.L1_accum}}, L2_grad_smem, wgrad, weights_grad_smem, scratch_smem, lane_id); __syncwarp(); + {%- if backward_schedule.kahan %} + kahanAdd<{{segment.L1.dim}}>(L1_kahan_smem, L1_grad_smem, lane_id); + {%- endif %} + bool changeRow = (i < end - 1) && (col != cols[i+1]); if(changeRow || i == end - 1) { IRREP_T* dst = l1_grad_shft; @@ -274,6 +319,12 @@ __global__ void double_backward_A( bool firstSegment = true; ROW_OPERATION({{segment.L3.dim}}, j, L3_smem[j + lane_id] = 0.0f;) + {%- set ns = namespace(L3_accum="L3_smem") %} + {%- if forward_schedule.kahan %} + ROW_OPERATION({{segment.L3.dim}}, j, L3_kahan_smem[j + lane_id] = 0.0f;) + {%- set ns.L3_accum="L3_kahan_smem" %} + {%- endif %} + for(size_t i = start; i < end; i++) { unsigned {{idx_type}} row = rows[i]; unsigned {{idx_type}} col = cols[i]; @@ -316,10 +367,14 @@ __global__ void double_backward_A( } __syncwarp(); - forward_loop_unroll_{{i}}(L1_smem, L2_smem, w_buffer, weights_smem, L3_smem, scratch_smem, lane_id); + forward_loop_unroll_{{i}}(L1_smem, L2_smem, w_buffer, weights_smem, {{ns.L3_accum}}, scratch_smem, lane_id); __syncwarp(); } + {%- if forward_schedule.kahan %} + kahanAdd<{{segment.L3.dim}}>(L3_kahan_smem, L3_smem, lane_id); + {%- endif %} + bool changeRow = (i < end - 1) && (row != rows[i+1]); if(changeRow || i == end - 1) { IRREP_T* dst = l3; @@ -379,6 +434,12 @@ __global__ void double_backward_B( bool firstSegment = true; ROW_OPERATION({{segment.L1.dim}}, j, L1_grad_smem[j + lane_id] = 0.0f;) + {%- set ns = namespace(L1_accum="L1_grad_smem") %} + {%- if backward_schedule.kahan %} + ROW_OPERATION({{segment.L1.dim}}, j, L1_kahan_smem[j + lane_id] = 0.0f;) + {%- set ns.L1_accum="L1_kahan_smem" %} + {%- endif %} + for(size_t i = start; i < end; i++) { unsigned {{idx_type}} row = rows[i]; unsigned {{idx_type}} col = cols[i]; {{idx_type}} tperm_idx = tperm[i]; @@ -437,10 +498,14 @@ __global__ void double_backward_B( __syncwarp(); double_backward_loop_unroll_{{i}}(L1_smem, L2_buffer, w_buffer, weights_smem, L3_grad_smem, - L1_grad_smem, L2_grad_smem, L2_dgrad_buffer, n, wgrad, weights_grad_smem, scratch_smem, lane_id); + {{ns.L1_accum}}, L2_grad_smem, L2_dgrad_buffer, n, wgrad, weights_grad_smem, scratch_smem, lane_id); __syncwarp(); } + {%- if backward_schedule.kahan %} + kahanAdd<{{segment.L1.dim}}>(L1_kahan_smem, L1_grad_smem, lane_id); + {%- endif %} + IRREP_T* l1_grad_shft = L1_grad + col * {{schedule.L1.dim}} + lane_id; IRREP_T* l2_grad_shft = L2_grad + tperm_idx * {{schedule.L2.dim}} + lane_id; diff --git a/tests/benchmark.py b/tests/benchmark.py index f04b216e..7cabe628 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -1,7 +1,7 @@ import numpy as np import numpy.linalg as la -import itertools, logging, argparse, os, copy +import itertools, logging, argparse, os, copy, gc from pathlib import Path import urllib.request @@ -121,11 +121,7 @@ def benchmark_roofline(params): if params.plot: plot({"data_folder": data_folder}) - -def benchmark_convolution(params): - filenames = [ "covid_spike_radius3.0.pickle", - "1drf_radius6.0.pickle", - "carbon_lattice_radius6.0.pickle"] +def download_graphs(params, filenames): download_prefix = "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" if not Path(params.data).exists(): @@ -140,10 +136,19 @@ def benchmark_convolution(params): exit(1) else: logging.info(f"Downloading {download_prefix + filename}...") - urllib.request.urlretrieve(download_prefix + filename, target_path) - + urllib.request.urlretrieve(download_prefix + filename, target_path) + graphs.append(load_graph(str(target_path))) + return graphs + +def benchmark_convolution(params): + filenames = [ "covid_spike_radius3.0.pickle", + "1drf_radius6.0.pickle", + "carbon_lattice_radius6.0.pickle"] + + graphs = download_graphs(params, filenames) + if not params.disable_bench: configs = [ ChannelwiseTPP("128x0e+128x1o+128x2e", "1x0e+1x1o+1x2e+1x3o", @@ -156,7 +161,7 @@ def benchmark_convolution(params): configs[1].irrep_dtype = np.float64 configs[1].weight_dtype = np.float64 - bench = ConvBenchmarkSuite(configs, torch_op=True, test_name="convolution") + bench = ConvBenchmarkSuite(configs, test_name="convolution") implementations = [ TensorProductConvScatterSum, CUEConv, @@ -177,7 +182,6 @@ def benchmark_convolution(params): graph = graph, direction=direction, correctness=False, - double_backward_correctness=False, benchmark=True, output_folder=params.output_folder) @@ -187,15 +191,10 @@ def benchmark_convolution(params): else: logger.critical("Cannot plot convolution speedups over cuE with --limited-memory flag enabled.") -def run_paper_hderiv_benchmark(params): +def benchmark_double_backward(params): from openequivariance.benchmark.benchmark_configs import mace_nequip_problems, diffdock_configs - implementations = [ - E3NNTensorProduct, - CUETensorProduct, - TensorProduct, - ] - + implementations = [E3NNTensorProduct, CUETensorProduct, TensorProduct] problems = diffdock_configs + mace_nequip_problems float64_problems = copy.deepcopy(problems) @@ -216,6 +215,30 @@ def run_paper_hderiv_benchmark(params): if params.plot: plot({"data_folder": data_folder}) +def benchmark_kahan_accuracy(params): + from openequivariance.benchmark.benchmark_configs import mace_problems + + filenames = ["carbon_lattice_radius6.0.pickle"] + graphs = download_graphs(params, filenames) + implementations = [TensorProductConvAtomic, TensorProductConvKahan] + problems = [mace_problems[0]] + + bench = ConvBenchmarkSuite(problems, test_name="kahan_convolution_accuracy", correctness_threshold=1e-4) + directions = ['forward', 'backward'] + if params.double_backward: + directions.append('double_backward') + + for graph in graphs: + for direction in directions: + output_folder = bench.run( + implementations = implementations, + graph = graph, + direction=direction, + correctness=True, + benchmark=False, + output_folder=params.output_folder, + high_precision_ref=True) + def plot(params): import openequivariance.benchmark.plotting as plotting @@ -288,9 +311,15 @@ def plot(params): parser_uvw.add_argument("--plot", action="store_true", help="Plot the results.") parser_uvw.set_defaults(func=run_paper_uvw_benchmark) - parser_higher_deriv = subparsers.add_parser('double_backward', help='Run the higher derivative kernel benchmark') - parser_higher_deriv.add_argument("--batch_size", "-b", type=int, default=50000, help="Batch size for benchmark") - parser_higher_deriv.set_defaults(func=run_paper_hderiv_benchmark) + parser_double_bwd = subparsers.add_parser('double_backward', help='Run the higher derivative kernel benchmark') + parser_double_bwd.add_argument("--batch_size", "-b", type=int, default=50000, help="Batch size for benchmark") + parser_double_bwd.set_defaults(func=benchmark_double_backward) + + parser_kahan = subparsers.add_parser('kahan_conv', help='Run the Kahan convolution accuracy benchmark') + parser_kahan.add_argument("--data", type=str, help="Folder to download graph data to (or already containing graphs)", required=True) + parser_kahan.add_argument("--disable_download", action='store_true', help="Disable downloading data files if they do not exist") + parser_kahan.add_argument("--double_backward", action='store_true', help="Run double backward test (high memory usage)") + parser_kahan.set_defaults(func=benchmark_kahan_accuracy) parser_plot = subparsers.add_parser('plot', help="Generate a plot for a folder of benchmarks.") parser_plot.add_argument("data_folder", type=str) diff --git a/tests/conv_test.py b/tests/conv_test.py index ca3cd2ab..a7e592f6 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -30,14 +30,23 @@ def graph(self, request): #graph = load_graph("data/1drf_radius3.5.pickle") return graph - @pytest.fixture(params=['atomic', 'deterministic'], scope='class') + @pytest.fixture(params=['atomic', 'deterministic', 'kahan'], scope='class') def conv_object(self, request, problem): if request.param == 'atomic': return oeq.TensorProductConv(problem, deterministic=False) elif request.param == 'deterministic': return oeq.TensorProductConv(problem, deterministic=True) + elif request.param == 'kahan': + if problem.irrep_dtype == np.float32: + return oeq.TensorProductConv(problem, deterministic=True, kahan=True) + else: + return None def test_tp_fwd(self, conv_object, graph): + if conv_object is None: + assert True + return + result = conv_object.test_correctness_forward(graph, thresh=3e-05, prng_seed=12345, @@ -46,6 +55,10 @@ def test_tp_fwd(self, conv_object, graph): self.check_result(result, "output") def test_tp_bwd(self, conv_object, graph): + if conv_object is None: + assert True + return + result = conv_object.test_correctness_backward(graph, thresh=3e-04, prng_seed=12345, @@ -56,6 +69,10 @@ def test_tp_bwd(self, conv_object, graph): self.check_result(result, "in2_grad") def test_tp_double_bwd(self, conv_object, graph): + if conv_object is None: + assert True + return + result = conv_object.test_correctness_double_backward(graph, thresh=3e-04, prng_seed=12345, @@ -78,6 +95,30 @@ def problem(self, request, dtype): request.param.irrep_dtype, request.param.weight_dtype = dtype, dtype return request.param + +class TestUVUSingleIrrep(ConvCorrectness): + muls = [ + (1, 1, 1), (8, 1, 8), (16, 1, 16), + (32, 1, 32), (5, 1, 5), (13, 1, 13), (19, 1, 19), + (33, 1, 33), (49, 1, 49), (128, 1, 128), (1, 2, 1), (1, 16, 1), (1, 32, 1), (16, 3, 16) + ] + + irs = [ (0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (2, 0, 2), (5, 3, 5), (7, 2, 5) ] + + def id_func(m, i): + return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" + + @pytest.fixture(params=product(muls, irs), + ids = lambda x: TestUVUSingleIrrep.id_func(x[0], x[1]), + scope="class") + def problem(self, request, dtype): + m, i = request.param[0], request.param[1] + instructions=[(0, 0, 0, "uvu", True)] + return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e", + instructions, shared_weights=False, + internal_weights=False, + irrep_dtype=dtype, weight_dtype=dtype) + class TestUVWSingleIrrep(ConvCorrectness): muls = [ @@ -99,4 +140,4 @@ def problem(self, request, dtype): return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e", instructions, shared_weights=False, internal_weights=False, - irrep_dtype=dtype, weight_dtype=dtype) \ No newline at end of file + irrep_dtype=dtype, weight_dtype=dtype) \ No newline at end of file diff --git a/tests/double_backwards_driver.py b/tests/double_backwards_driver.py deleted file mode 100644 index cabc35f4..00000000 --- a/tests/double_backwards_driver.py +++ /dev/null @@ -1,40 +0,0 @@ -from itertools import product -import logging - -import e3nn -from e3nn import o3 - -from openequivariance.benchmark.TestBenchmarkSuite import TestBenchmarkSuite, TestDefinition, Direction -from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct -from openequivariance.implementations.CUETensorProduct import CUETensorProduct -from openequivariance.implementations.TensorProduct import TensorProduct - -from openequivariance.implementations.e3nn_lite import TPProblem -from openequivariance.benchmark.tpp_creation_utils import FullyConnectedTPProblem, ChannelwiseTPP -from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.benchmark.benchmark_configs import mace_nequip_problems, diffdock_configs - - -implementations = [ - E3NNTensorProduct, - CUETensorProduct, - TensorProduct, -] - -problems = diffdock_configs # mace_nequip_problems - -directions : list[Direction] = [ - 'double_backward', -] - -tests = [TestDefinition(implementation, problem, direction, correctness=True, benchmark=True) for problem, direction, implementation, in product(problems, directions, implementations)] - -if __name__ == "__main__": - - logger = getLogger() - - logger.setLevel(logging.INFO) - test_suite = TestBenchmarkSuite( - bench_batch_size=50000 - ) - test_suite.run(tests) \ No newline at end of file