From fe911871ea3a1e31f352570004ed5c1681fca3f3 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 29 Apr 2025 01:30:29 -0700 Subject: [PATCH 01/12] Making progress on double backward pass. --- openequivariance/extension/convolution.hpp | 26 +++++++++- openequivariance/extension/torch_tp_jit.cpp | 51 +++++++++++++++++++ .../implementations/ComputationSchedule.py | 2 +- 3 files changed, 76 insertions(+), 3 deletions(-) diff --git a/openequivariance/extension/convolution.hpp b/openequivariance/extension/convolution.hpp index 0e412bce..948d0ee9 100644 --- a/openequivariance/extension/convolution.hpp +++ b/openequivariance/extension/convolution.hpp @@ -109,7 +109,7 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI backward_config(backward_config_i), is_uvw(is_uvw_i) { - vector kernels = {"forward", "backward", "fixup_forward", "fixup_backward"}; + vector kernels = {"forward", "backward", "fixup_forward", "fixup_backward", "double_backward_A", "double_backward_B"}; int opt_level = 3; #ifdef HIP_BACKEND @@ -117,7 +117,7 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI opt_level = 1; } #endif - jit.compile(kernels, {{}, {}, {}, {}}, opt_level); + jit.compile(kernels, {{}, {}, {}, {}, {}, {}}, opt_level); if(forward_config.smem > 0) { jit.set_max_smem(0, forward_config.smem); @@ -201,5 +201,27 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI } } + void double_backward( + void* L1_in, void* L2_in, void* W, void* L3_grad, + void* L1_dgrad, void* L2_dgrad, void* w_dgrad, + void* L1_grad, void* L2_grad, void* W_grad, void* L3_dgrad, + void* rows_contig, void* cols_contig, + uint64_t nnz, uint64_t node_count, + void* wspace, void* transpose_perm) { + + ConvData conv_data = {rows, cols, nnz, node_count}; + void* args[] = { + &L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad, + &L1_grad, &L2_grad, &W_grad, &L3_dgrad, &conv_data, &wspace, &transpose_perm + }; + jit.execute(4, args, forward_config); + + // Execute forward fixup kernel here + + jit.execute(5, args, backward_config); + + // Execute backward fixup kernel here + } + ~JITConvImpl() = default; }; \ No newline at end of file diff --git a/openequivariance/extension/torch_tp_jit.cpp b/openequivariance/extension/torch_tp_jit.cpp index 353a527d..640e54cb 100644 --- a/openequivariance/extension/torch_tp_jit.cpp +++ b/openequivariance/extension/torch_tp_jit.cpp @@ -347,6 +347,57 @@ tuple jit_conv_backward( return tuple(L1_grad, L2_grad, W_grad); } + +tuple jit_tp_double_backward( + const c10::intrusive_ptr &jit_instance, + const torch::Tensor &L1_in, + const torch::Tensor &L2_in, + const torch::Tensor &W, + const torch::Tensor &L3_grad, + const torch::Tensor &L1_dgrad, + const torch::Tensor &L2_dgrad, + const torch::Tensor &W_dgrad, + const torch::Tensor &rows, + const torch::Tensor &cols, + const torch::Tensor &workspace, + const torch::Tensor &transpose_perm) { + + int64_t nnz = rows.sizes()[0]; + int64_t node_count = L1_in.sizes()[0]; + torch::Tensor L1_grad = torch::zeros(L1_in.sizes(), L1_in.options()); + torch::Tensor L2_grad = torch::empty(L2_in.sizes(), L2_in.options()); + torch::Tensor W_grad = torch::empty(W.sizes(), W.options()); + torch::Tensor L3_dgrad = torch::zeros(L3_grad.sizes(), L3_grad.options()); + + torch::Tensor L1_in_contig = L1_in.contiguous(); + torch::Tensor L2_in_contig = L2_in.contiguous(); + torch::Tensor W_contig = W.contiguous(); + torch::Tensor L3_grad_contig = L3_grad.contiguous(); + torch::Tensor L1_dgrad_contig = L1_dgrad.contiguous(); + torch::Tensor L2_dgrad_contig = L2_dgrad.contiguous(); + torch::Tensor W_dgrad_contig = W_dgrad.contiguous(); + + torch::Tensor rows_contig = rows.contiguous(); + torch::Tensor cols_contig = cols.contiguous(); + torch::Tensor workspace_contig = workspace.contiguous(); + torch::Tensor transpose_perm_contig = transpose_perm.contiguous(); + + jit_instance->internal.double_backward( + num_batch, + data_ptr(L1_in_contig), data_ptr(L2_in_contig), + data_ptr(W_contig), data_ptr(L3_grad_contig), + data_ptr(L1_dgrad_contig), data_ptr(L2_dgrad_contig), + data_ptr(W_dgrad_contig), + data_ptr(L1_grad), data_ptr(L2_grad), + data_ptr(W_grad), data_ptr(L3_dgrad), + data_ptr(rows_contig), data_ptr(cols_contig), + nnz, node_count, + data_ptr(workspace_contig), data_ptr(transpose_perm_contig) + ); + + return tuple(L1_grad, L2_grad, W_grad, L3_dgrad); +} + // =========================================================== TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) { diff --git a/openequivariance/implementations/ComputationSchedule.py b/openequivariance/implementations/ComputationSchedule.py index 58a2fdab..ebb513dd 100644 --- a/openequivariance/implementations/ComputationSchedule.py +++ b/openequivariance/implementations/ComputationSchedule.py @@ -276,7 +276,7 @@ def __init__(self, smem_limit -= 1 self.memory_per_warp = smem_limit // warps_per_block - self.memory_per_warp -= self.memory_per_warp % 4 + self.memory_per_warp -= self.memory_per_warp % 8 # ===================================================================== # Shared memory partitioning functions From 5cd781f12738fbc863d29b3153cd14131230e429 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 29 Apr 2025 16:10:39 -0700 Subject: [PATCH 02/12] More progress on double backward setup. --- openequivariance/extension/convolution.hpp | 10 ++++- openequivariance/extension/torch_tp_jit.cpp | 3 +- .../convolution/LoopUnrollConv.py | 42 +++++++++++++++---- .../templates/loop_unroll_conv_atomic.cuh | 23 +++++++++- .../templates/loop_unroll_conv_det.cuh | 20 +++++++++ 5 files changed, 86 insertions(+), 12 deletions(-) diff --git a/openequivariance/extension/convolution.hpp b/openequivariance/extension/convolution.hpp index 948d0ee9..d184d052 100644 --- a/openequivariance/extension/convolution.hpp +++ b/openequivariance/extension/convolution.hpp @@ -97,16 +97,19 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI JIT_IMPL jit; KernelLaunchConfig forward_config; KernelLaunchConfig backward_config; + KernelLaunchConfig double_backward_config; bool is_uvw; JITConvImpl( std::string jit_kernel, KernelLaunchConfig forward_config_i, KernelLaunchConfig backward_config_i, + KernelLaunchConfig double_backward_config_i, bool is_uvw_i) : jit(jit_kernel), forward_config(forward_config_i), backward_config(backward_config_i), + double_backward_config(double_backward_config_i), is_uvw(is_uvw_i) { vector kernels = {"forward", "backward", "fixup_forward", "fixup_backward", "double_backward_A", "double_backward_B"}; @@ -121,11 +124,16 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI if(forward_config.smem > 0) { jit.set_max_smem(0, forward_config.smem); + jit.set_max_smem(4, forward_config.smem); } if(backward_config.smem > 0) { jit.set_max_smem(1, backward_config.smem); } + + if(double_backward_config.smem > 0) { + jit.set_max_smem(5, double_backward_config.smem); + } } JITConvImpl( @@ -218,7 +226,7 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI // Execute forward fixup kernel here - jit.execute(5, args, backward_config); + jit.execute(5, args, double_backward_config); // Execute backward fixup kernel here } diff --git a/openequivariance/extension/torch_tp_jit.cpp b/openequivariance/extension/torch_tp_jit.cpp index 640e54cb..9106205c 100644 --- a/openequivariance/extension/torch_tp_jit.cpp +++ b/openequivariance/extension/torch_tp_jit.cpp @@ -383,7 +383,6 @@ tuple torch::Tensor transpose_perm_contig = transpose_perm.contiguous(); jit_instance->internal.double_backward( - num_batch, data_ptr(L1_in_contig), data_ptr(L2_in_contig), data_ptr(W_contig), data_ptr(L3_grad_contig), data_ptr(L1_dgrad_contig), data_ptr(L2_dgrad_contig), @@ -448,6 +447,7 @@ TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) { m.def("jit_conv_forward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor"); m.def("jit_conv_backward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); + m.def("jit_conv_double_backward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); }; @@ -458,6 +458,7 @@ TORCH_LIBRARY_IMPL(torch_tp_jit, CUDA, m) { m.impl("jit_conv_forward", &jit_conv_forward); m.impl("jit_conv_backward", &jit_conv_backward); + m.impl("jit_conv_double_backward", &jit_conv_double_backward); }; PYBIND11_MODULE(torch_tp_jit, m) {} \ No newline at end of file diff --git a/openequivariance/implementations/convolution/LoopUnrollConv.py b/openequivariance/implementations/convolution/LoopUnrollConv.py index 4b6cbc67..5cdb58d7 100644 --- a/openequivariance/implementations/convolution/LoopUnrollConv.py +++ b/openequivariance/implementations/convolution/LoopUnrollConv.py @@ -49,8 +49,21 @@ def generate_backward_schedule(warps_per_block): warp_size=dp.warpsize, include_scratch=self.is_uvw, stream_weights=self.is_uvw) + + def generate_double_backward_schedule(warps_per_block): + self.double_backward_schedule = ComputationSchedule(self.config, + smem_limit=dp.maxSharedMemPerBlock, + warps_per_block=warps_per_block, + warp_size=dp.warpsize, + block_count=dp.multiprocessorCount, + direction = "double_backward", + irrep_dtype = config.irrep_dtype, + weight_dtype = config.weight_dtype, + include_scratch=self.is_uvw, + stream_weights=self.is_uvw, + schedule_type=3) - scheduler_generators = [generate_forward_schedule, generate_backward_schedule] + scheduler_generators = [generate_forward_schedule, generate_backward_schedule, generate_double_backward_schedule] for generate_schedule in scheduler_generators: warp_count = 6 @@ -72,6 +85,11 @@ def generate_backward_schedule(warps_per_block): for key in segment.L1Map.storeback_procedure: segment.L1Map.storeback_procedure[key] = "atomic_accumulate" + for segment in self.double_backward_schedule.segments: + for key in segment.L1Map.storeback_procedure: + segment.L1Map.storeback_procedure[key] = "atomic_accumulate" + + idx_type_map = {np.int32: "int", np.int64: "long"} if self.torch_op: @@ -98,6 +116,7 @@ def generate_backward_schedule(warps_per_block): self.jit_kernel = template.render( forward_schedule=self.forward_schedule, backward_schedule=self.backward_schedule, + double_backward_schedule=self.double_backward_schedule, idx_type=idx_type_map[idx_dtype], forward_workspace_offset=self.forward_workspace_offset, backward_workspace_offset=self.backward_workspace_offset) @@ -115,6 +134,7 @@ def generate_backward_schedule(warps_per_block): self.internal = internal_cls(self.jit_kernel, vars(self.forward_schedule.launch_config), vars(self.backward_schedule.launch_config), + vars(self.double_backward_schedule.launch_config), {"L3_dim": self.L3.dim, "is_uvw": int(self.is_uvw)}) logger.info("Kernel compiled!") @@ -163,6 +183,7 @@ def fake_backward(jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, s def register_autograd(cls): forward_op = torch.ops.torch_tp_jit.jit_conv_forward backward_op = torch.ops.torch_tp_jit.jit_conv_backward + double_backward_op = torch.ops.torch_tp_jit.jit_conv_double_backward def setup_context(ctx, inputs, output): ctx.jit, ctx.L1_in, ctx.L2_in, ctx.W, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm = inputs @@ -180,15 +201,18 @@ def setup_context_double_backward(ctx, inputs, output): def double_backward(ctx, E, F, G): jit, A, B, C, D, rows, cols, wspace, sender_perm = ctx.jit, ctx.L1_in, ctx.L2_in, ctx.grad_output, ctx.W, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm - op1 = backward_op(jit, E, F, D, C, rows, cols, wspace, sender_perm) - op2 = backward_op(jit, A, B, G, C, rows, cols, wspace, sender_perm) - op3 = forward_op(jit, E, B, D, rows, cols, wspace, sender_perm) - op4 = backward_op(jit, E, B, D, C, rows, cols, wspace, sender_perm) # op4 and op5 could be combined with op3 and op6 - op5 = backward_op(jit, A, F, D, C, rows, cols, wspace, sender_perm) - op6 = forward_op(jit, A, F, D, rows, cols, wspace, sender_perm) - op7 = forward_op(jit, A, B, G, rows, cols, wspace, sender_perm) + #op1 = backward_op(jit, E, F, D, C, rows, cols, wspace, sender_perm) + #op2 = backward_op(jit, A, B, G, C, rows, cols, wspace, sender_perm) + #op3 = forward_op(jit, E, B, D, rows, cols, wspace, sender_perm) + #op4 = backward_op(jit, E, B, D, C, rows, cols, wspace, sender_perm) # op4 and op5 could be combined with op3 and op6 + #op5 = backward_op(jit, A, F, D, C, rows, cols, wspace, sender_perm) + #op6 = forward_op(jit, A, F, D, rows, cols, wspace, sender_perm) + #op7 = forward_op(jit, A, B, G, rows, cols, wspace, sender_perm) + + #return None, op1[0] + op2[0], op1[1] + op2[1], op4[2] + op5[2], (op3 + op6 + op7), None, None, None, None - return None, op1[0] + op2[0], op1[1] + op2[1], op4[2] + op5[2], (op3 + op6 + op7), None, None, None, None + result = double_backward_op(ctx.jit, ctx.L1_in, ctx.L2_in, ctx.W, ctx.grad_output, E, F, G, rows, cols, wspace, sender_perm) + return None, result[0], result[1], result[2], result[3], None, None, None, None torch.library.register_autograd("torch_tp_jit::jit_conv_backward", double_backward, setup_context=setup_context_double_backward) diff --git a/openequivariance/templates/loop_unroll_conv_atomic.cuh b/openequivariance/templates/loop_unroll_conv_atomic.cuh index 34541b28..6887142b 100644 --- a/openequivariance/templates/loop_unroll_conv_atomic.cuh +++ b/openequivariance/templates/loop_unroll_conv_atomic.cuh @@ -155,4 +155,25 @@ __global__ void backward( {%- endif %} } {%- endfor %} } -} \ No newline at end of file +} + + +__global__ void double_backward_A( + IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, + IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, + IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad + ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { + + printf("Hello world!"); +} + + +__global__ void double_backward_B( + IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, + IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, + IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad + ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { + + printf("Hello world!"); +} + diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/templates/loop_unroll_conv_det.cuh index 507c7435..12e2c6ce 100644 --- a/openequivariance/templates/loop_unroll_conv_det.cuh +++ b/openequivariance/templates/loop_unroll_conv_det.cuh @@ -237,4 +237,24 @@ __global__ void backward( {%- endif %} } } {%- endfor %} +} + + +__global__ void double_backward_A( + IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, + IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, + IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad + ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { + + printf("Hello world!"); +} + + +__global__ void double_backward_B( + IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, + IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, + IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad + ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { + + printf("Hello world!"); } \ No newline at end of file From e2ae747dc9efded09ac6c41baaa9025e7acf10de Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 29 Apr 2025 16:51:18 -0700 Subject: [PATCH 03/12] Custom double backward pass is compiling. --- openequivariance/extension/convolution.hpp | 12 ++++++-- openequivariance/extension/generic_module.cpp | 1 + openequivariance/extension/torch_tp_jit.cpp | 28 +++++++++++-------- .../templates/loop_unroll_conv_atomic.cuh | 7 ++--- .../templates/loop_unroll_conv_det.cuh | 4 +-- tests/conv_test.py | 11 +++++--- 6 files changed, 39 insertions(+), 24 deletions(-) diff --git a/openequivariance/extension/convolution.hpp b/openequivariance/extension/convolution.hpp index d184d052..a41b786e 100644 --- a/openequivariance/extension/convolution.hpp +++ b/openequivariance/extension/convolution.hpp @@ -140,6 +140,7 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI std::string jit_kernel, std::unordered_map fwd_dict, std::unordered_map bwd_dict, + std::unordered_map dbl_bwd_dict, std::unordered_map kernel_dims ) : JITConvImpl( jit_kernel, @@ -153,6 +154,11 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI bwd_dict["num_threads"], bwd_dict["smem"] ), + KernelLaunchConfig( + dbl_bwd_dict["num_blocks"], + dbl_bwd_dict["num_threads"], + dbl_bwd_dict["smem"] + ), kernel_dims["is_uvw"] == 1) { } void exec_conv( @@ -213,7 +219,7 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI void* L1_in, void* L2_in, void* W, void* L3_grad, void* L1_dgrad, void* L2_dgrad, void* w_dgrad, void* L1_grad, void* L2_grad, void* W_grad, void* L3_dgrad, - void* rows_contig, void* cols_contig, + void* rows, void* cols, uint64_t nnz, uint64_t node_count, void* wspace, void* transpose_perm) { @@ -224,11 +230,11 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI }; jit.execute(4, args, forward_config); - // Execute forward fixup kernel here + // TODO: Execute forward fixup kernel here jit.execute(5, args, double_backward_config); - // Execute backward fixup kernel here + // TODO: Execute backward fixup kernel here } ~JITConvImpl() = default; diff --git a/openequivariance/extension/generic_module.cpp b/openequivariance/extension/generic_module.cpp index 9b5b3ae6..4231f02d 100644 --- a/openequivariance/extension/generic_module.cpp +++ b/openequivariance/extension/generic_module.cpp @@ -49,6 +49,7 @@ PYBIND11_MODULE(generic_module, m) { .def("backward_rawptrs", &ConvolutionImpl::backward_rawptrs); py::class_, ConvolutionImpl>(m, "JITConvImpl") .def(py::init< std::string, + std::unordered_map, std::unordered_map, std::unordered_map, std::unordered_map>()); diff --git a/openequivariance/extension/torch_tp_jit.cpp b/openequivariance/extension/torch_tp_jit.cpp index 9106205c..3754ad4e 100644 --- a/openequivariance/extension/torch_tp_jit.cpp +++ b/openequivariance/extension/torch_tp_jit.cpp @@ -215,25 +215,32 @@ tuple jit_tp_double_ class TorchJITConv : public torch::CustomClassHolder { public: - Map_t fwd_dict, bwd_dict, kernel_dims; + Map_t fwd_dict, bwd_dict, dbl_bwd_dict, kernel_dims; JITConvImpl internal; int64_t L3_dim; - TorchJITConv(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t kernel_dims_i) : + TorchJITConv(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t dbl_bwd_dict_i, Map_t kernel_dims_i) : fwd_dict(fwd_dict_i.copy()), bwd_dict(bwd_dict_i.copy()), + dbl_bwd_dict(bwd_dict_i.copy()), kernel_dims(kernel_dims_i.copy()), internal(kernel_plaintext, to_map(fwd_dict_i), to_map(bwd_dict_i), + to_map(dbl_bwd_dict_i), to_map(kernel_dims_i) ), L3_dim(kernel_dims.at("L3_dim")) { } - tuple, tuple, tuple, tuple> __obj_flatten__() { + tuple, + tuple, + tuple, + tuple, + tuple> __obj_flatten__() { return tuple(tuple("kernel_plaintext", internal.jit.kernel_plaintext), tuple("fwd_config", fwd_dict), tuple("bwd_config", bwd_dict), + tuple("dbl_bwd_config", dbl_bwd_dict), tuple("kernel_dims", kernel_dims)); } @@ -347,8 +354,7 @@ tuple jit_conv_backward( return tuple(L1_grad, L2_grad, W_grad); } - -tuple jit_tp_double_backward( +tuple jit_conv_double_backward( const c10::intrusive_ptr &jit_instance, const torch::Tensor &L1_in, const torch::Tensor &L2_in, @@ -426,7 +432,7 @@ TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) { m.class_("TorchJITConv") - .def(torch::init()) + .def(torch::init()) .def("__obj_flatten__", &TorchJITConv::__obj_flatten__) .def("exec_conv_rawptrs", &TorchJITConv::exec_conv_rawptrs) .def("backward_rawptrs", &TorchJITConv::backward_rawptrs) @@ -436,18 +442,18 @@ TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) { .def_pickle( // __getstate__ [](const c10::intrusive_ptr& self) - -> tuple { - return tuple(self->internal.jit.kernel_plaintext, self->fwd_dict, self->bwd_dict, self->kernel_dims); + -> tuple { + return tuple(self->internal.jit.kernel_plaintext, self->fwd_dict, self->bwd_dict, self->dbl_bwd_dict, self->kernel_dims); }, // __setstate__ - [](tuple state) + [](tuple state) -> c10::intrusive_ptr { - return c10::make_intrusive(get<0>(state), get<1>(state), get<2>(state), get<3>(state)); + return c10::make_intrusive(get<0>(state), get<1>(state), get<2>(state), get<3>(state), get<4>(state)); }); m.def("jit_conv_forward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor"); m.def("jit_conv_backward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); - m.def("jit_conv_double_backward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("jit_conv_double_backward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); }; diff --git a/openequivariance/templates/loop_unroll_conv_atomic.cuh b/openequivariance/templates/loop_unroll_conv_atomic.cuh index 6887142b..f9256464 100644 --- a/openequivariance/templates/loop_unroll_conv_atomic.cuh +++ b/openequivariance/templates/loop_unroll_conv_atomic.cuh @@ -161,7 +161,7 @@ __global__ void backward( __global__ void double_backward_A( IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, - IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad + IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad, ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { printf("Hello world!"); @@ -171,9 +171,8 @@ __global__ void double_backward_A( __global__ void double_backward_B( IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, - IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad + IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad, ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { printf("Hello world!"); -} - +} \ No newline at end of file diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/templates/loop_unroll_conv_det.cuh index 12e2c6ce..4215764c 100644 --- a/openequivariance/templates/loop_unroll_conv_det.cuh +++ b/openequivariance/templates/loop_unroll_conv_det.cuh @@ -243,7 +243,7 @@ __global__ void backward( __global__ void double_backward_A( IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, - IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad + IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad, ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { printf("Hello world!"); @@ -253,7 +253,7 @@ __global__ void double_backward_A( __global__ void double_backward_B( IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, - IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad + IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad, ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { printf("Hello world!"); diff --git a/tests/conv_test.py b/tests/conv_test.py index 782021ee..26249174 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -81,12 +81,15 @@ def problem(self, request, dtype): class TestUVWSingleIrrep(ConvCorrectness): muls = [ - (1, 1, 1), (4, 1, 4), (8, 1, 8), (16, 1, 16), - (32, 1, 32), (5, 1, 5), (13, 1, 13), (33, 1, 33), (49, 1, 49), (64, 1, 64), - (1, 2, 1), (1, 4, 1), (1, 16, 1), (1, 32, 1), (16, 3, 16) + (1, 1, 1), #(4, 1, 4), (8, 1, 8), (16, 1, 16), + #(32, 1, 32), (5, 1, 5), (13, 1, 13), (33, 1, 33), (49, 1, 49), (64, 1, 64), + #(1, 2, 1), (1, 4, 1), (1, 16, 1), (1, 32, 1), (16, 3, 16) ] - irs = [ (0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (5, 3, 5), (7, 2, 5) ] + irs = [ #(0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), + (5, 3, 5), + #(7, 2, 5) + ] def id_func(m, i): return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" From 89c8c713fadc0250b8cccdc908eb4a2afe48594c Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 30 Apr 2025 11:35:38 -0700 Subject: [PATCH 04/12] Wrote the A kernel. --- openequivariance/extension/torch_tp_jit.cpp | 1 - .../templates/loop_unroll_conv_atomic.cuh | 66 +++++++++++++++++-- tests/conv_test.py | 22 +++++-- 3 files changed, 77 insertions(+), 12 deletions(-) diff --git a/openequivariance/extension/torch_tp_jit.cpp b/openequivariance/extension/torch_tp_jit.cpp index 3754ad4e..6711f3f1 100644 --- a/openequivariance/extension/torch_tp_jit.cpp +++ b/openequivariance/extension/torch_tp_jit.cpp @@ -456,7 +456,6 @@ TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) { m.def("jit_conv_double_backward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); }; - TORCH_LIBRARY_IMPL(torch_tp_jit, CUDA, m) { m.impl("jit_tp_forward", &jit_tp_forward); m.impl("jit_tp_backward", &jit_tp_backward); diff --git a/openequivariance/templates/loop_unroll_conv_atomic.cuh b/openequivariance/templates/loop_unroll_conv_atomic.cuh index f9256464..caf3219a 100644 --- a/openequivariance/templates/loop_unroll_conv_atomic.cuh +++ b/openequivariance/templates/loop_unroll_conv_atomic.cuh @@ -157,16 +157,74 @@ __global__ void backward( } } - __global__ void double_backward_A( IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad, ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { - printf("Hello world!"); -} + extern __shared__ char s[]; + size_t num_products = c.nnz; + unsigned {{idx_type}}* rows = (unsigned {{idx_type}}*) c.rows; + unsigned {{idx_type}}* cols = (unsigned {{idx_type}}*) c.cols; + + {{ set_launch_bound_variables(forward_schedule.launch_config) }} + {%- set tpp = forward_schedule.updated_config %} + char* smem = s + {{forward_schedule.memory_per_warp}} * warp_loc; + + {%- for i, segment in enumerate(forward_schedule.segments) %} { + {{ declare_smem_variables(segment, "smem") }} + WEIGHT_T* w_buffer; + + for(size_t i = start; i < end; i++) { + unsigned {{idx_type}} row = rows[i]; unsigned {{idx_type}} col = cols[i]; + + IRREP_T* l1 = L1_in + col * {{forward_schedule.L1.dim}} + lane_id; + IRREP_T* l2 = L2_in + i * {{forward_schedule.L2.dim}} + lane_id; + IRREP_T* l3 = L3_dgrad + row * {{forward_schedule.L3.dim}} + lane_id; + + IRREP_T* l1_dgrad = L1_dgrad + col * {{forward_schedule.L1.dim}} + lane_id; + IRREP_T* l2_dgrad = L2_dgrad + i * {{forward_schedule.L2.dim}} + lane_id; + + {%- if not tpp.shared_weights %} + WEIGHT_T* w = W + i * {{tpp.weight_numel}}; + WEIGHT_T* w_dgrad = W_dgrad + i * {{tpp.weight_numel}}; + {%- else %} + WEIGHT_T* w = W; + WEIGHT_T* w_dgrad = W_dgrad; + {%- endif %} + + __syncwarp(); + {{ load_ir_segments(segment.L1Map, "l1", "L1_smem", "j") }} + {{ load_ir_segments(segment.L2Map, "l2", "L2_smem", "j") }} + ROW_OPERATION({{segment.L3.dim}}, j, L3_smem[j + lane_id] = 0.0f;) + {%- if not forward_schedule.stream_weights %} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = w[{{segment.weight_offset}} + j + lane_id];) + {%- endif %} + + for(int n = 0; n < 3; n++) { + if(n == 1) { + {% if not forward_schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = w[{{segment.weight_offset}} + j + lane_id];) + {% endif %} + {{ load_ir_segments(segment.L2Map, "l2_dgrad", "L2_smem", "j") }} + w_buffer = w; + } + else if(n == 2) { + {{ load_ir_segments(segment.L2Map, "l2", "L2_smem", "j") }} + {{ load_ir_segments(segment.L1Map, "l1_dgrad", "L1_smem", "j") }} + } + + __syncwarp(); + forward_loop_unroll_{{i}}(L1_smem, L2_smem, w, weights_smem, L3_smem, scratch_smem, lane_id); + __syncwarp(); + } + + {{ store_ir_segments(segment.L3Map, "l3", "L3_smem", "j") }} + } + } {%- endfor %} +} __global__ void double_backward_B( IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, @@ -174,5 +232,5 @@ __global__ void double_backward_B( IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad, ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { - printf("Hello world!"); + //printf("Hello world!"); } \ No newline at end of file diff --git a/tests/conv_test.py b/tests/conv_test.py index 26249174..c81be5f6 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -13,7 +13,11 @@ def check_result(self, result, fieldname): thresh = result["thresh"] assert result[fieldname]["pass"], f"{fieldname} observed error={error:.2f} >= {thresh}" - @pytest.fixture(params=[np.float32, np.float64], ids=['F32', 'F64'], scope='class') + @pytest.fixture(params=[np.float32, + #np.float64 + ], ids=['F32', + #'F64' + ], scope='class') def dtype(self, request): return request.param @@ -22,21 +26,24 @@ def graph(self, request): download_prefix = "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" filename = request.param - graph = None - with tempfile.NamedTemporaryFile() as temp_file: - urllib.request.urlretrieve(download_prefix + filename, temp_file.name) - graph = load_graph(temp_file.name) + #graph = None + #with tempfile.NamedTemporaryFile() as temp_file: + # urllib.request.urlretrieve(download_prefix + filename, temp_file.name) + # graph = load_graph(temp_file.name) - #graph = load_graph("data/1drf_radius3.5.pickle") + graph = load_graph("data/1drf_radius3.5.pickle") return graph - @pytest.fixture(params=['atomic', 'deterministic'], scope='class') + @pytest.fixture(params=['atomic', + #'deterministic' + ], scope='class') def conv_object(self, request, problem): if request.param == 'atomic': return oeq.TensorProductConv(problem, deterministic=False) elif request.param == 'deterministic': return oeq.TensorProductConv(problem, deterministic=True) + @pytest.mark.skip def test_tp_fwd(self, conv_object, graph): result = conv_object.test_correctness_forward(graph, thresh=3e-05, @@ -45,6 +52,7 @@ def test_tp_fwd(self, conv_object, graph): self.check_result(result, "output") + @pytest.mark.skip def test_tp_bwd(self, conv_object, graph): result = conv_object.test_correctness_backward(graph, thresh=3e-04, From 2e8ccbe0ac96ba2c7f95e4a13ab96776ca09cdf5 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 30 Apr 2025 15:44:37 -0700 Subject: [PATCH 05/12] Atomic double backward pass accelerated. --- .../convolution/LoopUnrollConv.py | 5 +- .../templates/loop_unroll_conv_atomic.cuh | 105 +++++++++++++++++- 2 files changed, 105 insertions(+), 5 deletions(-) diff --git a/openequivariance/implementations/convolution/LoopUnrollConv.py b/openequivariance/implementations/convolution/LoopUnrollConv.py index 5cdb58d7..1790b734 100644 --- a/openequivariance/implementations/convolution/LoopUnrollConv.py +++ b/openequivariance/implementations/convolution/LoopUnrollConv.py @@ -199,8 +199,9 @@ def setup_context_double_backward(ctx, inputs, output): ctx.inputs = inputs def double_backward(ctx, E, F, G): - jit, A, B, C, D, rows, cols, wspace, sender_perm = ctx.jit, ctx.L1_in, ctx.L2_in, ctx.grad_output, ctx.W, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm + result = double_backward_op(ctx.jit, ctx.L1_in, ctx.L2_in, ctx.W, ctx.grad_output, E, F, G, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm) + #jit, A, B, C, D, rows, cols, wspace, sender_perm = ctx.jit, ctx.L1_in, ctx.L2_in, ctx.grad_output, ctx.W, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm #op1 = backward_op(jit, E, F, D, C, rows, cols, wspace, sender_perm) #op2 = backward_op(jit, A, B, G, C, rows, cols, wspace, sender_perm) #op3 = forward_op(jit, E, B, D, rows, cols, wspace, sender_perm) @@ -210,8 +211,8 @@ def double_backward(ctx, E, F, G): #op7 = forward_op(jit, A, B, G, rows, cols, wspace, sender_perm) #return None, op1[0] + op2[0], op1[1] + op2[1], op4[2] + op5[2], (op3 + op6 + op7), None, None, None, None + #print(torch.norm(op7 - result[3])) - result = double_backward_op(ctx.jit, ctx.L1_in, ctx.L2_in, ctx.W, ctx.grad_output, E, F, G, rows, cols, wspace, sender_perm) return None, result[0], result[1], result[2], result[3], None, None, None, None torch.library.register_autograd("torch_tp_jit::jit_conv_backward", double_backward, setup_context=setup_context_double_backward) diff --git a/openequivariance/templates/loop_unroll_conv_atomic.cuh b/openequivariance/templates/loop_unroll_conv_atomic.cuh index caf3219a..ef6502c4 100644 --- a/openequivariance/templates/loop_unroll_conv_atomic.cuh +++ b/openequivariance/templates/loop_unroll_conv_atomic.cuh @@ -198,11 +198,13 @@ __global__ void double_backward_A( {{ load_ir_segments(segment.L1Map, "l1", "L1_smem", "j") }} {{ load_ir_segments(segment.L2Map, "l2", "L2_smem", "j") }} ROW_OPERATION({{segment.L3.dim}}, j, L3_smem[j + lane_id] = 0.0f;) - + {%- if not forward_schedule.stream_weights %} ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = w[{{segment.weight_offset}} + j + lane_id];) {%- endif %} + w_buffer = w_dgrad; + for(int n = 0; n < 3; n++) { if(n == 1) { {% if not forward_schedule.stream_weights%} @@ -217,7 +219,7 @@ __global__ void double_backward_A( } __syncwarp(); - forward_loop_unroll_{{i}}(L1_smem, L2_smem, w, weights_smem, L3_smem, scratch_smem, lane_id); + forward_loop_unroll_{{i}}(L1_smem, L2_smem, w_buffer, weights_smem, L3_smem, scratch_smem, lane_id); __syncwarp(); } @@ -226,11 +228,108 @@ __global__ void double_backward_A( } {%- endfor %} } +{%- for i, segment in enumerate(double_backward_schedule.segments) %} +{{ generate_segment_kernel_backward(i, segment, double_backward_schedule.launch_config.warp_size, double_bwd=True) }} +{%- endfor %} + +{% set schedule = double_backward_schedule %} + __global__ void double_backward_B( IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad, ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { - //printf("Hello world!"); + size_t num_products = c.nnz; + unsigned {{idx_type}}* rows = (unsigned {{idx_type}}*) c.rows; + unsigned {{idx_type}}* cols = (unsigned {{idx_type}}*) c.cols; + + extern __shared__ char s[]; + {{ set_launch_bound_variables(schedule.launch_config) }} + char* smem = s + {{schedule.memory_per_warp}} * warp_loc; + + {%- set tpp = schedule.updated_config %} + + {%- for i, segment in enumerate(schedule.segments) %} { + {{ declare_smem_variables(segment, "smem") }} + for(size_t i = start; i < end; i++) { + unsigned {{idx_type}} row = rows[i]; unsigned {{idx_type}} col = cols[i]; + + IRREP_T* l1_shft = L1_dgrad + col * {{schedule.L1.dim}} + lane_id; + IRREP_T* l2_shft = L2_dgrad + i * {{schedule.L2.dim}} + lane_id; + IRREP_T* l3_shft = L3_grad + row * {{schedule.L3.dim}} + lane_id; + + IRREP_T* l1_original = L1_in + col * {{schedule.L1.dim}} + lane_id; + IRREP_T* l2_original = L2_in + i * {{schedule.L2.dim}} + lane_id; + + {%- if not tpp.shared_weights %} + WEIGHT_T* w = W + i * {{tpp.weight_numel}}; + WEIGHT_T* wgrad = W_grad + i * {{tpp.weight_numel}}; + WEIGHT_T* wdgrad = W_dgrad + i * {{tpp.weight_numel}}; + {%- else %} + WEIGHT_T* w = W; + WEIGHT_T* wgrad = W_grad; + WEIGHT_T* wdgrad = W_dgrad; + {%- endif %} + WEIGHT_T* weights_shft = w + lane_id; + WEIGHT_T* weights_dgrad_shft = wdgrad + lane_id; + + {{ load_ir_segments(segment.L3Map, "l3_shft", "L3_grad_smem", "j") }} + {{ load_ir_segments(segment.L1Map, "l1_shft", "L1_smem", "j") }} + {{ load_ir_segments(segment.L2Map, "l2_shft", "L2_smem", "j") }} + {{ load_ir_segments(segment.L2Map, "l2_original", "L2_dgrad_smem", "j") }} + + __syncwarp(); + {%- if not segment.L1Map.persist_load %} + ROW_OPERATION({{segment.L1.dim}}, j, L1_grad_smem[j + lane_id] = 0.0f;) + {%- endif %} + + {%- if not segment.L2Map.persist_load %} + ROW_OPERATION({{segment.L2.dim}}, j, L2_grad_smem[j + lane_id] = 0.0f;) + {%- endif %} + + {% if not schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = weights_shft[{{segment.weight_offset}} + j];) + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_smem[j + lane_id] = 0.0;) + {% endif %} + + WEIGHT_T* w_buffer = w; + IRREP_T* L2_buffer = L2_smem; + IRREP_T* L2_dgrad_buffer = L2_dgrad_smem; + + for(int n = 0; n < 2; n++) { + if(n == 1) { + {{ load_ir_segments(segment.L1Map, "l1_original", "L1_smem", "j") }} + + {% if not schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = weights_dgrad_shft[{{segment.weight_offset}} + j];) + {% endif %} + w_buffer = wdgrad; + L2_buffer = L2_dgrad_smem; + L2_dgrad_buffer = L2_smem; + } + + __syncwarp(); + double_backward_loop_unroll_{{i}}(L1_smem, L2_buffer, w_buffer, weights_smem, L3_grad_smem, + L1_grad_smem, L2_grad_smem, L2_dgrad_buffer, n, wgrad, weights_grad_smem, scratch_smem, lane_id); + __syncwarp(); + } + + IRREP_T* l1_grad_shft = L1_grad + col * {{schedule.L1.dim}} + lane_id; + IRREP_T* l2_grad_shft = L2_grad + i * {{schedule.L2.dim}} + lane_id; + + {%- if not tpp.shared_weights %} + WEIGHT_T* weights_grad_shft = W_grad + i * {{schedule.updated_config.weight_numel}} + lane_id; + {%- else %} + WEIGHT_T* weights_grad_shft = W_grad + lane_id; + {%- endif %} + + {{ store_ir_segments(segment.L1Map, "l1_grad_shft", "L1_grad_smem", "j") }} + {{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }} + + {% if not schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];) + {% endif %} + } + } {%- endfor %} } \ No newline at end of file From 3ad926e5f9afe42cc8b185895e5eb45846882291 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 30 Apr 2025 18:11:02 -0700 Subject: [PATCH 06/12] More convolution double backward tests passing. --- tests/conv_test.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/conv_test.py b/tests/conv_test.py index c81be5f6..20eb6eff 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -89,15 +89,12 @@ def problem(self, request, dtype): class TestUVWSingleIrrep(ConvCorrectness): muls = [ - (1, 1, 1), #(4, 1, 4), (8, 1, 8), (16, 1, 16), - #(32, 1, 32), (5, 1, 5), (13, 1, 13), (33, 1, 33), (49, 1, 49), (64, 1, 64), - #(1, 2, 1), (1, 4, 1), (1, 16, 1), (1, 32, 1), (16, 3, 16) + (1, 1, 1), (4, 1, 4), (8, 1, 8), (16, 1, 16), + (32, 1, 32), (5, 1, 5), (13, 1, 13), (33, 1, 33), (49, 1, 49), (64, 1, 64), + (1, 2, 1), (1, 4, 1), (1, 16, 1), (1, 32, 1), (16, 3, 16) ] - irs = [ #(0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), - (5, 3, 5), - #(7, 2, 5) - ] + irs = [(0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (5, 3, 5), (7, 2, 5) ] def id_func(m, i): return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" From 382251873a27914891d8e86a7ac207067d3129f1 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 30 Apr 2025 22:16:52 -0700 Subject: [PATCH 07/12] Deterministic double backward A kernel written. --- .../templates/loop_unroll_conv_det.cuh | 92 ++++++++++++++++++- tests/conv_test.py | 17 ++-- 2 files changed, 101 insertions(+), 8 deletions(-) diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/templates/loop_unroll_conv_det.cuh index 4215764c..8bd7fb39 100644 --- a/openequivariance/templates/loop_unroll_conv_det.cuh +++ b/openequivariance/templates/loop_unroll_conv_det.cuh @@ -246,7 +246,95 @@ __global__ void double_backward_A( IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad, ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { - printf("Hello world!"); + extern __shared__ char s[]; + size_t num_products = c.nnz; + unsigned {{idx_type}}* rows = (unsigned {{idx_type}}*) c.rows; + unsigned {{idx_type}}* cols = (unsigned {{idx_type}}*) c.cols; + + IRREP_T* workspace = (IRREP_T*) workspace_raw; + {{idx_type}}* dst_idxs = ({{idx_type}}*) ((char*) workspace + {{forward_workspace_offset}}); + + if(lane_id == 0) { + if(start < end) { + dst_idxs[warp_id] = rows[start]; + } + else { + dst_idxs[warp_id] = -1; + } + } + + {{ set_launch_bound_variables(forward_schedule.launch_config) }} + {%- set tpp = forward_schedule.updated_config %} + char* smem = s + {{forward_schedule.memory_per_warp}} * warp_loc; + + {%- for i, segment in enumerate(forward_schedule.segments) %} { + {{ declare_smem_variables(segment, "smem") }} + WEIGHT_T* w_buffer; + + bool firstSegment = true; + ROW_OPERATION({{segment.L3.dim}}, j, L3_smem[j + lane_id] = 0.0f;) + + for(size_t i = start; i < end; i++) { + unsigned {{idx_type}} row = rows[i]; unsigned {{idx_type}} col = cols[i]; + + IRREP_T* l1 = L1_in + col * {{forward_schedule.L1.dim}} + lane_id; + IRREP_T* l2 = L2_in + i * {{forward_schedule.L2.dim}} + lane_id; + IRREP_T* l3 = L3_dgrad + row * {{forward_schedule.L3.dim}} + lane_id; + + IRREP_T* l1_dgrad = L1_dgrad + col * {{forward_schedule.L1.dim}} + lane_id; + IRREP_T* l2_dgrad = L2_dgrad + i * {{forward_schedule.L2.dim}} + lane_id; + + {%- if not tpp.shared_weights %} + WEIGHT_T* w = W + i * {{tpp.weight_numel}}; + WEIGHT_T* w_dgrad = W_dgrad + i * {{tpp.weight_numel}}; + {%- else %} + WEIGHT_T* w = W; + WEIGHT_T* w_dgrad = W_dgrad; + {%- endif %} + + __syncwarp(); + {{ load_ir_segments(segment.L1Map, "l1", "L1_smem", "j") }} + {{ load_ir_segments(segment.L2Map, "l2", "L2_smem", "j") }} + ROW_OPERATION({{segment.L3.dim}}, j, L3_smem[j + lane_id] = 0.0f;) + + {%- if not forward_schedule.stream_weights %} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = w[{{segment.weight_offset}} + j + lane_id];) + {%- endif %} + + w_buffer = w_dgrad; + + for(int n = 0; n < 3; n++) { + if(n == 1) { + {% if not forward_schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = w[{{segment.weight_offset}} + j + lane_id];) + {% endif %} + {{ load_ir_segments(segment.L2Map, "l2_dgrad", "L2_smem", "j") }} + w_buffer = w; + } + else if(n == 2) { + {{ load_ir_segments(segment.L2Map, "l2", "L2_smem", "j") }} + {{ load_ir_segments(segment.L1Map, "l1_dgrad", "L1_smem", "j") }} + } + + __syncwarp(); + forward_loop_unroll_{{i}}(L1_smem, L2_smem, w_buffer, weights_smem, L3_smem, scratch_smem, lane_id); + __syncwarp(); + } + + bool changeRow = (i < end - 1) && (row != rows[i+1]); + if(changeRow || i == end - 1) { + IRREP_T* dst = l3; + if(firstSegment) { + dst = workspace + {{forward_schedule.L3.dim}} * warp_id + lane_id; + firstSegment = false; + } + {{ store_ir_segments(segment.L3Map, "dst", "L3_smem", "j") }} + __syncwarp(); + + ROW_OPERATION({{segment.L3.dim}}, j, L3_smem[j + lane_id] = 0.0f;) + } + } + } {%- endfor %} } @@ -256,5 +344,5 @@ __global__ void double_backward_B( IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad, ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { - printf("Hello world!"); + } \ No newline at end of file diff --git a/tests/conv_test.py b/tests/conv_test.py index 20eb6eff..a64b0296 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -34,8 +34,8 @@ def graph(self, request): graph = load_graph("data/1drf_radius3.5.pickle") return graph - @pytest.fixture(params=['atomic', - #'deterministic' + @pytest.fixture(params=[#'atomic', + 'deterministic' ], scope='class') def conv_object(self, request, problem): if request.param == 'atomic': @@ -89,12 +89,17 @@ def problem(self, request, dtype): class TestUVWSingleIrrep(ConvCorrectness): muls = [ - (1, 1, 1), (4, 1, 4), (8, 1, 8), (16, 1, 16), - (32, 1, 32), (5, 1, 5), (13, 1, 13), (33, 1, 33), (49, 1, 49), (64, 1, 64), - (1, 2, 1), (1, 4, 1), (1, 16, 1), (1, 32, 1), (16, 3, 16) + (1, 1, 1), + #(4, 1, 4), (8, 1, 8), (16, 1, 16), + #(32, 1, 32), (5, 1, 5), (13, 1, 13), (33, 1, 33), (49, 1, 49), (64, 1, 64), + #(1, 2, 1), (1, 4, 1), (1, 16, 1), (1, 32, 1), (16, 3, 16) ] - irs = [(0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (5, 3, 5), (7, 2, 5) ] + irs = [#(0, 0, 0), (1, 1, 1), (1, 0, 1), + #(1, 2, 1), + (5, 3, 5), + #(7, 2, 5) + ] def id_func(m, i): return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" From 37474d50f6f42770f96b0f86c787979ee9adef21 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 30 Apr 2025 22:45:13 -0700 Subject: [PATCH 08/12] Double backward deterministic is compiling, but we need to check the output. --- openequivariance/extension/convolution.hpp | 12 +++++++++--- openequivariance/templates/loop_unroll_conv_det.cuh | 6 +++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/openequivariance/extension/convolution.hpp b/openequivariance/extension/convolution.hpp index a41b786e..c68b601b 100644 --- a/openequivariance/extension/convolution.hpp +++ b/openequivariance/extension/convolution.hpp @@ -228,12 +228,18 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI &L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad, &L1_grad, &L2_grad, &W_grad, &L3_dgrad, &conv_data, &wspace, &transpose_perm }; - jit.execute(4, args, forward_config); - // TODO: Execute forward fixup kernel here + jit.execute(4, args, forward_config); + if(reinterpret_cast(wspace) != 0) { + void *fixup_args[] = {&wspace, &L3_dgrad}; + KernelLaunchConfig fixup_config; + fixup_config.num_blocks = forward_config.num_blocks; + fixup_config.num_threads = forward_config.num_threads; + fixup_config.smem = 0; + jit.execute(2, fixup_args, fixup_config); + } jit.execute(5, args, double_backward_config); - // TODO: Execute backward fixup kernel here } diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/templates/loop_unroll_conv_det.cuh index 8bd7fb39..d9a70617 100644 --- a/openequivariance/templates/loop_unroll_conv_det.cuh +++ b/openequivariance/templates/loop_unroll_conv_det.cuh @@ -244,13 +244,15 @@ __global__ void double_backward_A( IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad, - ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { + ConvData c, void* workspace_raw, unsigned {{idx_type}}* transpose_perm) { extern __shared__ char s[]; size_t num_products = c.nnz; unsigned {{idx_type}}* rows = (unsigned {{idx_type}}*) c.rows; unsigned {{idx_type}}* cols = (unsigned {{idx_type}}*) c.cols; + {{ set_launch_bound_variables(forward_schedule.launch_config) }} + IRREP_T* workspace = (IRREP_T*) workspace_raw; {{idx_type}}* dst_idxs = ({{idx_type}}*) ((char*) workspace + {{forward_workspace_offset}}); @@ -263,7 +265,6 @@ __global__ void double_backward_A( } } - {{ set_launch_bound_variables(forward_schedule.launch_config) }} {%- set tpp = forward_schedule.updated_config %} char* smem = s + {{forward_schedule.memory_per_warp}} * warp_loc; @@ -337,7 +338,6 @@ __global__ void double_backward_A( } {%- endfor %} } - __global__ void double_backward_B( IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, From 8a7bd974f1102f50b02703d6c73f59ec7d9a6f06 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 30 Apr 2025 22:48:11 -0700 Subject: [PATCH 09/12] Double backward A kernel is working. --- openequivariance/templates/loop_unroll_conv_det.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/templates/loop_unroll_conv_det.cuh index d9a70617..3e180aa7 100644 --- a/openequivariance/templates/loop_unroll_conv_det.cuh +++ b/openequivariance/templates/loop_unroll_conv_det.cuh @@ -296,7 +296,6 @@ __global__ void double_backward_A( __syncwarp(); {{ load_ir_segments(segment.L1Map, "l1", "L1_smem", "j") }} {{ load_ir_segments(segment.L2Map, "l2", "L2_smem", "j") }} - ROW_OPERATION({{segment.L3.dim}}, j, L3_smem[j + lane_id] = 0.0f;) {%- if not forward_schedule.stream_weights %} ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = w[{{segment.weight_offset}} + j + lane_id];) From d79617efee9b33aa1236569d965f1af7a2d128a9 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 1 May 2025 12:52:00 -0700 Subject: [PATCH 10/12] Gradient of L1 has an error, need to fix. --- openequivariance/extension/convolution.hpp | 14 +- .../convolution/LoopUnrollConv.py | 9 +- .../templates/loop_unroll_conv_det.cuh | 127 +++++++++++++++++- 3 files changed, 143 insertions(+), 7 deletions(-) diff --git a/openequivariance/extension/convolution.hpp b/openequivariance/extension/convolution.hpp index c68b601b..048aaff6 100644 --- a/openequivariance/extension/convolution.hpp +++ b/openequivariance/extension/convolution.hpp @@ -112,7 +112,7 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI double_backward_config(double_backward_config_i), is_uvw(is_uvw_i) { - vector kernels = {"forward", "backward", "fixup_forward", "fixup_backward", "double_backward_A", "double_backward_B"}; + vector kernels = {"forward", "backward", "fixup_forward", "fixup_backward", "double_backward_A", "double_backward_B", "fixup_double_backwardB"}; int opt_level = 3; #ifdef HIP_BACKEND @@ -120,7 +120,7 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI opt_level = 1; } #endif - jit.compile(kernels, {{}, {}, {}, {}, {}, {}}, opt_level); + jit.compile(kernels, {{}, {}, {}, {}, {}, {}, {}}, opt_level); if(forward_config.smem > 0) { jit.set_max_smem(0, forward_config.smem); @@ -240,7 +240,15 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI } jit.execute(5, args, double_backward_config); - // TODO: Execute backward fixup kernel here + + if(reinterpret_cast(wspace) != 0) { + void *fixup_args[] = {&wspace, &L1_grad}; + KernelLaunchConfig fixup_config; + fixup_config.num_blocks = double_backward_config.num_blocks; + fixup_config.num_threads = double_backward_config.num_threads; + fixup_config.smem = 0; + jit.execute(6, fixup_args, fixup_config); + } } ~JITConvImpl() = default; diff --git a/openequivariance/implementations/convolution/LoopUnrollConv.py b/openequivariance/implementations/convolution/LoopUnrollConv.py index 1790b734..0c52f6ba 100644 --- a/openequivariance/implementations/convolution/LoopUnrollConv.py +++ b/openequivariance/implementations/convolution/LoopUnrollConv.py @@ -103,13 +103,17 @@ def generate_double_backward_schedule(warps_per_block): destination_index_bytes = 32 # Add extra to account for padding workspace_size = max( (self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.forward_schedule.total_warps, - (self.backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.backward_schedule.total_warps) + (self.backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.backward_schedule.total_warps, + (self.double_backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.double_backward_schedule.total_warps + ) self.forward_workspace_offset = self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize * self.forward_schedule.total_warps self.backward_workspace_offset = self.backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize * self.backward_schedule.total_warps + self.double_backwardB_offset = self.double_backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize * self.double_backward_schedule.total_warps self.forward_workspace_offset = (self.forward_workspace_offset + 7) // 8 * 8 self.backward_workspace_offset = (self.backward_workspace_offset + 7) // 8 * 8 + self.double_backwardB_offset = (self.double_backwardB_offset + 7) // 8 * 8 self.allocate_workspace(workspace_size) @@ -119,7 +123,8 @@ def generate_double_backward_schedule(warps_per_block): double_backward_schedule=self.double_backward_schedule, idx_type=idx_type_map[idx_dtype], forward_workspace_offset=self.forward_workspace_offset, - backward_workspace_offset=self.backward_workspace_offset) + backward_workspace_offset=self.backward_workspace_offset, + double_backwardB_offset=self.double_backwardB_offset) self.jit_kernel = postprocess_kernel(self.jit_kernel) if self.torch_op and extlib.TORCH_COMPILE: diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/templates/loop_unroll_conv_det.cuh index 3e180aa7..7592729d 100644 --- a/openequivariance/templates/loop_unroll_conv_det.cuh +++ b/openequivariance/templates/loop_unroll_conv_det.cuh @@ -218,7 +218,6 @@ __global__ void backward( __syncwarp(); bool changeRow = (i < end - 1) && (col != cols[i+1]); - if(changeRow || i == end - 1) { IRREP_T* dst = l1_grad_shft; if(firstSegment) { @@ -337,11 +336,135 @@ __global__ void double_backward_A( } {%- endfor %} } +{{ generate_fixup_kernel("fixup_double_backwardB", double_backward_schedule.launch_config.warp_size, backward_schedule.L1.dim, double_backwardB_offset) }} + +{%- for i, segment in enumerate(double_backward_schedule.segments) %} +{{ generate_segment_kernel_backward(i, segment, double_backward_schedule.launch_config.warp_size, double_bwd=True) }} +{%- endfor %} + +{% set schedule = double_backward_schedule %} + __global__ void double_backward_B( IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad, - ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { + ConvData c, void* workspace_raw, unsigned {{idx_type}}* transpose_perm) { + + size_t num_products = c.nnz; + unsigned {{idx_type}}* rows = (unsigned {{idx_type}}*) c.rows; + unsigned {{idx_type}}* cols = (unsigned {{idx_type}}*) c.cols; + + extern __shared__ char s[]; + {{ set_launch_bound_variables(schedule.launch_config) }} + char* smem = s + {{schedule.memory_per_warp}} * warp_loc; + + IRREP_T* workspace = (IRREP_T*) workspace_raw; + {{idx_type}}* dst_idxs = ({{idx_type}}*) ((char*) workspace + {{double_backwardB_offset}}); + + if(lane_id == 0) { + if(start < end) { + dst_idxs[warp_id] = rows[start]; + } + else { + dst_idxs[warp_id] = -1; + } + } + + {%- set tpp = schedule.updated_config %} + + {%- for i, segment in enumerate(schedule.segments) %} { + {{ declare_smem_variables(segment, "smem") }} + + bool firstSegment = true; + ROW_OPERATION({{segment.L1.dim}}, j, L1_grad_smem[j + lane_id] = 0.0f;) + + for(size_t i = start; i < end; i++) { + unsigned {{idx_type}} row = rows[i]; unsigned {{idx_type}} col = cols[i]; + + IRREP_T* l1_shft = L1_dgrad + col * {{schedule.L1.dim}} + lane_id; + IRREP_T* l2_shft = L2_dgrad + i * {{schedule.L2.dim}} + lane_id; + IRREP_T* l3_shft = L3_grad + row * {{schedule.L3.dim}} + lane_id; + + IRREP_T* l1_original = L1_in + col * {{schedule.L1.dim}} + lane_id; + IRREP_T* l2_original = L2_in + i * {{schedule.L2.dim}} + lane_id; + + {%- if not tpp.shared_weights %} + WEIGHT_T* w = W + i * {{tpp.weight_numel}}; + WEIGHT_T* wgrad = W_grad + i * {{tpp.weight_numel}}; + WEIGHT_T* wdgrad = W_dgrad + i * {{tpp.weight_numel}}; + {%- else %} + WEIGHT_T* w = W; + WEIGHT_T* wgrad = W_grad; + WEIGHT_T* wdgrad = W_dgrad; + {%- endif %} + + WEIGHT_T* weights_shft = w + lane_id; + WEIGHT_T* weights_dgrad_shft = wdgrad + lane_id; + {{ load_ir_segments(segment.L3Map, "l3_shft", "L3_grad_smem", "j") }} + {{ load_ir_segments(segment.L1Map, "l1_shft", "L1_smem", "j") }} + {{ load_ir_segments(segment.L2Map, "l2_shft", "L2_smem", "j") }} + {{ load_ir_segments(segment.L2Map, "l2_original", "L2_dgrad_smem", "j") }} + + __syncwarp(); + + {%- if not segment.L2Map.persist_load %} + ROW_OPERATION({{segment.L2.dim}}, j, L2_grad_smem[j + lane_id] = 0.0f;) + {%- endif %} + + {% if not schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = weights_shft[{{segment.weight_offset}} + j];) + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_smem[j + lane_id] = 0.0;) + {% endif %} + + WEIGHT_T* w_buffer = w; + IRREP_T* L2_buffer = L2_smem; + IRREP_T* L2_dgrad_buffer = L2_dgrad_smem; + + for(int n = 0; n < 2; n++) { + if(n == 1) { + {{ load_ir_segments(segment.L1Map, "l1_original", "L1_smem", "j") }} + + {% if not schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = weights_dgrad_shft[{{segment.weight_offset}} + j];) + {% endif %} + w_buffer = wdgrad; + L2_buffer = L2_dgrad_smem; + L2_dgrad_buffer = L2_smem; + } + + __syncwarp(); + double_backward_loop_unroll_{{i}}(L1_smem, L2_buffer, w_buffer, weights_smem, L3_grad_smem, + L1_grad_smem, L2_grad_smem, L2_dgrad_buffer, n, wgrad, weights_grad_smem, scratch_smem, lane_id); + __syncwarp(); + } + + IRREP_T* l1_grad_shft = L1_grad + col * {{schedule.L1.dim}} + lane_id; + IRREP_T* l2_grad_shft = L2_grad + i * {{schedule.L2.dim}} + lane_id; + {%- if not tpp.shared_weights %} + WEIGHT_T* weights_grad_shft = W_grad + i * {{schedule.updated_config.weight_numel}} + lane_id; + {%- else %} + WEIGHT_T* weights_grad_shft = W_grad + lane_id; + {%- endif %} + + bool changeRow = (i < end - 1) && (col != cols[i+1]); + if(changeRow || i == end - 1) { + IRREP_T* dst = l1_grad_shft; + if(firstSegment) { + dst = workspace + {{backward_schedule.L1.dim}} * warp_id + lane_id; + firstSegment = false; + } + {{ store_ir_segments(segment.L1Map, "dst", "L1_grad_smem", "j") }} + __syncwarp(); + ROW_OPERATION({{segment.L1.dim}}, j, L1_grad_smem[j + lane_id] = 0.0f;) + } + + {{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }} + + {% if not schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];) + {% endif %} + } + } {%- endfor %} } \ No newline at end of file From 0b272809fffa436187b593e8c0a2115881f0c29e Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 1 May 2025 17:13:56 -0700 Subject: [PATCH 11/12] All tests passing. --- openequivariance/extension/convolution.hpp | 13 ++++----- .../convolution/LoopUnrollConv.py | 21 ++++---------- .../convolution/TensorProductConv.py | 5 +++- .../templates/loop_unroll_conv_atomic.cuh | 13 ++++++--- .../templates/loop_unroll_conv_det.cuh | 28 ++++++++++--------- tests/conv_test.py | 24 ++++------------ 6 files changed, 44 insertions(+), 60 deletions(-) diff --git a/openequivariance/extension/convolution.hpp b/openequivariance/extension/convolution.hpp index 048aaff6..44c79413 100644 --- a/openequivariance/extension/convolution.hpp +++ b/openequivariance/extension/convolution.hpp @@ -216,12 +216,12 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI } void double_backward( - void* L1_in, void* L2_in, void* W, void* L3_grad, - void* L1_dgrad, void* L2_dgrad, void* w_dgrad, - void* L1_grad, void* L2_grad, void* W_grad, void* L3_dgrad, - void* rows, void* cols, - uint64_t nnz, uint64_t node_count, - void* wspace, void* transpose_perm) { + void* L1_in, void* L2_in, void* W, void* L3_grad, + void* L1_dgrad, void* L2_dgrad, void* w_dgrad, + void* L1_grad, void* L2_grad, void* W_grad, void* L3_dgrad, + void* rows, void* cols, + uint64_t nnz, uint64_t node_count, + void* wspace, void* transpose_perm) { ConvData conv_data = {rows, cols, nnz, node_count}; void* args[] = { @@ -240,7 +240,6 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI } jit.execute(5, args, double_backward_config); - if(reinterpret_cast(wspace) != 0) { void *fixup_args[] = {&wspace, &L1_grad}; KernelLaunchConfig fixup_config; diff --git a/openequivariance/implementations/convolution/LoopUnrollConv.py b/openequivariance/implementations/convolution/LoopUnrollConv.py index 0c52f6ba..812304cb 100644 --- a/openequivariance/implementations/convolution/LoopUnrollConv.py +++ b/openequivariance/implementations/convolution/LoopUnrollConv.py @@ -97,6 +97,7 @@ def generate_double_backward_schedule(warps_per_block): self.forward_workspace_offset = None self.backward_workspace_offset = None + self.double_backwardB_offset = None workspace_size = 1 if deterministic: @@ -162,9 +163,10 @@ def register_torch_fakes(cls): class TorchJITConv: def __init__(self, kernel_plaintext: str, fwd_config: dict[str, int], - bwd_config: dict[str, int], + bwd_config: dict[str, int], + dbl_bwd_config: dict[str, int], kernel_dims: dict[str, int]) -> None: - self.kernel_plaintext, self.fwd_config, self.bwd_config, self.kernel_dims = kernel_plaintext, fwd_config, bwd_config, kernel_dims + self.kernel_plaintext, self.fwd_config, self.bwd_config, self.dbl_bwd_config, self.kernel_dims = kernel_plaintext, fwd_config, bwd_config, dbl_bwd_config, kernel_dims @classmethod def __obj_unflatten__(cls, flattened_product): @@ -174,7 +176,7 @@ def __len__(self): return 0 def __setstate__(self, state): - self.kernel_plaintext, self.fwd_config, self.bwd_config, self.kernel_dims = state + self.kernel_plaintext, self.fwd_config, self.bwd_config, self.dbl_bwd_config, self.kernel_dims = state @torch.library.register_fake("torch_tp_jit::jit_conv_forward") def fake_forward(jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm): @@ -205,19 +207,6 @@ def setup_context_double_backward(ctx, inputs, output): def double_backward(ctx, E, F, G): result = double_backward_op(ctx.jit, ctx.L1_in, ctx.L2_in, ctx.W, ctx.grad_output, E, F, G, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm) - - #jit, A, B, C, D, rows, cols, wspace, sender_perm = ctx.jit, ctx.L1_in, ctx.L2_in, ctx.grad_output, ctx.W, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm - #op1 = backward_op(jit, E, F, D, C, rows, cols, wspace, sender_perm) - #op2 = backward_op(jit, A, B, G, C, rows, cols, wspace, sender_perm) - #op3 = forward_op(jit, E, B, D, rows, cols, wspace, sender_perm) - #op4 = backward_op(jit, E, B, D, C, rows, cols, wspace, sender_perm) # op4 and op5 could be combined with op3 and op6 - #op5 = backward_op(jit, A, F, D, C, rows, cols, wspace, sender_perm) - #op6 = forward_op(jit, A, F, D, rows, cols, wspace, sender_perm) - #op7 = forward_op(jit, A, B, G, rows, cols, wspace, sender_perm) - - #return None, op1[0] + op2[0], op1[1] + op2[1], op4[2] + op5[2], (op3 + op6 + op7), None, None, None, None - #print(torch.norm(op7 - result[3])) - return None, result[0], result[1], result[2], result[3], None, None, None, None torch.library.register_autograd("torch_tp_jit::jit_conv_backward", double_backward, setup_context=setup_context_double_backward) diff --git a/openequivariance/implementations/convolution/TensorProductConv.py b/openequivariance/implementations/convolution/TensorProductConv.py index 792e57da..a0d87751 100644 --- a/openequivariance/implementations/convolution/TensorProductConv.py +++ b/openequivariance/implementations/convolution/TensorProductConv.py @@ -1,3 +1,4 @@ +from openequivariance import extlib from openequivariance.implementations.convolution.LoopUnrollConv import * from openequivariance.implementations.TensorProduct import TensorProduct import numpy as np @@ -14,7 +15,9 @@ def __init__(self, config, idx_dtype=np.int64, torch_op=True, deterministic=Fals torch_op=torch_op, deterministic=deterministic) self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device='cuda') - self.forward = self.forward_deterministic if deterministic else self.forward_atomic + + if extlib.TORCH_COMPILE: + self.forward = self.forward_deterministic if deterministic else self.forward_atomic @staticmethod def name(): diff --git a/openequivariance/templates/loop_unroll_conv_atomic.cuh b/openequivariance/templates/loop_unroll_conv_atomic.cuh index ef6502c4..c73dbbc5 100644 --- a/openequivariance/templates/loop_unroll_conv_atomic.cuh +++ b/openequivariance/templates/loop_unroll_conv_atomic.cuh @@ -4,7 +4,8 @@ {%- from 'macros.jinja' import transpose_load, transpose_store, - load_ir_segments, store_ir_segments, + load_ir_segments, load_ir_segments_force, + store_ir_segments, declare_smem_variables, set_launch_bound_variables with context %} @@ -37,6 +38,10 @@ __global__ void fixup_backward(void* workspace, IRREP_T* dst_ptr) { // Empty, no fixup } +__global__ void fixup_double_backwardB(void* workspace, IRREP_T* dst_ptr) { + // Empty, no fixup +} + __global__ void forward( IRREP_T* L1_in, IRREP_T* L2_in, @@ -198,9 +203,9 @@ __global__ void double_backward_A( {{ load_ir_segments(segment.L1Map, "l1", "L1_smem", "j") }} {{ load_ir_segments(segment.L2Map, "l2", "L2_smem", "j") }} ROW_OPERATION({{segment.L3.dim}}, j, L3_smem[j + lane_id] = 0.0f;) - + {%- if not forward_schedule.stream_weights %} - ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = w[{{segment.weight_offset}} + j + lane_id];) + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = w_dgrad[{{segment.weight_offset}} + j + lane_id];) {%- endif %} w_buffer = w_dgrad; @@ -225,7 +230,7 @@ __global__ void double_backward_A( {{ store_ir_segments(segment.L3Map, "l3", "L3_smem", "j") }} } - } {%- endfor %} + } {%- endfor %} } {%- for i, segment in enumerate(double_backward_schedule.segments) %} diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/templates/loop_unroll_conv_det.cuh index 7592729d..c7cee536 100644 --- a/openequivariance/templates/loop_unroll_conv_det.cuh +++ b/openequivariance/templates/loop_unroll_conv_det.cuh @@ -297,7 +297,7 @@ __global__ void double_backward_A( {{ load_ir_segments(segment.L2Map, "l2", "L2_smem", "j") }} {%- if not forward_schedule.stream_weights %} - ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = w[{{segment.weight_offset}} + j + lane_id];) + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = w_dgrad[{{segment.weight_offset}} + j + lane_id];) {%- endif %} w_buffer = w_dgrad; @@ -336,7 +336,7 @@ __global__ void double_backward_A( } {%- endfor %} } -{{ generate_fixup_kernel("fixup_double_backwardB", double_backward_schedule.launch_config.warp_size, backward_schedule.L1.dim, double_backwardB_offset) }} +{{ generate_fixup_kernel("fixup_double_backwardB", double_backward_schedule.launch_config.warp_size, double_backward_schedule.L1.dim, double_backwardB_offset) }} {%- for i, segment in enumerate(double_backward_schedule.segments) %} {{ generate_segment_kernel_backward(i, segment, double_backward_schedule.launch_config.warp_size, double_bwd=True) }} @@ -351,8 +351,9 @@ __global__ void double_backward_B( ConvData c, void* workspace_raw, unsigned {{idx_type}}* transpose_perm) { size_t num_products = c.nnz; - unsigned {{idx_type}}* rows = (unsigned {{idx_type}}*) c.rows; - unsigned {{idx_type}}* cols = (unsigned {{idx_type}}*) c.cols; + {{idx_type}}* rows = ({{idx_type}}*) c.cols; + {{idx_type}}* cols = ({{idx_type}}*) c.rows; + {{idx_type}}* tperm = ({{idx_type}}*) transpose_perm; extern __shared__ char s[]; {{ set_launch_bound_variables(schedule.launch_config) }} @@ -363,7 +364,7 @@ __global__ void double_backward_B( if(lane_id == 0) { if(start < end) { - dst_idxs[warp_id] = rows[start]; + dst_idxs[warp_id] = cols[start]; } else { dst_idxs[warp_id] = -1; @@ -380,18 +381,19 @@ __global__ void double_backward_B( for(size_t i = start; i < end; i++) { unsigned {{idx_type}} row = rows[i]; unsigned {{idx_type}} col = cols[i]; + {{idx_type}} tperm_idx = tperm[i]; IRREP_T* l1_shft = L1_dgrad + col * {{schedule.L1.dim}} + lane_id; - IRREP_T* l2_shft = L2_dgrad + i * {{schedule.L2.dim}} + lane_id; + IRREP_T* l2_shft = L2_dgrad + tperm_idx * {{schedule.L2.dim}} + lane_id; IRREP_T* l3_shft = L3_grad + row * {{schedule.L3.dim}} + lane_id; IRREP_T* l1_original = L1_in + col * {{schedule.L1.dim}} + lane_id; - IRREP_T* l2_original = L2_in + i * {{schedule.L2.dim}} + lane_id; + IRREP_T* l2_original = L2_in + tperm_idx * {{schedule.L2.dim}} + lane_id; {%- if not tpp.shared_weights %} - WEIGHT_T* w = W + i * {{tpp.weight_numel}}; - WEIGHT_T* wgrad = W_grad + i * {{tpp.weight_numel}}; - WEIGHT_T* wdgrad = W_dgrad + i * {{tpp.weight_numel}}; + WEIGHT_T* w = W + tperm_idx * {{tpp.weight_numel}}; + WEIGHT_T* wgrad = W_grad + tperm_idx * {{tpp.weight_numel}}; + WEIGHT_T* wdgrad = W_dgrad + tperm_idx * {{tpp.weight_numel}}; {%- else %} WEIGHT_T* w = W; WEIGHT_T* wgrad = W_grad; @@ -440,10 +442,10 @@ __global__ void double_backward_B( } IRREP_T* l1_grad_shft = L1_grad + col * {{schedule.L1.dim}} + lane_id; - IRREP_T* l2_grad_shft = L2_grad + i * {{schedule.L2.dim}} + lane_id; + IRREP_T* l2_grad_shft = L2_grad + tperm_idx * {{schedule.L2.dim}} + lane_id; {%- if not tpp.shared_weights %} - WEIGHT_T* weights_grad_shft = W_grad + i * {{schedule.updated_config.weight_numel}} + lane_id; + WEIGHT_T* weights_grad_shft = W_grad + tperm_idx * {{schedule.updated_config.weight_numel}} + lane_id; {%- else %} WEIGHT_T* weights_grad_shft = W_grad + lane_id; {%- endif %} @@ -452,7 +454,7 @@ __global__ void double_backward_B( if(changeRow || i == end - 1) { IRREP_T* dst = l1_grad_shft; if(firstSegment) { - dst = workspace + {{backward_schedule.L1.dim}} * warp_id + lane_id; + dst = workspace + {{schedule.L1.dim}} * warp_id + lane_id; firstSegment = false; } {{ store_ir_segments(segment.L1Map, "dst", "L1_grad_smem", "j") }} diff --git a/tests/conv_test.py b/tests/conv_test.py index a64b0296..2bf5e2ba 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -13,11 +13,7 @@ def check_result(self, result, fieldname): thresh = result["thresh"] assert result[fieldname]["pass"], f"{fieldname} observed error={error:.2f} >= {thresh}" - @pytest.fixture(params=[np.float32, - #np.float64 - ], ids=['F32', - #'F64' - ], scope='class') + @pytest.fixture(params=[np.float32, np.float64], ids=['F32', 'F64'], scope='class') def dtype(self, request): return request.param @@ -34,16 +30,13 @@ def graph(self, request): graph = load_graph("data/1drf_radius3.5.pickle") return graph - @pytest.fixture(params=[#'atomic', - 'deterministic' - ], scope='class') + @pytest.fixture(params=['atomic', 'deterministic'], scope='class') def conv_object(self, request, problem): if request.param == 'atomic': return oeq.TensorProductConv(problem, deterministic=False) elif request.param == 'deterministic': return oeq.TensorProductConv(problem, deterministic=True) - @pytest.mark.skip def test_tp_fwd(self, conv_object, graph): result = conv_object.test_correctness_forward(graph, thresh=3e-05, @@ -52,7 +45,6 @@ def test_tp_fwd(self, conv_object, graph): self.check_result(result, "output") - @pytest.mark.skip def test_tp_bwd(self, conv_object, graph): result = conv_object.test_correctness_backward(graph, thresh=3e-04, @@ -89,17 +81,11 @@ def problem(self, request, dtype): class TestUVWSingleIrrep(ConvCorrectness): muls = [ - (1, 1, 1), - #(4, 1, 4), (8, 1, 8), (16, 1, 16), - #(32, 1, 32), (5, 1, 5), (13, 1, 13), (33, 1, 33), (49, 1, 49), (64, 1, 64), - #(1, 2, 1), (1, 4, 1), (1, 16, 1), (1, 32, 1), (16, 3, 16) + (1, 1, 1), (4, 1, 4), (8, 1, 8), (16, 1, 16), (32, 1, 32), (5, 1, 5), (13, 1, 13), (33, 1, 33), (49, 1, 49), (64, 1, 64), + (1, 2, 1), (1, 4, 1), (1, 16, 1), (1, 32, 1), (16, 3, 16) ] - irs = [#(0, 0, 0), (1, 1, 1), (1, 0, 1), - #(1, 2, 1), - (5, 3, 5), - #(7, 2, 5) - ] + irs = [(0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (5, 3, 5), (7, 2, 5)] def id_func(m, i): return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" From fd90bebf07caae6b1b5e19d7a63bedc64e8dbfac Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 1 May 2025 19:06:31 -0700 Subject: [PATCH 12/12] Updated conv test fixture to download graph again. --- tests/conv_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/conv_test.py b/tests/conv_test.py index 2bf5e2ba..ca3cd2ab 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -22,12 +22,12 @@ def graph(self, request): download_prefix = "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" filename = request.param - #graph = None - #with tempfile.NamedTemporaryFile() as temp_file: - # urllib.request.urlretrieve(download_prefix + filename, temp_file.name) - # graph = load_graph(temp_file.name) + graph = None + with tempfile.NamedTemporaryFile() as temp_file: + urllib.request.urlretrieve(download_prefix + filename, temp_file.name) + graph = load_graph(temp_file.name) - graph = load_graph("data/1drf_radius3.5.pickle") + #graph = load_graph("data/1drf_radius3.5.pickle") return graph @pytest.fixture(params=['atomic', 'deterministic'], scope='class')