From 0f095fac6bfeabb93400b7cdd59e2acea60b174e Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 9 May 2025 16:17:03 -0700 Subject: [PATCH 01/13] First pass at Kahan summation. --- .../implementations/ComputationSchedule.py | 19 +++++++++- .../convolution/LoopUnrollConv.py | 14 +++++-- .../convolution/TensorProductConv.py | 11 ++++++ openequivariance/implementations/utils.py | 5 ++- .../templates/loop_unroll_conv_det.cuh | 37 ++++++++++++++++++- tests/conv_test.py | 30 ++++++++++++++- 6 files changed, 108 insertions(+), 8 deletions(-) diff --git a/openequivariance/implementations/ComputationSchedule.py b/openequivariance/implementations/ComputationSchedule.py index ebb513dd..bb75c219 100644 --- a/openequivariance/implementations/ComputationSchedule.py +++ b/openequivariance/implementations/ComputationSchedule.py @@ -245,10 +245,15 @@ def __init__(self, weight_dtype, include_scratch=False, stream_weights=False, - schedule_type=2): + schedule_type=2, + kahan=False): ''' smem_limit: size of available shared memory in bytes ''' + self.kahan = kahan + if kahan: + assert irrep_dtype == weight_dtype == np.float32 + # Note: does not work with variances for irreps; easy to add that in self.total_warps = warps_per_block * block_count @@ -288,10 +293,16 @@ def calculate_forward_smem(L1_set, L2_set, L3_set, inst_idxs): "L1": {"size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, "L2": {"size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, "L3": {"size": sum([self.L3[el].dim for el in L3_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, + "L3_kahan": {"size": 0, "dtype": self.irrep_dtype_cstr}, "weights": {"size": 0, "dtype": self.weight_dtype_cstr}, "scratch": {"size": 0, "dtype": self.weight_dtype_cstr} } + if kahan: + smem["L3_kahan"]["size"] = smem["L3"]["size"] + else: + smem.pop("L3_kahan") + weights_smem = 0 for inst_idx in inst_idxs: inst = self.new_instructions[inst_idx] @@ -325,6 +336,7 @@ def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs, smem = { "L1": {"size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, "L1_grad": {"size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, + "L1_kahan": {"size": 0, "dtype": self.irrep_dtype_cstr}, "L2": {"size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, "L2_grad": {"size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, "L3_grad": {"size": sum([self.L3[el].dim for el in L3_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr}, @@ -333,6 +345,11 @@ def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs, "scratch": {"size": 0, "dtype": self.weight_dtype_cstr} } + if kahan: + smem["L1_kahan"]["size"] = smem["L1"]["size"] + else: + smem.pop("L1_kahan") + if L2_dgrad: smem["L2_dgrad"] = {"size": smem["L2"]["size"], "dtype": self.irrep_dtype_cstr} diff --git a/openequivariance/implementations/convolution/LoopUnrollConv.py b/openequivariance/implementations/convolution/LoopUnrollConv.py index 5d6b286c..f5b1cd75 100644 --- a/openequivariance/implementations/convolution/LoopUnrollConv.py +++ b/openequivariance/implementations/convolution/LoopUnrollConv.py @@ -8,10 +8,13 @@ class LoopUnrollConv(ConvolutionBase): def __init__(self, config, idx_dtype=np.int64, - torch_op=False, deterministic=False): + torch_op=False, deterministic=False, kahan=False): super().__init__(config, idx_dtype, torch_op, deterministic) L1, L2, L3 = self.L1, self.L2, self.L3 + if kahan: + assert deterministic + env = get_jinja_environment() template = env.get_template("loop_unroll_conv_atomic.cuh") dp = DeviceProp(0) @@ -36,7 +39,8 @@ def generate_forward_schedule(warps_per_block): schedule_type=forward_schedule_type, warp_size=dp.warpsize, include_scratch=self.is_uvw, - stream_weights=self.is_uvw) + stream_weights=self.is_uvw, + kahan=kahan) def generate_backward_schedule(warps_per_block): self.backward_schedule = ComputationSchedule(self.config, @@ -48,7 +52,8 @@ def generate_backward_schedule(warps_per_block): schedule_type=backward_schedule_type, warp_size=dp.warpsize, include_scratch=self.is_uvw, - stream_weights=self.is_uvw) + stream_weights=self.is_uvw, + kahan=kahan) def generate_double_backward_schedule(warps_per_block): self.double_backward_schedule = ComputationSchedule(self.config, @@ -61,7 +66,8 @@ def generate_double_backward_schedule(warps_per_block): weight_dtype = config.weight_dtype, include_scratch=self.is_uvw, stream_weights=self.is_uvw, - schedule_type=3) + schedule_type=3, + kahan=kahan) scheduler_generators = [generate_forward_schedule, generate_backward_schedule, generate_double_backward_schedule] diff --git a/openequivariance/implementations/convolution/TensorProductConv.py b/openequivariance/implementations/convolution/TensorProductConv.py index af1ddd32..1d3cc99b 100644 --- a/openequivariance/implementations/convolution/TensorProductConv.py +++ b/openequivariance/implementations/convolution/TensorProductConv.py @@ -38,6 +38,17 @@ def forward_atomic(self, L1_in: torch.Tensor, L2_in: # ================================================================== # Reference implementations for benchmarking +class TensorProductConvKahan(TensorProductConv): + def __init__(self, config, + idx_dtype=np.int64, + torch_op=True): + super().__init__(config, idx_dtype, torch_op, deterministic=True, kahan=True) + + @staticmethod + def name(): + return "LoopUnrollConvKahan" + + class TensorProductConvDeterministic(TensorProductConv): def __init__(self, config, idx_dtype=np.int64, diff --git a/openequivariance/implementations/utils.py b/openequivariance/implementations/utils.py index f37b8384..aafb9715 100644 --- a/openequivariance/implementations/utils.py +++ b/openequivariance/implementations/utils.py @@ -54,10 +54,13 @@ def filter_and_analyze_problem(problem): Centralized function that stops unhandled problem configurations, returns a dictionary of useful information about the problem. ''' - for inst in problem.instructions: + for i, inst in enumerate(problem.instructions): assert inst.connection_mode == problem.instructions[0].connection_mode, \ f"All instructions must have the same connection mode, got {inst.connection_mode} and {problem.instructions[0].connection_mode}" + assert inst.has_weight, \ + f"All instructions must have trainable weights, got {inst.has_weight} at index {i}" + assert problem.instructions[0].connection_mode in ["uvu", "uvw"], \ f"Connection mode must be 'uvu' or 'uvw', got {problem.instructions[0].connection_mode}" diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/templates/loop_unroll_conv_det.cuh index c7cee536..a0456d21 100644 --- a/openequivariance/templates/loop_unroll_conv_det.cuh +++ b/openequivariance/templates/loop_unroll_conv_det.cuh @@ -63,6 +63,31 @@ __global__ void {{name}}(void* workspace, IRREP_T* dst_ptr) { {{ generate_fixup_kernel("fixup_forward", forward_schedule.launch_config.warp_size, forward_schedule.L3.dim, forward_workspace_offset) }} +template +__device__ __forceinline__ void kahanAdd(IRREP_T* c, IRREP_T* sum) { + c += lane_id; + sum += lane_id; + #pragma unroll + for(int j = 0; j < ROW_LEN; j += THREADS_PER_WARP) { + if(j >= ROW_LEN - THREADS_PER_WARP) { + if(lane_id < ROW_LEN - j) { + IRREP_T y = c[j]; + IRREP_T sum = sum[j]; + IRREP_T t = sum + y; + c[j] = y - (t - sum); + sum[j] = t; + } + } + else { + IRREP_T y = c[j]; + IRREP_T sum = sum[j]; + IRREP_T t = sum + y; + c[j] = y - (t - sum); + sum[j] = t; + } + } +} + __global__ void forward( IRREP_T* L1_in, IRREP_T* L2_in, @@ -99,6 +124,12 @@ __global__ void forward( bool firstSegment = true; ROW_OPERATION({{segment.L3.dim}}, j, L3_smem[j + lane_id] = 0.0f;) + {%- set ns = namespace(L3_accum="L3_smem") %} + {%- if forward_schedule.kahan %} + ROW_OPERATION({{segment.L3.dim}}, j, L3_kahan[j + lane_id] = 0.0f;) + {%- set ns.L3_accum="L3_kahan" %} + {%- endif %} + for(size_t i = start; i < end; i++) { {{idx_type}} row = rows[i]; {{idx_type}} col = cols[i]; @@ -115,9 +146,13 @@ __global__ void forward( {%- endif %} __syncwarp(); - forward_loop_unroll_{{i}}(L1_smem, L2_smem, w, weights_smem, L3_smem, scratch_smem, lane_id); + forward_loop_unroll_{{i}}(L1_smem, L2_smem, w, weights_smem, {{ns.L3_accum}}, scratch_smem, lane_id); __syncwarp(); + {%- if forward_schedule.kahan %} + kahanAdd<{{segment.L3.dim}}>(L3_kahan, L3_smem); + {%- endif %} + bool changeRow = (i < end - 1) && (row != rows[i+1]); if(changeRow || i == end - 1) { diff --git a/tests/conv_test.py b/tests/conv_test.py index ca3cd2ab..298f9354 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -99,4 +99,32 @@ def problem(self, request, dtype): return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e", instructions, shared_weights=False, internal_weights=False, - irrep_dtype=dtype, weight_dtype=dtype) \ No newline at end of file + irrep_dtype=dtype, weight_dtype=dtype) + +class TestUVUSingleIrrep(ConvCorrectness): + muls = [ (32, 1, 32) ] + irs = [(5, 3, 5)] + + def id_func(m, i): + return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" + + @pytest.fixture(params=product(muls, irs), + ids = lambda x: TestUVUSingleIrrep.id_func(x[0], x[1]), + scope="class") + def problem(self, request, dtype): + m, i = request.param[0], request.param[1] + instructions=[(0, 0, 0, "uvu", True)] + return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e", + instructions, shared_weights=False, + internal_weights=False, + irrep_dtype=dtype, weight_dtype=dtype) + + + @pytest.fixture(params=['kahan'], scope='class') + def conv_object(self, request, problem): + if request.param == 'atomic': + return oeq.TensorProductConv(problem, deterministic=False) + elif request.param == 'deterministic': + return oeq.TensorProductConv(problem, deterministic=True) + elif request.param == 'kahan': + return oeq.TensorProductConv(problem, deterministic=True, kahan=True) \ No newline at end of file From 812a7c91b2690439961762cef8cb5e0a8110370e Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 9 May 2025 16:38:01 -0700 Subject: [PATCH 02/13] Added SMEM capacity exception. --- .../implementations/ComputationSchedule.py | 13 +++++++++---- .../implementations/convolution/LoopUnrollConv.py | 6 +++--- .../convolution/TensorProductConv.py | 4 ++-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/openequivariance/implementations/ComputationSchedule.py b/openequivariance/implementations/ComputationSchedule.py index bb75c219..54ae5d9c 100644 --- a/openequivariance/implementations/ComputationSchedule.py +++ b/openequivariance/implementations/ComputationSchedule.py @@ -5,6 +5,11 @@ from openequivariance.implementations.TensorProductBase import * logger = getLogger() +class SMEMCapacityException(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + class IrrepMapping: ''' Maps irreps from a source to a destination set. @@ -104,7 +109,7 @@ def create_schedule_case2(instructions, memory_per_warp, calculate_smem, directi segments.append((cL1, cL2, cL3, cinst)) cL3, cinst = set(), [] else: - raise Exception(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!") + raise SMEMCapacityException(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!") else: cL3.add(w) cinst.append(inst_idx) @@ -130,7 +135,7 @@ def create_schedule_case3(instructions, memory_per_warp, calculate_smem, directi segments.append((cL1, cL2, cL3, cinst)) cL1, cL2, cL3, cinst = set(), set(), set(), [] else: - raise Exception(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!") + raise SMEMCapacityException(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!") else: cL1.add(u) cL2.add(v) @@ -393,11 +398,11 @@ def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs, schedule2_succeeded = False try: if schedule_type != 2: - raise Exception("Asked for schedule case 3.") + raise SMEMCapacityException("Asked for schedule case 3.") self.segments = create_schedule_case2(self.new_instructions, self.memory_per_warp, calculate_smem, direction) logger.info(f"{direction.title()} case 2 scheduling succeeded with {len(self.segments)} segments.") schedule2_succeeded = True - except Exception as e: + except SMEMCapacityException as e: self.segments = create_schedule_case3(self.new_instructions, self.memory_per_warp, calculate_smem, direction) logger.info(f"{direction.title()} case 3 scheduling succeeded with {len(self.segments)} segments.") diff --git a/openequivariance/implementations/convolution/LoopUnrollConv.py b/openequivariance/implementations/convolution/LoopUnrollConv.py index f5b1cd75..533b4fb4 100644 --- a/openequivariance/implementations/convolution/LoopUnrollConv.py +++ b/openequivariance/implementations/convolution/LoopUnrollConv.py @@ -1,5 +1,5 @@ from openequivariance.implementations.convolution.ConvolutionBase import * -from openequivariance.implementations.ComputationSchedule import ComputationSchedule +from openequivariance.implementations.ComputationSchedule import ComputationSchedule, SMEMCapacityException from openequivariance.implementations.TensorProduct import * from openequivariance.templates.jinja_utils import * from openequivariance.extlib import * @@ -77,10 +77,10 @@ def generate_double_backward_schedule(warps_per_block): try: generate_schedule(warp_count) break - except Exception as e: + except SMEMCapacityException as e: warp_count -= 1 if warp_count == 0: - raise RuntimeError("Tensor product schedule generation failed, shared memory inadequate!") + raise SMEMCapacityException("Tensor product schedule generation failed, shared memory inadequate!") if not deterministic: for segment in self.forward_schedule.segments: diff --git a/openequivariance/implementations/convolution/TensorProductConv.py b/openequivariance/implementations/convolution/TensorProductConv.py index 1d3cc99b..0ab777b7 100644 --- a/openequivariance/implementations/convolution/TensorProductConv.py +++ b/openequivariance/implementations/convolution/TensorProductConv.py @@ -9,10 +9,10 @@ class TensorProductConv(torch.nn.Module, LoopUnrollConv): ''' PyTorch-specialized dispatcher class. ''' - def __init__(self, config, idx_dtype=np.int64, torch_op=True, deterministic=False): + def __init__(self, config, idx_dtype=np.int64, torch_op=True, deterministic=False, kahan=False): torch.nn.Module.__init__(self) LoopUnrollConv.__init__(self, config, idx_dtype=np.int64, - torch_op=torch_op, deterministic=deterministic) + torch_op=torch_op, deterministic=deterministic, kahan=kahan) self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device='cuda') self.weight_numel = self.config.weight_numel From c7cb121d70992918f65aa0398909dccc63d975a1 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 9 May 2025 16:51:20 -0700 Subject: [PATCH 03/13] Kahan summation working for the forward pass. --- .../templates/loop_unroll_conv_det.cuh | 28 +++++++++---------- tests/conv_test.py | 4 +++ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/templates/loop_unroll_conv_det.cuh index a0456d21..389869ae 100644 --- a/openequivariance/templates/loop_unroll_conv_det.cuh +++ b/openequivariance/templates/loop_unroll_conv_det.cuh @@ -64,26 +64,26 @@ __global__ void {{name}}(void* workspace, IRREP_T* dst_ptr) { {{ generate_fixup_kernel("fixup_forward", forward_schedule.launch_config.warp_size, forward_schedule.L3.dim, forward_workspace_offset) }} template -__device__ __forceinline__ void kahanAdd(IRREP_T* c, IRREP_T* sum) { - c += lane_id; - sum += lane_id; +__device__ __forceinline__ void kahanAdd(IRREP_T* c_arr, IRREP_T* sum_arr, int lane_id) { + c_arr += lane_id; + sum_arr += lane_id; #pragma unroll for(int j = 0; j < ROW_LEN; j += THREADS_PER_WARP) { if(j >= ROW_LEN - THREADS_PER_WARP) { if(lane_id < ROW_LEN - j) { - IRREP_T y = c[j]; - IRREP_T sum = sum[j]; + IRREP_T y = c_arr[j]; + IRREP_T sum = sum_arr[j]; IRREP_T t = sum + y; - c[j] = y - (t - sum); - sum[j] = t; + c_arr[j] = y - (t - sum); + sum_arr[j] = t; } } else { - IRREP_T y = c[j]; - IRREP_T sum = sum[j]; + IRREP_T y = c_arr[j]; + IRREP_T sum = sum_arr[j]; IRREP_T t = sum + y; - c[j] = y - (t - sum); - sum[j] = t; + c_arr[j] = y - (t - sum); + sum_arr[j] = t; } } } @@ -126,8 +126,8 @@ __global__ void forward( {%- set ns = namespace(L3_accum="L3_smem") %} {%- if forward_schedule.kahan %} - ROW_OPERATION({{segment.L3.dim}}, j, L3_kahan[j + lane_id] = 0.0f;) - {%- set ns.L3_accum="L3_kahan" %} + ROW_OPERATION({{segment.L3.dim}}, j, L3_kahan_smem[j + lane_id] = 0.0f;) + {%- set ns.L3_accum="L3_kahan_smem" %} {%- endif %} for(size_t i = start; i < end; i++) { @@ -150,7 +150,7 @@ __global__ void forward( __syncwarp(); {%- if forward_schedule.kahan %} - kahanAdd<{{segment.L3.dim}}>(L3_kahan, L3_smem); + kahanAdd<{{segment.L3.dim}}>(L3_kahan_smem, L3_smem, lane_id); {%- endif %} bool changeRow = (i < end - 1) && (row != rows[i+1]); diff --git a/tests/conv_test.py b/tests/conv_test.py index 298f9354..e0b857c8 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -105,6 +105,10 @@ class TestUVUSingleIrrep(ConvCorrectness): muls = [ (32, 1, 32) ] irs = [(5, 3, 5)] + @pytest.fixture(params=[np.float32], ids=['F32'], scope='class') + def dtype(self, request): + return request.param + def id_func(m, i): return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" From e8fa5f40cbdd8450a966da1e9e0a34d427d5128e Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 9 May 2025 17:06:05 -0700 Subject: [PATCH 04/13] Kahan summation for the backward pass is working. --- openequivariance/templates/loop_unroll_conv_det.cuh | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/templates/loop_unroll_conv_det.cuh index 389869ae..b0acd364 100644 --- a/openequivariance/templates/loop_unroll_conv_det.cuh +++ b/openequivariance/templates/loop_unroll_conv_det.cuh @@ -214,6 +214,12 @@ __global__ void backward( bool firstSegment = true; ROW_OPERATION({{segment.L1.dim}}, j, L1_grad_smem[j + lane_id] = 0.0f;) + {%- set ns = namespace(L1_accum="L1_grad_smem") %} + {%- if backward_schedule.kahan %} + ROW_OPERATION({{segment.L1.dim}}, j, L1_kahan_smem[j + lane_id] = 0.0f;) + {%- set ns.L1_accum="L1_kahan_smem" %} + {%- endif %} + for(size_t i = start; i < end; i++) { {{idx_type}} row = rows[i]; {{idx_type}} col = cols[i]; {{idx_type}} tperm_idx = tperm[i]; @@ -249,9 +255,13 @@ __global__ void backward( __syncwarp(); backward_loop_unroll_{{i}}(L1_smem, L2_smem, w, weights_smem, L3_grad_smem, - L1_grad_smem, L2_grad_smem, wgrad, weights_grad_smem, scratch_smem, lane_id); + {{ns.L1_accum}}, L2_grad_smem, wgrad, weights_grad_smem, scratch_smem, lane_id); __syncwarp(); + {%- if backward_schedule.kahan %} + kahanAdd<{{segment.L1.dim}}>(L1_kahan_smem, L1_grad_smem, lane_id); + {%- endif %} + bool changeRow = (i < end - 1) && (col != cols[i+1]); if(changeRow || i == end - 1) { IRREP_T* dst = l1_grad_shft; From f3644acdb8af23385f43415744d30306a37ad1f0 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 9 May 2025 19:06:40 -0700 Subject: [PATCH 05/13] Forward A kernel is working with Kahan summation. --- openequivariance/templates/loop_unroll_conv_det.cuh | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/templates/loop_unroll_conv_det.cuh index b0acd364..006e1d82 100644 --- a/openequivariance/templates/loop_unroll_conv_det.cuh +++ b/openequivariance/templates/loop_unroll_conv_det.cuh @@ -319,6 +319,12 @@ __global__ void double_backward_A( bool firstSegment = true; ROW_OPERATION({{segment.L3.dim}}, j, L3_smem[j + lane_id] = 0.0f;) + {%- set ns = namespace(L3_accum="L3_smem") %} + {%- if forward_schedule.kahan %} + ROW_OPERATION({{segment.L3.dim}}, j, L3_kahan_smem[j + lane_id] = 0.0f;) + {%- set ns.L3_accum="L3_kahan_smem" %} + {%- endif %} + for(size_t i = start; i < end; i++) { unsigned {{idx_type}} row = rows[i]; unsigned {{idx_type}} col = cols[i]; @@ -361,10 +367,14 @@ __global__ void double_backward_A( } __syncwarp(); - forward_loop_unroll_{{i}}(L1_smem, L2_smem, w_buffer, weights_smem, L3_smem, scratch_smem, lane_id); + forward_loop_unroll_{{i}}(L1_smem, L2_smem, w_buffer, weights_smem, {{ns.L3_accum}}, scratch_smem, lane_id); __syncwarp(); } + {%- if forward_schedule.kahan %} + kahanAdd<{{segment.L3.dim}}>(L3_kahan_smem, L3_smem, lane_id); + {%- endif %} + bool changeRow = (i < end - 1) && (row != rows[i+1]); if(changeRow || i == end - 1) { IRREP_T* dst = l3; From c62fae6afc557544b1da1ca63b65cca8d6147912 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 9 May 2025 19:14:51 -0700 Subject: [PATCH 06/13] Kahan summation passes a simple correctness check. --- openequivariance/templates/loop_unroll_conv_det.cuh | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/templates/loop_unroll_conv_det.cuh index 006e1d82..5392da17 100644 --- a/openequivariance/templates/loop_unroll_conv_det.cuh +++ b/openequivariance/templates/loop_unroll_conv_det.cuh @@ -434,6 +434,12 @@ __global__ void double_backward_B( bool firstSegment = true; ROW_OPERATION({{segment.L1.dim}}, j, L1_grad_smem[j + lane_id] = 0.0f;) + {%- set ns = namespace(L1_accum="L1_grad_smem") %} + {%- if backward_schedule.kahan %} + ROW_OPERATION({{segment.L1.dim}}, j, L1_kahan_smem[j + lane_id] = 0.0f;) + {%- set ns.L1_accum="L1_kahan_smem" %} + {%- endif %} + for(size_t i = start; i < end; i++) { unsigned {{idx_type}} row = rows[i]; unsigned {{idx_type}} col = cols[i]; {{idx_type}} tperm_idx = tperm[i]; @@ -492,10 +498,14 @@ __global__ void double_backward_B( __syncwarp(); double_backward_loop_unroll_{{i}}(L1_smem, L2_buffer, w_buffer, weights_smem, L3_grad_smem, - L1_grad_smem, L2_grad_smem, L2_dgrad_buffer, n, wgrad, weights_grad_smem, scratch_smem, lane_id); + {{ns.L1_accum}}, L2_grad_smem, L2_dgrad_buffer, n, wgrad, weights_grad_smem, scratch_smem, lane_id); __syncwarp(); } + {%- if backward_schedule.kahan %} + kahanAdd<{{segment.L1.dim}}>(L1_kahan_smem, L1_grad_smem, lane_id); + {%- endif %} + IRREP_T* l1_grad_shft = L1_grad + col * {{schedule.L1.dim}} + lane_id; IRREP_T* l2_grad_shft = L2_grad + tperm_idx * {{schedule.L2.dim}} + lane_id; From 759a15724805e099a3c5518c0af7447225b22e0a Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 9 May 2025 19:58:45 -0700 Subject: [PATCH 07/13] Added a small suite of UVU tests. --- tests/conv_test.py | 65 ++++++++++++++++++++++++++-------------------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/tests/conv_test.py b/tests/conv_test.py index e0b857c8..a7e592f6 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -30,14 +30,23 @@ def graph(self, request): #graph = load_graph("data/1drf_radius3.5.pickle") return graph - @pytest.fixture(params=['atomic', 'deterministic'], scope='class') + @pytest.fixture(params=['atomic', 'deterministic', 'kahan'], scope='class') def conv_object(self, request, problem): if request.param == 'atomic': return oeq.TensorProductConv(problem, deterministic=False) elif request.param == 'deterministic': return oeq.TensorProductConv(problem, deterministic=True) + elif request.param == 'kahan': + if problem.irrep_dtype == np.float32: + return oeq.TensorProductConv(problem, deterministic=True, kahan=True) + else: + return None def test_tp_fwd(self, conv_object, graph): + if conv_object is None: + assert True + return + result = conv_object.test_correctness_forward(graph, thresh=3e-05, prng_seed=12345, @@ -46,6 +55,10 @@ def test_tp_fwd(self, conv_object, graph): self.check_result(result, "output") def test_tp_bwd(self, conv_object, graph): + if conv_object is None: + assert True + return + result = conv_object.test_correctness_backward(graph, thresh=3e-04, prng_seed=12345, @@ -56,6 +69,10 @@ def test_tp_bwd(self, conv_object, graph): self.check_result(result, "in2_grad") def test_tp_double_bwd(self, conv_object, graph): + if conv_object is None: + assert True + return + result = conv_object.test_correctness_double_backward(graph, thresh=3e-04, prng_seed=12345, @@ -78,57 +95,49 @@ def problem(self, request, dtype): request.param.irrep_dtype, request.param.weight_dtype = dtype, dtype return request.param - -class TestUVWSingleIrrep(ConvCorrectness): + +class TestUVUSingleIrrep(ConvCorrectness): muls = [ - (1, 1, 1), (4, 1, 4), (8, 1, 8), (16, 1, 16), (32, 1, 32), (5, 1, 5), (13, 1, 13), (33, 1, 33), (49, 1, 49), (64, 1, 64), - (1, 2, 1), (1, 4, 1), (1, 16, 1), (1, 32, 1), (16, 3, 16) + (1, 1, 1), (8, 1, 8), (16, 1, 16), + (32, 1, 32), (5, 1, 5), (13, 1, 13), (19, 1, 19), + (33, 1, 33), (49, 1, 49), (128, 1, 128), (1, 2, 1), (1, 16, 1), (1, 32, 1), (16, 3, 16) ] - irs = [(0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (5, 3, 5), (7, 2, 5)] + irs = [ (0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (2, 0, 2), (5, 3, 5), (7, 2, 5) ] def id_func(m, i): return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" @pytest.fixture(params=product(muls, irs), - ids = lambda x: TestUVWSingleIrrep.id_func(x[0], x[1]), + ids = lambda x: TestUVUSingleIrrep.id_func(x[0], x[1]), scope="class") def problem(self, request, dtype): m, i = request.param[0], request.param[1] - instructions=[(0, 0, 0, "uvw", True)] + instructions=[(0, 0, 0, "uvu", True)] return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e", instructions, shared_weights=False, internal_weights=False, irrep_dtype=dtype, weight_dtype=dtype) - -class TestUVUSingleIrrep(ConvCorrectness): - muls = [ (32, 1, 32) ] - irs = [(5, 3, 5)] - @pytest.fixture(params=[np.float32], ids=['F32'], scope='class') - def dtype(self, request): - return request.param + +class TestUVWSingleIrrep(ConvCorrectness): + muls = [ + (1, 1, 1), (4, 1, 4), (8, 1, 8), (16, 1, 16), (32, 1, 32), (5, 1, 5), (13, 1, 13), (33, 1, 33), (49, 1, 49), (64, 1, 64), + (1, 2, 1), (1, 4, 1), (1, 16, 1), (1, 32, 1), (16, 3, 16) + ] + + irs = [(0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (5, 3, 5), (7, 2, 5)] def id_func(m, i): return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" @pytest.fixture(params=product(muls, irs), - ids = lambda x: TestUVUSingleIrrep.id_func(x[0], x[1]), + ids = lambda x: TestUVWSingleIrrep.id_func(x[0], x[1]), scope="class") def problem(self, request, dtype): m, i = request.param[0], request.param[1] - instructions=[(0, 0, 0, "uvu", True)] + instructions=[(0, 0, 0, "uvw", True)] return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e", instructions, shared_weights=False, internal_weights=False, - irrep_dtype=dtype, weight_dtype=dtype) - - - @pytest.fixture(params=['kahan'], scope='class') - def conv_object(self, request, problem): - if request.param == 'atomic': - return oeq.TensorProductConv(problem, deterministic=False) - elif request.param == 'deterministic': - return oeq.TensorProductConv(problem, deterministic=True) - elif request.param == 'kahan': - return oeq.TensorProductConv(problem, deterministic=True, kahan=True) \ No newline at end of file + irrep_dtype=dtype, weight_dtype=dtype) \ No newline at end of file From 902cc1a687ffcb6841594802ff2914fa97fea9cf Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 9 May 2025 21:04:03 -0700 Subject: [PATCH 08/13] First step writing code to test Kahan summation accuracy. --- .../benchmark/ConvBenchmarkSuite.py | 6 +- .../convolution/ConvolutionBase.py | 61 ++++++++----------- tests/benchmark.py | 39 ++++++++---- tests/double_backwards_driver.py | 40 ------------ 4 files changed, 55 insertions(+), 91 deletions(-) delete mode 100644 tests/double_backwards_driver.py diff --git a/openequivariance/benchmark/ConvBenchmarkSuite.py b/openequivariance/benchmark/ConvBenchmarkSuite.py index cca53f6e..ffd13583 100644 --- a/openequivariance/benchmark/ConvBenchmarkSuite.py +++ b/openequivariance/benchmark/ConvBenchmarkSuite.py @@ -24,7 +24,6 @@ def __init__(self, configs, num_warmup = 10, num_iter = 30, reference_impl=None, - torch_op=True, test_name=None, prng_seed = 12345): self.configs = configs @@ -33,7 +32,6 @@ def __init__(self, configs, self.reference_impl = reference_impl self.prng_seed = 12345 self.correctness_threshold = 1e-5 - self.torch_op = torch_op self.exp_count = 0 self.test_name = test_name @@ -65,7 +63,7 @@ def run(self, graph, implementations, direction, output_folder=None, correctness for impl in implementations: tc_name = f"{config}, {impl.name()}" logger.info(f'Starting {tc_name}, graph {graph.name}, {direction}') - conv = impl(config, torch_op=self.torch_op) + conv = impl(config) if double_backward_correctness: double_backward_correctness = conv.test_correctness_double_backward(self.graph, @@ -100,7 +98,7 @@ def run(self, graph, implementations, direction, output_folder=None, correctness "config": str(config), "irrep_dtype": str(config.irrep_dtype), "weight_dtype": str(config.weight_dtype), - "torch_overhead_included": self.torch_op, + "torch_overhead_included": conv.torch_op, "direction": direction, "graph": graph.name, "name": impl.name(), diff --git a/openequivariance/implementations/convolution/ConvolutionBase.py b/openequivariance/implementations/convolution/ConvolutionBase.py index a806850e..0536d134 100644 --- a/openequivariance/implementations/convolution/ConvolutionBase.py +++ b/openequivariance/implementations/convolution/ConvolutionBase.py @@ -1,3 +1,4 @@ +import copy import numpy as np import numpy.linalg as la from openequivariance.extlib import * @@ -11,7 +12,6 @@ def flops_data_per_tp(config, direction): ''' Assumes all interactions are "uvu" for now - Returns (flops_per_tp, data_per_tp, nnz) ''' bytes_per_word = np.dtype(config.irrep_dtype).itemsize @@ -187,48 +187,41 @@ def backward_cpu(self, def test_correctness_forward(self, graph, thresh, prng_seed, reference_implementation=None, - check_reproducible=True, torch_op=False): - L1, L2, L3 = self.L1, self.L2, self.L3 + check_reproducible=True, high_precision_ref=False): if reference_implementation is None: from openequivariance.implementations.convolution.E3NNConv import E3NNConv reference_implementation = E3NNConv - result = { - "thresh": thresh - } + result = {"thresh": thresh} in1, in2, weights, out = get_random_buffers_forward_conv(self.config, graph.node_count, graph.nnz, prng_seed) + ref_in1, ref_in2, ref_weights, ref_out = [buf.copy() for buf in [in1, in2, weights, out]] + + reference_config = self.config + if high_precision_ref: + reference_config = copy.deepcopy(self.config) + reference_config.irrep_dtype = np.float64 + reference_config.weight_dtype = np.float64 + reference_buffers = [np.array(el, dtype=np.float64) for el in reference_buffers] + + args = { + "L1_in": ref_in1, + "L2_in": ref_in2, + "weights": ref_weights, + "rows": graph.rows, + "cols": graph.cols + } - ref_tp = reference_implementation(self.config) - ref_out = out.copy() - - if ref_tp.torch_op: - ref_out_torch = None - if not ref_tp.deterministic: - ref_out_torch = ref_tp.forward( - L1_in=torch.tensor(in1, device='cuda'), - L2_in=torch.tensor(in2, device='cuda'), - weights=torch.tensor(weights, device='cuda'), - rows = torch.tensor(graph.rows, device='cuda'), - cols = torch.tensor(graph.cols, device='cuda')) - else: - ref_out_torch = ref_tp.forward( - torch.tensor(in1, device='cuda'), - torch.tensor(in2, device='cuda'), - torch.tensor(weights, device='cuda'), - torch.tensor(graph.rows, device='cuda'), - torch.tensor(graph.cols, device='cuda'), - torch.tensor(graph.transpose_perm, device='cuda')) - ref_out[:] = ref_out_torch.cpu().numpy() - else: - ref_tp.forward_cpu( - L1_in=in1.copy(), - L2_in=in2.copy(), - weights=weights.copy(), - L3_out=ref_out, - graph=graph) + ref_tp = reference_implementation(reference_config) + if ref_tp.deterministic: + args["transpose_perm"] = graph.transpose_perm + + for key in args: + args[key] = torch.tensor(args[key], device='cuda') + + ref_out[:] = ref_tp.forward(**args).cpu().numpy() test_out = out.copy() self.forward_cpu( diff --git a/tests/benchmark.py b/tests/benchmark.py index f04b216e..9fb4e50d 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -121,12 +121,12 @@ def benchmark_roofline(params): if params.plot: plot({"data_folder": data_folder}) +def download_graphs(params): + download_prefix = "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" -def benchmark_convolution(params): filenames = [ "covid_spike_radius3.0.pickle", "1drf_radius6.0.pickle", "carbon_lattice_radius6.0.pickle"] - download_prefix = "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" if not Path(params.data).exists(): os.makedirs(params.data, exist_ok=True) @@ -140,10 +140,15 @@ def benchmark_convolution(params): exit(1) else: logging.info(f"Downloading {download_prefix + filename}...") - urllib.request.urlretrieve(download_prefix + filename, target_path) - + urllib.request.urlretrieve(download_prefix + filename, target_path) + graphs.append(load_graph(str(target_path))) + return graphs + +def benchmark_convolution(params): + graphs = download_graphs(params) + if not params.disable_bench: configs = [ ChannelwiseTPP("128x0e+128x1o+128x2e", "1x0e+1x1o+1x2e+1x3o", @@ -156,7 +161,7 @@ def benchmark_convolution(params): configs[1].irrep_dtype = np.float64 configs[1].weight_dtype = np.float64 - bench = ConvBenchmarkSuite(configs, torch_op=True, test_name="convolution") + bench = ConvBenchmarkSuite(configs, test_name="convolution") implementations = [ TensorProductConvScatterSum, CUEConv, @@ -187,15 +192,10 @@ def benchmark_convolution(params): else: logger.critical("Cannot plot convolution speedups over cuE with --limited-memory flag enabled.") -def run_paper_hderiv_benchmark(params): +def benchmark_double_backward(params): from openequivariance.benchmark.benchmark_configs import mace_nequip_problems, diffdock_configs - implementations = [ - E3NNTensorProduct, - CUETensorProduct, - TensorProduct, - ] - + implementations = [E3NNTensorProduct, CUETensorProduct, TensorProduct] problems = diffdock_configs + mace_nequip_problems float64_problems = copy.deepcopy(problems) @@ -216,6 +216,19 @@ def run_paper_hderiv_benchmark(params): if params.plot: plot({"data_folder": data_folder}) +def benchmark_kahan_accuracy(params): + from openequivariance.benchmark.benchmark_configs import mace_problems + graphs = download_graphs(params)[0] + implementations = [TensorProductConvKahan] + problem = mace_problems[0] + + output_folder = None + for graph in graphs: + for direction in ["forward", "backward"]: + conv_tp = TensorProductConvKahan(problem) + result = conv_tp.test_correctness_forward(graph, 1e-7, check_reproducible=False, high_precision_ref=True) + + print(result) def plot(params): import openequivariance.benchmark.plotting as plotting @@ -290,7 +303,7 @@ def plot(params): parser_higher_deriv = subparsers.add_parser('double_backward', help='Run the higher derivative kernel benchmark') parser_higher_deriv.add_argument("--batch_size", "-b", type=int, default=50000, help="Batch size for benchmark") - parser_higher_deriv.set_defaults(func=run_paper_hderiv_benchmark) + parser_higher_deriv.set_defaults(func=benchmark_double_backward) parser_plot = subparsers.add_parser('plot', help="Generate a plot for a folder of benchmarks.") parser_plot.add_argument("data_folder", type=str) diff --git a/tests/double_backwards_driver.py b/tests/double_backwards_driver.py deleted file mode 100644 index cabc35f4..00000000 --- a/tests/double_backwards_driver.py +++ /dev/null @@ -1,40 +0,0 @@ -from itertools import product -import logging - -import e3nn -from e3nn import o3 - -from openequivariance.benchmark.TestBenchmarkSuite import TestBenchmarkSuite, TestDefinition, Direction -from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct -from openequivariance.implementations.CUETensorProduct import CUETensorProduct -from openequivariance.implementations.TensorProduct import TensorProduct - -from openequivariance.implementations.e3nn_lite import TPProblem -from openequivariance.benchmark.tpp_creation_utils import FullyConnectedTPProblem, ChannelwiseTPP -from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.benchmark.benchmark_configs import mace_nequip_problems, diffdock_configs - - -implementations = [ - E3NNTensorProduct, - CUETensorProduct, - TensorProduct, -] - -problems = diffdock_configs # mace_nequip_problems - -directions : list[Direction] = [ - 'double_backward', -] - -tests = [TestDefinition(implementation, problem, direction, correctness=True, benchmark=True) for problem, direction, implementation, in product(problems, directions, implementations)] - -if __name__ == "__main__": - - logger = getLogger() - - logger.setLevel(logging.INFO) - test_suite = TestBenchmarkSuite( - bench_batch_size=50000 - ) - test_suite.run(tests) \ No newline at end of file From 0bccbb42768ef846324cf42bf746594408ffd661 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 9 May 2025 23:14:19 -0700 Subject: [PATCH 09/13] Kahan summation reduces error by about 5x on Kahan summation. --- .../convolution/ConvolutionBase.py | 5 ++-- tests/benchmark.py | 25 ++++++++++++------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/openequivariance/implementations/convolution/ConvolutionBase.py b/openequivariance/implementations/convolution/ConvolutionBase.py index 0536d134..2d9b8f05 100644 --- a/openequivariance/implementations/convolution/ConvolutionBase.py +++ b/openequivariance/implementations/convolution/ConvolutionBase.py @@ -97,7 +97,7 @@ def allocate_workspace(self, size_bytes): else: self.workspace_buffer = DeviceBuffer(size_bytes) self.workspace_ptr = self.workspace_buffer.data_ptr() - logger.info(f"Deterministic Convolution requires {size_bytes // 1000000}MB of workspace.") + logger.info(f"Convolution requires {size_bytes // 1000000}MB of workspace.") @staticmethod def name(): @@ -204,7 +204,8 @@ def test_correctness_forward(self, reference_config = copy.deepcopy(self.config) reference_config.irrep_dtype = np.float64 reference_config.weight_dtype = np.float64 - reference_buffers = [np.array(el, dtype=np.float64) for el in reference_buffers] + ref_in1, ref_in2, ref_weights, ref_out = [np.array(el, dtype=np.float64) + for el in [ref_in1, ref_in2, ref_weights, ref_out]] args = { "L1_in": ref_in1, diff --git a/tests/benchmark.py b/tests/benchmark.py index 9fb4e50d..19b69039 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -218,15 +218,17 @@ def benchmark_double_backward(params): def benchmark_kahan_accuracy(params): from openequivariance.benchmark.benchmark_configs import mace_problems - graphs = download_graphs(params)[0] - implementations = [TensorProductConvKahan] + graphs = download_graphs(params) + implementations = [TensorProductConvAtomic, TensorProductConvKahan] problem = mace_problems[0] - output_folder = None for graph in graphs: - for direction in ["forward", "backward"]: - conv_tp = TensorProductConvKahan(problem) - result = conv_tp.test_correctness_forward(graph, 1e-7, check_reproducible=False, high_precision_ref=True) + for impl in implementations: + conv_tp = impl(problem) + result = conv_tp.test_correctness_forward( graph, 1e-4, + check_reproducible=False, + high_precision_ref=True, + prng_seed=12345) print(result) @@ -301,9 +303,14 @@ def plot(params): parser_uvw.add_argument("--plot", action="store_true", help="Plot the results.") parser_uvw.set_defaults(func=run_paper_uvw_benchmark) - parser_higher_deriv = subparsers.add_parser('double_backward', help='Run the higher derivative kernel benchmark') - parser_higher_deriv.add_argument("--batch_size", "-b", type=int, default=50000, help="Batch size for benchmark") - parser_higher_deriv.set_defaults(func=benchmark_double_backward) + parser_double_bwd = subparsers.add_parser('double_backward', help='Run the higher derivative kernel benchmark') + parser_double_bwd.add_argument("--batch_size", "-b", type=int, default=50000, help="Batch size for benchmark") + parser_double_bwd.set_defaults(func=benchmark_double_backward) + + parser_kahan = subparsers.add_parser('kahan_conv', help='Run the Kahan convolution accuracy benchmark') + parser_kahan.add_argument("--data", type=str, help="Folder to download graph data to (or already containing graphs)", required=True) + parser_kahan.add_argument("--disable_download", action='store_true', help="Disable downloading data files if they do not exist") + parser_kahan.set_defaults(func=benchmark_kahan_accuracy) parser_plot = subparsers.add_parser('plot', help="Generate a plot for a folder of benchmarks.") parser_plot.add_argument("data_folder", type=str) From e914d60b378df634b1e83612558c13772eb02631 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 10 May 2025 01:12:01 -0700 Subject: [PATCH 10/13] We have a dependency cycle that we need to fix. --- .../convolution/ConvolutionBase.py | 66 ++++++++++++------- tests/benchmark.py | 24 +++++-- 2 files changed, 59 insertions(+), 31 deletions(-) diff --git a/openequivariance/implementations/convolution/ConvolutionBase.py b/openequivariance/implementations/convolution/ConvolutionBase.py index 2d9b8f05..8f591593 100644 --- a/openequivariance/implementations/convolution/ConvolutionBase.py +++ b/openequivariance/implementations/convolution/ConvolutionBase.py @@ -207,13 +207,8 @@ def test_correctness_forward(self, ref_in1, ref_in2, ref_weights, ref_out = [np.array(el, dtype=np.float64) for el in [ref_in1, ref_in2, ref_weights, ref_out]] - args = { - "L1_in": ref_in1, - "L2_in": ref_in2, - "weights": ref_weights, - "rows": graph.rows, - "cols": graph.cols - } + args = {"L1_in": ref_in1, "L2_in": ref_in2, "weights": ref_weights, + "rows": graph.rows, "cols": graph.cols} ref_tp = reference_implementation(reference_config) if ref_tp.deterministic: @@ -452,32 +447,36 @@ def calculate_bench_stats(self, direction, ops_per_tp, data_per_tp, time_millis, logger.info(f"{bcolors.OKCYAN}Avg. Bandwidth: {bcolors.ENDC} {bcolors.OKGREEN}{np.mean(bandwidth_gbps):.2f} ± {np.std(bandwidth_gbps):.2f} GBPs{bcolors.ENDC}") return result - def test_correctness_backward(self, graph, thresh, prng_seed, reference_implementation=None): + def test_correctness_backward(self, graph, thresh, prng_seed, reference_implementation=None, high_precision_ref=False): L1, L2, L3 = self.L1, self.L2, self.L3 if reference_implementation is None: from openequivariance.implementations.convolution.E3NNConv import E3NNConv reference_implementation = E3NNConv - result = { - "thresh": thresh - } + result = {"thresh": thresh} - in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = get_random_buffers_backward_conv(self.config, graph.node_count, graph.nnz, prng_seed) + buffers = get_random_buffers_backward_conv(self.config, graph.node_count, graph.nnz, prng_seed) + reference_buffers = [buf.copy() for buf in buffers] + reference_problem = self.config - ref_tp = reference_implementation(self.config) + if high_precision_ref: + reference_problem = copy.deepcopy(self.config) + reference_problem.irrep_dtype = np.float64 + reference_problem.weight_dtype = np.float64 + reference_buffers = [np.array(el, dtype=np.float64) for el in reference_buffers] - ref_weights_grad = weights_grad.copy() - ref_in1_grad = in1_grad.copy() - ref_in2_grad = in2_grad.copy() + ref_tp = reference_implementation(reference_problem) + in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = buffers + ref_in1, ref_in2, ref_out_grad, ref_weights, ref_weights_grad, ref_in1_grad, ref_in2_grad = reference_buffers ref_tp.backward_cpu( - L1_in=in1.copy(), + L1_in=ref_in1, L1_grad=ref_in1_grad, - L2_in=in2.copy(), + L2_in=ref_in2, L2_grad=ref_in2_grad, - L3_grad=out_grad.copy(), - weights=weights.copy(), + L3_grad=ref_out_grad, + weights=ref_weights, weights_grad=ref_weights_grad, graph=graph) @@ -491,8 +490,8 @@ def test_correctness_backward(self, graph, thresh, prng_seed, reference_implemen L1_grad=test_in1_grad, L2_in=in2.copy(), L2_grad=test_in2_grad, - L3_grad=out_grad.copy(), - weights=weights.copy(), + L3_grad=out_grad.copy(), + weights=weights.copy(), weights_grad=test_weights_grad, graph=graph) @@ -504,13 +503,13 @@ def test_correctness_backward(self, graph, thresh, prng_seed, reference_implemen return result - def test_correctness_double_backward(self, graph, thresh, prng_seed, reference_implementation=None): + def test_correctness_double_backward(self, graph, thresh, prng_seed, reference_implementation=None, high_precision_ref=False): global torch import torch assert(self.torch_op) + buffers = get_random_buffers_backward_conv(self.config, graph.node_count, graph.nnz, prng_seed) - in1, in2, out_grad, weights, _, _, _ = get_random_buffers_backward_conv(self.config, graph.node_count, graph.nnz, prng_seed) rng = np.random.default_rng(seed=prng_seed * 2) dummy_grad_value = rng.standard_normal(1)[0] @@ -518,11 +517,22 @@ def test_correctness_double_backward(self, graph, thresh, prng_seed, reference_i from openequivariance.implementations.convolution.E3NNConv import E3NNConv reference_implementation = E3NNConv - reference_tp = reference_implementation(self.config, torch_op=True) + reference_problem = self.config + if high_precision_ref: + reference_problem = copy.deepcopy(self.config) + reference_problem.irrep_dtype = np.float64 + reference_problem.weight_dtype = np.float64 + + reference_tp = reference_implementation(reference_problem, torch_op=True) result = {"thresh": thresh} tensors = [] for i, tp in enumerate([self, reference_tp]): + in1, in2, out_grad, weights, _, _, _ = [buf.copy() for buf in buffers] + + if i == 1 and high_precision_ref: + in1, in2, out_grad, weights, _, _, _ = [np.array(el, dtype=np.float64) for el in buffers] + in1_torch = torch.tensor(in1, device='cuda', requires_grad=True) in2_torch = torch.tensor(in2, device='cuda', requires_grad=True) @@ -552,6 +562,12 @@ def test_correctness_double_backward(self, graph, thresh, prng_seed, reference_i dummy = torch.norm(in1_grad) + torch.norm(in2_grad) + torch.norm(w_grad) dummy_grad = torch.tensor(float(dummy_grad_value), device='cuda', requires_grad=True) + + #torch.autograd.grad( + # outputs=[dummy], + # inputs=[out_torch, in1_torch, in2_torch, weights_torch], + # grad_outputs=[dummy_grad], + #) dummy.backward(dummy_grad, inputs=[out_grad_torch, in1_torch, in2_torch, weights_torch]) weights_grad = weights_torch.grad.detach().cpu().numpy() diff --git a/tests/benchmark.py b/tests/benchmark.py index 19b69039..d2f4c980 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -1,7 +1,7 @@ import numpy as np import numpy.linalg as la -import itertools, logging, argparse, os, copy +import itertools, logging, argparse, os, copy, gc from pathlib import Path import urllib.request @@ -218,19 +218,31 @@ def benchmark_double_backward(params): def benchmark_kahan_accuracy(params): from openequivariance.benchmark.benchmark_configs import mace_problems - graphs = download_graphs(params) + graphs = download_graphs(params)[-1:] implementations = [TensorProductConvAtomic, TensorProductConvKahan] problem = mace_problems[0] + from torch.utils.viz._cycles import warn_tensor_cycles + warn_tensor_cycles() + for graph in graphs: for impl in implementations: conv_tp = impl(problem) - result = conv_tp.test_correctness_forward( graph, 1e-4, - check_reproducible=False, + #result_fwd = conv_tp.test_correctness_forward( graph, 1e-4, + # check_reproducible=False, + # high_precision_ref=True, + # prng_seed=12345) + + #result_bwd = conv_tp.test_correctness_backward(graph, 1e-4, + # high_precision_ref=True, + # prng_seed=12345) + + result_double_bwd = conv_tp.test_correctness_double_backward(graph, 1e-4, high_precision_ref=True, prng_seed=12345) - - print(result) + #gc.collect() + + print(result_double_bwd) def plot(params): import openequivariance.benchmark.plotting as plotting From 668b974e9d192ff0db64c241c8e28c900c773a74 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 10 May 2025 01:57:18 -0700 Subject: [PATCH 11/13] Removed some reference cycles. --- .../convolution/LoopUnrollConv.py | 9 ++++--- .../convolution/TensorProductConv.py | 25 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/openequivariance/implementations/convolution/LoopUnrollConv.py b/openequivariance/implementations/convolution/LoopUnrollConv.py index 533b4fb4..676397a0 100644 --- a/openequivariance/implementations/convolution/LoopUnrollConv.py +++ b/openequivariance/implementations/convolution/LoopUnrollConv.py @@ -154,10 +154,11 @@ def generate_double_backward_schedule(warps_per_block): #with open("scratch.txt", "w") as f: # f.write(self.jit_kernel) - self.reorder_weights_e3nn_to_oeq = lambda input, output, has_batch_dim: \ - self.forward_schedule.reorder_weights(input, output, "forward", has_batch_dim) - self.reorder_weights_oeq_to_e3nn = lambda input, output, has_batch_dim: \ - self.forward_schedule.reorder_weights(input, output, "backward", has_batch_dim) + def reorder_weights_e3nn_to_oeq(self, input, output, has_batch_dim): + return self.forward_schedule.reorder_weights(input, output, "forward", has_batch_dim) + + def reorder_weights_oeq_to_e3nn(self, input, output, has_batch_dim): + return self.forward_schedule.reorder_weights(input, output, "backward", has_batch_dim) @staticmethod def name(): diff --git a/openequivariance/implementations/convolution/TensorProductConv.py b/openequivariance/implementations/convolution/TensorProductConv.py index 0ab777b7..dca2a479 100644 --- a/openequivariance/implementations/convolution/TensorProductConv.py +++ b/openequivariance/implementations/convolution/TensorProductConv.py @@ -4,6 +4,7 @@ import numpy as np from typing import Optional +import types class TensorProductConv(torch.nn.Module, LoopUnrollConv): ''' @@ -17,23 +18,21 @@ def __init__(self, config, idx_dtype=np.int64, torch_op=True, deterministic=Fals self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device='cuda') self.weight_numel = self.config.weight_numel - if extlib.TORCH_COMPILE: - self.forward = self.forward_deterministic if deterministic else self.forward_atomic + if not extlib.TORCH_COMPILE: + self.forward = types.MethodType(LoopUnrollConv.forward, self) + + def forward(self, L1_in: torch.Tensor, L2_in: + torch.Tensor, W: torch.Tensor, + rows: torch.Tensor, cols: torch.Tensor, sender_perm: Optional[torch.Tensor]=None) -> torch.Tensor: + if sender_perm is None: + return torch.ops.torch_tp_jit.jit_conv_forward(self.internal, L1_in, L2_in, W, rows, cols, self.workspace_buffer, self.dummy_transpose_perm) + else: + return torch.ops.torch_tp_jit.jit_conv_forward(self.internal, L1_in, L2_in, W, rows, cols, self.workspace_buffer, sender_perm) @staticmethod def name(): return LoopUnrollConv.name() - - def forward_deterministic(self, L1_in: torch.Tensor, L2_in: - torch.Tensor, W: torch.Tensor, - rows: torch.Tensor, cols: torch.Tensor, sender_perm: torch.Tensor) -> torch.Tensor: - return torch.ops.torch_tp_jit.jit_conv_forward(self.internal, L1_in, L2_in, W, rows, cols, self.workspace_buffer, sender_perm) - - def forward_atomic(self, L1_in: torch.Tensor, L2_in: - torch.Tensor, W: torch.Tensor, - rows: torch.Tensor, cols: torch.Tensor) -> torch.Tensor: - return torch.ops.torch_tp_jit.jit_conv_forward(self.internal, L1_in, L2_in, W, rows, cols, self.workspace_buffer, self.dummy_transpose_perm) - + # ================================================================== # Reference implementations for benchmarking From de5ee905caa170d3ddd3b1ce5f6519e0b195be71 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 10 May 2025 02:49:55 -0700 Subject: [PATCH 12/13] Accuracy benchmark completed. --- README.md | 13 ++++- .../benchmark/ConvBenchmarkSuite.py | 33 ++++++----- .../convolution/ConvolutionBase.py | 14 ++--- tests/benchmark.py | 57 +++++++++---------- 4 files changed, 63 insertions(+), 54 deletions(-) diff --git a/README.md b/README.md index 42839a75..80b14d7b 100644 --- a/README.md +++ b/README.md @@ -125,7 +125,17 @@ print(torch.norm(Z)) ``` **Note**: you don't need Pytorch geometric to use our kernels. When `deterministic=False`, the `sender` and `receiver` indices can have -arbitrary order. +arbitrary order. + +**New:** If you're working in FP32 precision and want +higher accuracy during graph convolution, we offer a Kahan +summation variant of our deterministic algorithm: + +```python +tp_conv_kahan = oeq.TensorProductConv(problem, torch_op=True, deterministic=True, kahan=True) +Z = tp_conv_kahan.forward(X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm) +print(torch.norm(Z)) +``` ## Installation We currently support Linux systems only. @@ -172,6 +182,7 @@ python tests/benchmark.py -o outputs/uvu uvu --plot python tests/benchmark.py -o outputs/uvw uvw --plot python tests/benchmark.py -o outputs/roofline roofline --plot python tests/benchmark.py -o outputs/conv conv --plot --data data/molecular_structures +python tests/benchmark.py -o outputs/kahan_conv kahan_conv --data data/molecular_structures/ ``` If your GPU has limited memory, you might want to try diff --git a/openequivariance/benchmark/ConvBenchmarkSuite.py b/openequivariance/benchmark/ConvBenchmarkSuite.py index ffd13583..35b17600 100644 --- a/openequivariance/benchmark/ConvBenchmarkSuite.py +++ b/openequivariance/benchmark/ConvBenchmarkSuite.py @@ -25,19 +25,21 @@ def __init__(self, configs, num_iter = 30, reference_impl=None, test_name=None, - prng_seed = 12345): + prng_seed = 12345, + correctness_threshold = 1e-5): self.configs = configs self.num_warmup = num_warmup self.num_iter = num_iter self.reference_impl = reference_impl self.prng_seed = 12345 - self.correctness_threshold = 1e-5 + self.correctness_threshold = correctness_threshold self.exp_count = 0 self.test_name = test_name self.millis_since_epoch = round(time.time() * 1000) - def run(self, graph, implementations, direction, output_folder=None, correctness=True, double_backward_correctness=False, benchmark=True): + def run(self, graph, implementations, direction, output_folder=None, + correctness=True, benchmark=True, high_precision_ref=False): if output_folder is None: if oeq._check_package_editable(): output_folder = oeq._editable_install_output_path / f"{self.millis_since_epoch}" @@ -65,18 +67,13 @@ def run(self, graph, implementations, direction, output_folder=None, correctness logger.info(f'Starting {tc_name}, graph {graph.name}, {direction}') conv = impl(config) - if double_backward_correctness: - double_backward_correctness = conv.test_correctness_double_backward(self.graph, - thresh=self.correctness_threshold, - prng_seed=self.prng_seed, - reference_implementation=self.reference_impl) - if direction == "forward": if correctness: correctness = conv.test_correctness_forward(graph, thresh=self.correctness_threshold, prng_seed=self.prng_seed, - reference_implementation=self.reference_impl) + reference_implementation=self.reference_impl, + high_precision_ref=high_precision_ref) if benchmark: benchmark = conv.benchmark_forward(self.num_warmup, @@ -88,11 +85,22 @@ def run(self, graph, implementations, direction, output_folder=None, correctness correctness = conv.test_correctness_backward(graph, thresh=self.correctness_threshold, prng_seed=self.prng_seed, - reference_implementation=self.reference_impl) + reference_implementation=self.reference_impl, + high_precision_ref=high_precision_ref) if benchmark: benchmark = conv.benchmark_backward(self.num_warmup, self.num_iter, graph, prng_seed=12345) + + if direction == "double_backward": + if correctness: + correctness = conv.test_correctness_double_backward(self.graph, + thresh=self.correctness_threshold, + prng_seed=self.prng_seed, + reference_implementation=self.reference_impl, + high_precision_ref=high_precision_ref) + + assert not benchmark result = { "config": str(config), @@ -103,8 +111,7 @@ def run(self, graph, implementations, direction, output_folder=None, correctness "graph": graph.name, "name": impl.name(), "correctness": correctness, - "benchmark": benchmark, - "double_backward_correctness": double_backward_correctness + "benchmark": benchmark } fname = pathlib.Path(f"{output_folder}/{self.exp_count}_{impl.name()}_{graph.name}.json") diff --git a/openequivariance/implementations/convolution/ConvolutionBase.py b/openequivariance/implementations/convolution/ConvolutionBase.py index 8f591593..b645f896 100644 --- a/openequivariance/implementations/convolution/ConvolutionBase.py +++ b/openequivariance/implementations/convolution/ConvolutionBase.py @@ -562,12 +562,6 @@ def test_correctness_double_backward(self, graph, thresh, prng_seed, reference_i dummy = torch.norm(in1_grad) + torch.norm(in2_grad) + torch.norm(w_grad) dummy_grad = torch.tensor(float(dummy_grad_value), device='cuda', requires_grad=True) - - #torch.autograd.grad( - # outputs=[dummy], - # inputs=[out_torch, in1_torch, in2_torch, weights_torch], - # grad_outputs=[dummy_grad], - #) dummy.backward(dummy_grad, inputs=[out_grad_torch, in1_torch, in2_torch, weights_torch]) weights_grad = weights_torch.grad.detach().cpu().numpy() @@ -576,10 +570,10 @@ def test_correctness_double_backward(self, graph, thresh, prng_seed, reference_i self.reorder_weights_oeq_to_e3nn(weights_grad_copy, weights_grad, not self.config.shared_weights) tensors.append(( - out_grad_torch.grad.detach().cpu().numpy(), - in1_torch.grad.detach().cpu().numpy(), - in2_torch.grad.detach().cpu().numpy(), - weights_grad + out_grad_torch.grad.detach().cpu().numpy().copy(), + in1_torch.grad.detach().cpu().numpy().copy(), + in2_torch.grad.detach().cpu().numpy().copy(), + weights_grad.copy() )) for name, to_check, ground_truth in [ diff --git a/tests/benchmark.py b/tests/benchmark.py index d2f4c980..3284348b 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -121,13 +121,9 @@ def benchmark_roofline(params): if params.plot: plot({"data_folder": data_folder}) -def download_graphs(params): +def download_graphs(params, filenames): download_prefix = "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" - filenames = [ "covid_spike_radius3.0.pickle", - "1drf_radius6.0.pickle", - "carbon_lattice_radius6.0.pickle"] - if not Path(params.data).exists(): os.makedirs(params.data, exist_ok=True) @@ -146,8 +142,12 @@ def download_graphs(params): return graphs -def benchmark_convolution(params): - graphs = download_graphs(params) +def benchmark_convolution(params): + filenames = [ "covid_spike_radius3.0.pickle", + "1drf_radius6.0.pickle", + "carbon_lattice_radius6.0.pickle"] + + graphs = download_graphs(params, filenames) if not params.disable_bench: configs = [ ChannelwiseTPP("128x0e+128x1o+128x2e", @@ -182,7 +182,6 @@ def benchmark_convolution(params): graph = graph, direction=direction, correctness=False, - double_backward_correctness=False, benchmark=True, output_folder=params.output_folder) @@ -218,31 +217,28 @@ def benchmark_double_backward(params): def benchmark_kahan_accuracy(params): from openequivariance.benchmark.benchmark_configs import mace_problems - graphs = download_graphs(params)[-1:] + + filenames = ["carbon_lattice_radius6.0.pickle"] + graphs = download_graphs(params, filenames) implementations = [TensorProductConvAtomic, TensorProductConvKahan] - problem = mace_problems[0] + problems = [mace_problems[0]] - from torch.utils.viz._cycles import warn_tensor_cycles - warn_tensor_cycles() + bench = ConvBenchmarkSuite(problems, test_name="convolution", correctness_threshold=1e-4) + directions = ['forward', 'backward'] + if params.double_backward: + directions.append('double_backward') for graph in graphs: - for impl in implementations: - conv_tp = impl(problem) - #result_fwd = conv_tp.test_correctness_forward( graph, 1e-4, - # check_reproducible=False, - # high_precision_ref=True, - # prng_seed=12345) - - #result_bwd = conv_tp.test_correctness_backward(graph, 1e-4, - # high_precision_ref=True, - # prng_seed=12345) - - result_double_bwd = conv_tp.test_correctness_double_backward(graph, 1e-4, - high_precision_ref=True, - prng_seed=12345) - #gc.collect() - - print(result_double_bwd) + for direction in directions: + output_folder = bench.run( + implementations = implementations, + graph = graph, + direction=direction, + correctness=True, + benchmark=False, + output_folder=params.output_folder, + high_precision_ref=True) + def plot(params): import openequivariance.benchmark.plotting as plotting @@ -321,7 +317,8 @@ def plot(params): parser_kahan = subparsers.add_parser('kahan_conv', help='Run the Kahan convolution accuracy benchmark') parser_kahan.add_argument("--data", type=str, help="Folder to download graph data to (or already containing graphs)", required=True) - parser_kahan.add_argument("--disable_download", action='store_true', help="Disable downloading data files if they do not exist") + parser_kahan.add_argument("--disable_download", action='store_true', help="Disable downloading data files if they do not exist") + parser_kahan.add_argument("--double_backward", action='store_true', help="Run double backward test (high memory usage)") parser_kahan.set_defaults(func=benchmark_kahan_accuracy) parser_plot = subparsers.add_parser('plot', help="Generate a plot for a folder of benchmarks.") From 4ebf9586fdf9ba71a97bd83b721234508d20a5ea Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 10 May 2025 03:21:11 -0700 Subject: [PATCH 13/13] Kahan summation complete! --- tests/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/benchmark.py b/tests/benchmark.py index 3284348b..7cabe628 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -223,7 +223,7 @@ def benchmark_kahan_accuracy(params): implementations = [TensorProductConvAtomic, TensorProductConvKahan] problems = [mace_problems[0]] - bench = ConvBenchmarkSuite(problems, test_name="convolution", correctness_threshold=1e-4) + bench = ConvBenchmarkSuite(problems, test_name="kahan_convolution_accuracy", correctness_threshold=1e-4) directions = ['forward', 'backward'] if params.double_backward: directions.append('double_backward')