diff --git a/openequivariance/extension/convolution.hpp b/openequivariance/extension/convolution.hpp index 0e412bce..44c79413 100644 --- a/openequivariance/extension/convolution.hpp +++ b/openequivariance/extension/convolution.hpp @@ -97,19 +97,22 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI JIT_IMPL jit; KernelLaunchConfig forward_config; KernelLaunchConfig backward_config; + KernelLaunchConfig double_backward_config; bool is_uvw; JITConvImpl( std::string jit_kernel, KernelLaunchConfig forward_config_i, KernelLaunchConfig backward_config_i, + KernelLaunchConfig double_backward_config_i, bool is_uvw_i) : jit(jit_kernel), forward_config(forward_config_i), backward_config(backward_config_i), + double_backward_config(double_backward_config_i), is_uvw(is_uvw_i) { - vector kernels = {"forward", "backward", "fixup_forward", "fixup_backward"}; + vector kernels = {"forward", "backward", "fixup_forward", "fixup_backward", "double_backward_A", "double_backward_B", "fixup_double_backwardB"}; int opt_level = 3; #ifdef HIP_BACKEND @@ -117,21 +120,27 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI opt_level = 1; } #endif - jit.compile(kernels, {{}, {}, {}, {}}, opt_level); + jit.compile(kernels, {{}, {}, {}, {}, {}, {}, {}}, opt_level); if(forward_config.smem > 0) { jit.set_max_smem(0, forward_config.smem); + jit.set_max_smem(4, forward_config.smem); } if(backward_config.smem > 0) { jit.set_max_smem(1, backward_config.smem); } + + if(double_backward_config.smem > 0) { + jit.set_max_smem(5, double_backward_config.smem); + } } JITConvImpl( std::string jit_kernel, std::unordered_map fwd_dict, std::unordered_map bwd_dict, + std::unordered_map dbl_bwd_dict, std::unordered_map kernel_dims ) : JITConvImpl( jit_kernel, @@ -145,6 +154,11 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI bwd_dict["num_threads"], bwd_dict["smem"] ), + KernelLaunchConfig( + dbl_bwd_dict["num_blocks"], + dbl_bwd_dict["num_threads"], + dbl_bwd_dict["smem"] + ), kernel_dims["is_uvw"] == 1) { } void exec_conv( @@ -201,5 +215,40 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI } } + void double_backward( + void* L1_in, void* L2_in, void* W, void* L3_grad, + void* L1_dgrad, void* L2_dgrad, void* w_dgrad, + void* L1_grad, void* L2_grad, void* W_grad, void* L3_dgrad, + void* rows, void* cols, + uint64_t nnz, uint64_t node_count, + void* wspace, void* transpose_perm) { + + ConvData conv_data = {rows, cols, nnz, node_count}; + void* args[] = { + &L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad, + &L1_grad, &L2_grad, &W_grad, &L3_dgrad, &conv_data, &wspace, &transpose_perm + }; + + jit.execute(4, args, forward_config); + if(reinterpret_cast(wspace) != 0) { + void *fixup_args[] = {&wspace, &L3_dgrad}; + KernelLaunchConfig fixup_config; + fixup_config.num_blocks = forward_config.num_blocks; + fixup_config.num_threads = forward_config.num_threads; + fixup_config.smem = 0; + jit.execute(2, fixup_args, fixup_config); + } + + jit.execute(5, args, double_backward_config); + if(reinterpret_cast(wspace) != 0) { + void *fixup_args[] = {&wspace, &L1_grad}; + KernelLaunchConfig fixup_config; + fixup_config.num_blocks = double_backward_config.num_blocks; + fixup_config.num_threads = double_backward_config.num_threads; + fixup_config.smem = 0; + jit.execute(6, fixup_args, fixup_config); + } + } + ~JITConvImpl() = default; }; \ No newline at end of file diff --git a/openequivariance/extension/generic_module.cpp b/openequivariance/extension/generic_module.cpp index 9b5b3ae6..4231f02d 100644 --- a/openequivariance/extension/generic_module.cpp +++ b/openequivariance/extension/generic_module.cpp @@ -49,6 +49,7 @@ PYBIND11_MODULE(generic_module, m) { .def("backward_rawptrs", &ConvolutionImpl::backward_rawptrs); py::class_, ConvolutionImpl>(m, "JITConvImpl") .def(py::init< std::string, + std::unordered_map, std::unordered_map, std::unordered_map, std::unordered_map>()); diff --git a/openequivariance/extension/torch_tp_jit.cpp b/openequivariance/extension/torch_tp_jit.cpp index 353a527d..6711f3f1 100644 --- a/openequivariance/extension/torch_tp_jit.cpp +++ b/openequivariance/extension/torch_tp_jit.cpp @@ -215,25 +215,32 @@ tuple jit_tp_double_ class TorchJITConv : public torch::CustomClassHolder { public: - Map_t fwd_dict, bwd_dict, kernel_dims; + Map_t fwd_dict, bwd_dict, dbl_bwd_dict, kernel_dims; JITConvImpl internal; int64_t L3_dim; - TorchJITConv(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t kernel_dims_i) : + TorchJITConv(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t dbl_bwd_dict_i, Map_t kernel_dims_i) : fwd_dict(fwd_dict_i.copy()), bwd_dict(bwd_dict_i.copy()), + dbl_bwd_dict(bwd_dict_i.copy()), kernel_dims(kernel_dims_i.copy()), internal(kernel_plaintext, to_map(fwd_dict_i), to_map(bwd_dict_i), + to_map(dbl_bwd_dict_i), to_map(kernel_dims_i) ), L3_dim(kernel_dims.at("L3_dim")) { } - tuple, tuple, tuple, tuple> __obj_flatten__() { + tuple, + tuple, + tuple, + tuple, + tuple> __obj_flatten__() { return tuple(tuple("kernel_plaintext", internal.jit.kernel_plaintext), tuple("fwd_config", fwd_dict), tuple("bwd_config", bwd_dict), + tuple("dbl_bwd_config", dbl_bwd_dict), tuple("kernel_dims", kernel_dims)); } @@ -347,6 +354,55 @@ tuple jit_conv_backward( return tuple(L1_grad, L2_grad, W_grad); } +tuple jit_conv_double_backward( + const c10::intrusive_ptr &jit_instance, + const torch::Tensor &L1_in, + const torch::Tensor &L2_in, + const torch::Tensor &W, + const torch::Tensor &L3_grad, + const torch::Tensor &L1_dgrad, + const torch::Tensor &L2_dgrad, + const torch::Tensor &W_dgrad, + const torch::Tensor &rows, + const torch::Tensor &cols, + const torch::Tensor &workspace, + const torch::Tensor &transpose_perm) { + + int64_t nnz = rows.sizes()[0]; + int64_t node_count = L1_in.sizes()[0]; + torch::Tensor L1_grad = torch::zeros(L1_in.sizes(), L1_in.options()); + torch::Tensor L2_grad = torch::empty(L2_in.sizes(), L2_in.options()); + torch::Tensor W_grad = torch::empty(W.sizes(), W.options()); + torch::Tensor L3_dgrad = torch::zeros(L3_grad.sizes(), L3_grad.options()); + + torch::Tensor L1_in_contig = L1_in.contiguous(); + torch::Tensor L2_in_contig = L2_in.contiguous(); + torch::Tensor W_contig = W.contiguous(); + torch::Tensor L3_grad_contig = L3_grad.contiguous(); + torch::Tensor L1_dgrad_contig = L1_dgrad.contiguous(); + torch::Tensor L2_dgrad_contig = L2_dgrad.contiguous(); + torch::Tensor W_dgrad_contig = W_dgrad.contiguous(); + + torch::Tensor rows_contig = rows.contiguous(); + torch::Tensor cols_contig = cols.contiguous(); + torch::Tensor workspace_contig = workspace.contiguous(); + torch::Tensor transpose_perm_contig = transpose_perm.contiguous(); + + jit_instance->internal.double_backward( + data_ptr(L1_in_contig), data_ptr(L2_in_contig), + data_ptr(W_contig), data_ptr(L3_grad_contig), + data_ptr(L1_dgrad_contig), data_ptr(L2_dgrad_contig), + data_ptr(W_dgrad_contig), + data_ptr(L1_grad), data_ptr(L2_grad), + data_ptr(W_grad), data_ptr(L3_dgrad), + data_ptr(rows_contig), data_ptr(cols_contig), + nnz, node_count, + data_ptr(workspace_contig), data_ptr(transpose_perm_contig) + ); + + return tuple(L1_grad, L2_grad, W_grad, L3_dgrad); +} + // =========================================================== TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) { @@ -376,7 +432,7 @@ TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) { m.class_("TorchJITConv") - .def(torch::init()) + .def(torch::init()) .def("__obj_flatten__", &TorchJITConv::__obj_flatten__) .def("exec_conv_rawptrs", &TorchJITConv::exec_conv_rawptrs) .def("backward_rawptrs", &TorchJITConv::backward_rawptrs) @@ -386,20 +442,20 @@ TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) { .def_pickle( // __getstate__ [](const c10::intrusive_ptr& self) - -> tuple { - return tuple(self->internal.jit.kernel_plaintext, self->fwd_dict, self->bwd_dict, self->kernel_dims); + -> tuple { + return tuple(self->internal.jit.kernel_plaintext, self->fwd_dict, self->bwd_dict, self->dbl_bwd_dict, self->kernel_dims); }, // __setstate__ - [](tuple state) + [](tuple state) -> c10::intrusive_ptr { - return c10::make_intrusive(get<0>(state), get<1>(state), get<2>(state), get<3>(state)); + return c10::make_intrusive(get<0>(state), get<1>(state), get<2>(state), get<3>(state), get<4>(state)); }); m.def("jit_conv_forward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor"); m.def("jit_conv_backward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); + m.def("jit_conv_double_backward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); }; - TORCH_LIBRARY_IMPL(torch_tp_jit, CUDA, m) { m.impl("jit_tp_forward", &jit_tp_forward); m.impl("jit_tp_backward", &jit_tp_backward); @@ -407,6 +463,7 @@ TORCH_LIBRARY_IMPL(torch_tp_jit, CUDA, m) { m.impl("jit_conv_forward", &jit_conv_forward); m.impl("jit_conv_backward", &jit_conv_backward); + m.impl("jit_conv_double_backward", &jit_conv_double_backward); }; PYBIND11_MODULE(torch_tp_jit, m) {} \ No newline at end of file diff --git a/openequivariance/implementations/ComputationSchedule.py b/openequivariance/implementations/ComputationSchedule.py index 58a2fdab..ebb513dd 100644 --- a/openequivariance/implementations/ComputationSchedule.py +++ b/openequivariance/implementations/ComputationSchedule.py @@ -276,7 +276,7 @@ def __init__(self, smem_limit -= 1 self.memory_per_warp = smem_limit // warps_per_block - self.memory_per_warp -= self.memory_per_warp % 4 + self.memory_per_warp -= self.memory_per_warp % 8 # ===================================================================== # Shared memory partitioning functions diff --git a/openequivariance/implementations/convolution/LoopUnrollConv.py b/openequivariance/implementations/convolution/LoopUnrollConv.py index 4b6cbc67..812304cb 100644 --- a/openequivariance/implementations/convolution/LoopUnrollConv.py +++ b/openequivariance/implementations/convolution/LoopUnrollConv.py @@ -49,8 +49,21 @@ def generate_backward_schedule(warps_per_block): warp_size=dp.warpsize, include_scratch=self.is_uvw, stream_weights=self.is_uvw) + + def generate_double_backward_schedule(warps_per_block): + self.double_backward_schedule = ComputationSchedule(self.config, + smem_limit=dp.maxSharedMemPerBlock, + warps_per_block=warps_per_block, + warp_size=dp.warpsize, + block_count=dp.multiprocessorCount, + direction = "double_backward", + irrep_dtype = config.irrep_dtype, + weight_dtype = config.weight_dtype, + include_scratch=self.is_uvw, + stream_weights=self.is_uvw, + schedule_type=3) - scheduler_generators = [generate_forward_schedule, generate_backward_schedule] + scheduler_generators = [generate_forward_schedule, generate_backward_schedule, generate_double_backward_schedule] for generate_schedule in scheduler_generators: warp_count = 6 @@ -72,6 +85,11 @@ def generate_backward_schedule(warps_per_block): for key in segment.L1Map.storeback_procedure: segment.L1Map.storeback_procedure[key] = "atomic_accumulate" + for segment in self.double_backward_schedule.segments: + for key in segment.L1Map.storeback_procedure: + segment.L1Map.storeback_procedure[key] = "atomic_accumulate" + + idx_type_map = {np.int32: "int", np.int64: "long"} if self.torch_op: @@ -79,28 +97,35 @@ def generate_backward_schedule(warps_per_block): self.forward_workspace_offset = None self.backward_workspace_offset = None + self.double_backwardB_offset = None workspace_size = 1 if deterministic: destination_index_bytes = 32 # Add extra to account for padding workspace_size = max( (self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.forward_schedule.total_warps, - (self.backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.backward_schedule.total_warps) + (self.backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.backward_schedule.total_warps, + (self.double_backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.double_backward_schedule.total_warps + ) self.forward_workspace_offset = self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize * self.forward_schedule.total_warps self.backward_workspace_offset = self.backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize * self.backward_schedule.total_warps + self.double_backwardB_offset = self.double_backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize * self.double_backward_schedule.total_warps self.forward_workspace_offset = (self.forward_workspace_offset + 7) // 8 * 8 self.backward_workspace_offset = (self.backward_workspace_offset + 7) // 8 * 8 + self.double_backwardB_offset = (self.double_backwardB_offset + 7) // 8 * 8 self.allocate_workspace(workspace_size) self.jit_kernel = template.render( forward_schedule=self.forward_schedule, backward_schedule=self.backward_schedule, + double_backward_schedule=self.double_backward_schedule, idx_type=idx_type_map[idx_dtype], forward_workspace_offset=self.forward_workspace_offset, - backward_workspace_offset=self.backward_workspace_offset) + backward_workspace_offset=self.backward_workspace_offset, + double_backwardB_offset=self.double_backwardB_offset) self.jit_kernel = postprocess_kernel(self.jit_kernel) if self.torch_op and extlib.TORCH_COMPILE: @@ -115,6 +140,7 @@ def generate_backward_schedule(warps_per_block): self.internal = internal_cls(self.jit_kernel, vars(self.forward_schedule.launch_config), vars(self.backward_schedule.launch_config), + vars(self.double_backward_schedule.launch_config), {"L3_dim": self.L3.dim, "is_uvw": int(self.is_uvw)}) logger.info("Kernel compiled!") @@ -137,9 +163,10 @@ def register_torch_fakes(cls): class TorchJITConv: def __init__(self, kernel_plaintext: str, fwd_config: dict[str, int], - bwd_config: dict[str, int], + bwd_config: dict[str, int], + dbl_bwd_config: dict[str, int], kernel_dims: dict[str, int]) -> None: - self.kernel_plaintext, self.fwd_config, self.bwd_config, self.kernel_dims = kernel_plaintext, fwd_config, bwd_config, kernel_dims + self.kernel_plaintext, self.fwd_config, self.bwd_config, self.dbl_bwd_config, self.kernel_dims = kernel_plaintext, fwd_config, bwd_config, dbl_bwd_config, kernel_dims @classmethod def __obj_unflatten__(cls, flattened_product): @@ -149,7 +176,7 @@ def __len__(self): return 0 def __setstate__(self, state): - self.kernel_plaintext, self.fwd_config, self.bwd_config, self.kernel_dims = state + self.kernel_plaintext, self.fwd_config, self.bwd_config, self.dbl_bwd_config, self.kernel_dims = state @torch.library.register_fake("torch_tp_jit::jit_conv_forward") def fake_forward(jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm): @@ -163,6 +190,7 @@ def fake_backward(jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, s def register_autograd(cls): forward_op = torch.ops.torch_tp_jit.jit_conv_forward backward_op = torch.ops.torch_tp_jit.jit_conv_backward + double_backward_op = torch.ops.torch_tp_jit.jit_conv_double_backward def setup_context(ctx, inputs, output): ctx.jit, ctx.L1_in, ctx.L2_in, ctx.W, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm = inputs @@ -178,17 +206,8 @@ def setup_context_double_backward(ctx, inputs, output): ctx.inputs = inputs def double_backward(ctx, E, F, G): - jit, A, B, C, D, rows, cols, wspace, sender_perm = ctx.jit, ctx.L1_in, ctx.L2_in, ctx.grad_output, ctx.W, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm - - op1 = backward_op(jit, E, F, D, C, rows, cols, wspace, sender_perm) - op2 = backward_op(jit, A, B, G, C, rows, cols, wspace, sender_perm) - op3 = forward_op(jit, E, B, D, rows, cols, wspace, sender_perm) - op4 = backward_op(jit, E, B, D, C, rows, cols, wspace, sender_perm) # op4 and op5 could be combined with op3 and op6 - op5 = backward_op(jit, A, F, D, C, rows, cols, wspace, sender_perm) - op6 = forward_op(jit, A, F, D, rows, cols, wspace, sender_perm) - op7 = forward_op(jit, A, B, G, rows, cols, wspace, sender_perm) - - return None, op1[0] + op2[0], op1[1] + op2[1], op4[2] + op5[2], (op3 + op6 + op7), None, None, None, None + result = double_backward_op(ctx.jit, ctx.L1_in, ctx.L2_in, ctx.W, ctx.grad_output, E, F, G, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm) + return None, result[0], result[1], result[2], result[3], None, None, None, None torch.library.register_autograd("torch_tp_jit::jit_conv_backward", double_backward, setup_context=setup_context_double_backward) diff --git a/openequivariance/implementations/convolution/TensorProductConv.py b/openequivariance/implementations/convolution/TensorProductConv.py index 792e57da..a0d87751 100644 --- a/openequivariance/implementations/convolution/TensorProductConv.py +++ b/openequivariance/implementations/convolution/TensorProductConv.py @@ -1,3 +1,4 @@ +from openequivariance import extlib from openequivariance.implementations.convolution.LoopUnrollConv import * from openequivariance.implementations.TensorProduct import TensorProduct import numpy as np @@ -14,7 +15,9 @@ def __init__(self, config, idx_dtype=np.int64, torch_op=True, deterministic=Fals torch_op=torch_op, deterministic=deterministic) self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device='cuda') - self.forward = self.forward_deterministic if deterministic else self.forward_atomic + + if extlib.TORCH_COMPILE: + self.forward = self.forward_deterministic if deterministic else self.forward_atomic @staticmethod def name(): diff --git a/openequivariance/templates/loop_unroll_conv_atomic.cuh b/openequivariance/templates/loop_unroll_conv_atomic.cuh index 34541b28..c73dbbc5 100644 --- a/openequivariance/templates/loop_unroll_conv_atomic.cuh +++ b/openequivariance/templates/loop_unroll_conv_atomic.cuh @@ -4,7 +4,8 @@ {%- from 'macros.jinja' import transpose_load, transpose_store, - load_ir_segments, store_ir_segments, + load_ir_segments, load_ir_segments_force, + store_ir_segments, declare_smem_variables, set_launch_bound_variables with context %} @@ -37,6 +38,10 @@ __global__ void fixup_backward(void* workspace, IRREP_T* dst_ptr) { // Empty, no fixup } +__global__ void fixup_double_backwardB(void* workspace, IRREP_T* dst_ptr) { + // Empty, no fixup +} + __global__ void forward( IRREP_T* L1_in, IRREP_T* L2_in, @@ -155,4 +160,181 @@ __global__ void backward( {%- endif %} } {%- endfor %} } +} + +__global__ void double_backward_A( + IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, + IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, + IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad, + ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { + + extern __shared__ char s[]; + size_t num_products = c.nnz; + unsigned {{idx_type}}* rows = (unsigned {{idx_type}}*) c.rows; + unsigned {{idx_type}}* cols = (unsigned {{idx_type}}*) c.cols; + + {{ set_launch_bound_variables(forward_schedule.launch_config) }} + {%- set tpp = forward_schedule.updated_config %} + char* smem = s + {{forward_schedule.memory_per_warp}} * warp_loc; + + {%- for i, segment in enumerate(forward_schedule.segments) %} { + {{ declare_smem_variables(segment, "smem") }} + WEIGHT_T* w_buffer; + + for(size_t i = start; i < end; i++) { + unsigned {{idx_type}} row = rows[i]; unsigned {{idx_type}} col = cols[i]; + + IRREP_T* l1 = L1_in + col * {{forward_schedule.L1.dim}} + lane_id; + IRREP_T* l2 = L2_in + i * {{forward_schedule.L2.dim}} + lane_id; + IRREP_T* l3 = L3_dgrad + row * {{forward_schedule.L3.dim}} + lane_id; + + IRREP_T* l1_dgrad = L1_dgrad + col * {{forward_schedule.L1.dim}} + lane_id; + IRREP_T* l2_dgrad = L2_dgrad + i * {{forward_schedule.L2.dim}} + lane_id; + + {%- if not tpp.shared_weights %} + WEIGHT_T* w = W + i * {{tpp.weight_numel}}; + WEIGHT_T* w_dgrad = W_dgrad + i * {{tpp.weight_numel}}; + {%- else %} + WEIGHT_T* w = W; + WEIGHT_T* w_dgrad = W_dgrad; + {%- endif %} + + __syncwarp(); + {{ load_ir_segments(segment.L1Map, "l1", "L1_smem", "j") }} + {{ load_ir_segments(segment.L2Map, "l2", "L2_smem", "j") }} + ROW_OPERATION({{segment.L3.dim}}, j, L3_smem[j + lane_id] = 0.0f;) + + {%- if not forward_schedule.stream_weights %} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = w_dgrad[{{segment.weight_offset}} + j + lane_id];) + {%- endif %} + + w_buffer = w_dgrad; + + for(int n = 0; n < 3; n++) { + if(n == 1) { + {% if not forward_schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = w[{{segment.weight_offset}} + j + lane_id];) + {% endif %} + {{ load_ir_segments(segment.L2Map, "l2_dgrad", "L2_smem", "j") }} + w_buffer = w; + } + else if(n == 2) { + {{ load_ir_segments(segment.L2Map, "l2", "L2_smem", "j") }} + {{ load_ir_segments(segment.L1Map, "l1_dgrad", "L1_smem", "j") }} + } + + __syncwarp(); + forward_loop_unroll_{{i}}(L1_smem, L2_smem, w_buffer, weights_smem, L3_smem, scratch_smem, lane_id); + __syncwarp(); + } + + {{ store_ir_segments(segment.L3Map, "l3", "L3_smem", "j") }} + } + } {%- endfor %} +} + +{%- for i, segment in enumerate(double_backward_schedule.segments) %} +{{ generate_segment_kernel_backward(i, segment, double_backward_schedule.launch_config.warp_size, double_bwd=True) }} +{%- endfor %} + +{% set schedule = double_backward_schedule %} + +__global__ void double_backward_B( + IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, + IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, + IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad, + ConvData c, void* workspace, unsigned {{idx_type}}* transpose_perm) { + + size_t num_products = c.nnz; + unsigned {{idx_type}}* rows = (unsigned {{idx_type}}*) c.rows; + unsigned {{idx_type}}* cols = (unsigned {{idx_type}}*) c.cols; + + extern __shared__ char s[]; + {{ set_launch_bound_variables(schedule.launch_config) }} + char* smem = s + {{schedule.memory_per_warp}} * warp_loc; + + {%- set tpp = schedule.updated_config %} + + {%- for i, segment in enumerate(schedule.segments) %} { + {{ declare_smem_variables(segment, "smem") }} + for(size_t i = start; i < end; i++) { + unsigned {{idx_type}} row = rows[i]; unsigned {{idx_type}} col = cols[i]; + + IRREP_T* l1_shft = L1_dgrad + col * {{schedule.L1.dim}} + lane_id; + IRREP_T* l2_shft = L2_dgrad + i * {{schedule.L2.dim}} + lane_id; + IRREP_T* l3_shft = L3_grad + row * {{schedule.L3.dim}} + lane_id; + + IRREP_T* l1_original = L1_in + col * {{schedule.L1.dim}} + lane_id; + IRREP_T* l2_original = L2_in + i * {{schedule.L2.dim}} + lane_id; + + {%- if not tpp.shared_weights %} + WEIGHT_T* w = W + i * {{tpp.weight_numel}}; + WEIGHT_T* wgrad = W_grad + i * {{tpp.weight_numel}}; + WEIGHT_T* wdgrad = W_dgrad + i * {{tpp.weight_numel}}; + {%- else %} + WEIGHT_T* w = W; + WEIGHT_T* wgrad = W_grad; + WEIGHT_T* wdgrad = W_dgrad; + {%- endif %} + WEIGHT_T* weights_shft = w + lane_id; + WEIGHT_T* weights_dgrad_shft = wdgrad + lane_id; + + {{ load_ir_segments(segment.L3Map, "l3_shft", "L3_grad_smem", "j") }} + {{ load_ir_segments(segment.L1Map, "l1_shft", "L1_smem", "j") }} + {{ load_ir_segments(segment.L2Map, "l2_shft", "L2_smem", "j") }} + {{ load_ir_segments(segment.L2Map, "l2_original", "L2_dgrad_smem", "j") }} + + __syncwarp(); + {%- if not segment.L1Map.persist_load %} + ROW_OPERATION({{segment.L1.dim}}, j, L1_grad_smem[j + lane_id] = 0.0f;) + {%- endif %} + + {%- if not segment.L2Map.persist_load %} + ROW_OPERATION({{segment.L2.dim}}, j, L2_grad_smem[j + lane_id] = 0.0f;) + {%- endif %} + + {% if not schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = weights_shft[{{segment.weight_offset}} + j];) + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_smem[j + lane_id] = 0.0;) + {% endif %} + + WEIGHT_T* w_buffer = w; + IRREP_T* L2_buffer = L2_smem; + IRREP_T* L2_dgrad_buffer = L2_dgrad_smem; + + for(int n = 0; n < 2; n++) { + if(n == 1) { + {{ load_ir_segments(segment.L1Map, "l1_original", "L1_smem", "j") }} + + {% if not schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = weights_dgrad_shft[{{segment.weight_offset}} + j];) + {% endif %} + w_buffer = wdgrad; + L2_buffer = L2_dgrad_smem; + L2_dgrad_buffer = L2_smem; + } + + __syncwarp(); + double_backward_loop_unroll_{{i}}(L1_smem, L2_buffer, w_buffer, weights_smem, L3_grad_smem, + L1_grad_smem, L2_grad_smem, L2_dgrad_buffer, n, wgrad, weights_grad_smem, scratch_smem, lane_id); + __syncwarp(); + } + + IRREP_T* l1_grad_shft = L1_grad + col * {{schedule.L1.dim}} + lane_id; + IRREP_T* l2_grad_shft = L2_grad + i * {{schedule.L2.dim}} + lane_id; + + {%- if not tpp.shared_weights %} + WEIGHT_T* weights_grad_shft = W_grad + i * {{schedule.updated_config.weight_numel}} + lane_id; + {%- else %} + WEIGHT_T* weights_grad_shft = W_grad + lane_id; + {%- endif %} + + {{ store_ir_segments(segment.L1Map, "l1_grad_shft", "L1_grad_smem", "j") }} + {{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }} + + {% if not schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];) + {% endif %} + } + } {%- endfor %} } \ No newline at end of file diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/templates/loop_unroll_conv_det.cuh index 507c7435..c7cee536 100644 --- a/openequivariance/templates/loop_unroll_conv_det.cuh +++ b/openequivariance/templates/loop_unroll_conv_det.cuh @@ -218,7 +218,6 @@ __global__ void backward( __syncwarp(); bool changeRow = (i < end - 1) && (col != cols[i+1]); - if(changeRow || i == end - 1) { IRREP_T* dst = l1_grad_shft; if(firstSegment) { @@ -237,4 +236,237 @@ __global__ void backward( {%- endif %} } } {%- endfor %} +} + + +__global__ void double_backward_A( + IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, + IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, + IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad, + ConvData c, void* workspace_raw, unsigned {{idx_type}}* transpose_perm) { + + extern __shared__ char s[]; + size_t num_products = c.nnz; + unsigned {{idx_type}}* rows = (unsigned {{idx_type}}*) c.rows; + unsigned {{idx_type}}* cols = (unsigned {{idx_type}}*) c.cols; + + {{ set_launch_bound_variables(forward_schedule.launch_config) }} + + IRREP_T* workspace = (IRREP_T*) workspace_raw; + {{idx_type}}* dst_idxs = ({{idx_type}}*) ((char*) workspace + {{forward_workspace_offset}}); + + if(lane_id == 0) { + if(start < end) { + dst_idxs[warp_id] = rows[start]; + } + else { + dst_idxs[warp_id] = -1; + } + } + + {%- set tpp = forward_schedule.updated_config %} + char* smem = s + {{forward_schedule.memory_per_warp}} * warp_loc; + + {%- for i, segment in enumerate(forward_schedule.segments) %} { + {{ declare_smem_variables(segment, "smem") }} + WEIGHT_T* w_buffer; + + bool firstSegment = true; + ROW_OPERATION({{segment.L3.dim}}, j, L3_smem[j + lane_id] = 0.0f;) + + for(size_t i = start; i < end; i++) { + unsigned {{idx_type}} row = rows[i]; unsigned {{idx_type}} col = cols[i]; + + IRREP_T* l1 = L1_in + col * {{forward_schedule.L1.dim}} + lane_id; + IRREP_T* l2 = L2_in + i * {{forward_schedule.L2.dim}} + lane_id; + IRREP_T* l3 = L3_dgrad + row * {{forward_schedule.L3.dim}} + lane_id; + + IRREP_T* l1_dgrad = L1_dgrad + col * {{forward_schedule.L1.dim}} + lane_id; + IRREP_T* l2_dgrad = L2_dgrad + i * {{forward_schedule.L2.dim}} + lane_id; + + {%- if not tpp.shared_weights %} + WEIGHT_T* w = W + i * {{tpp.weight_numel}}; + WEIGHT_T* w_dgrad = W_dgrad + i * {{tpp.weight_numel}}; + {%- else %} + WEIGHT_T* w = W; + WEIGHT_T* w_dgrad = W_dgrad; + {%- endif %} + + __syncwarp(); + {{ load_ir_segments(segment.L1Map, "l1", "L1_smem", "j") }} + {{ load_ir_segments(segment.L2Map, "l2", "L2_smem", "j") }} + + {%- if not forward_schedule.stream_weights %} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = w_dgrad[{{segment.weight_offset}} + j + lane_id];) + {%- endif %} + + w_buffer = w_dgrad; + + for(int n = 0; n < 3; n++) { + if(n == 1) { + {% if not forward_schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = w[{{segment.weight_offset}} + j + lane_id];) + {% endif %} + {{ load_ir_segments(segment.L2Map, "l2_dgrad", "L2_smem", "j") }} + w_buffer = w; + } + else if(n == 2) { + {{ load_ir_segments(segment.L2Map, "l2", "L2_smem", "j") }} + {{ load_ir_segments(segment.L1Map, "l1_dgrad", "L1_smem", "j") }} + } + + __syncwarp(); + forward_loop_unroll_{{i}}(L1_smem, L2_smem, w_buffer, weights_smem, L3_smem, scratch_smem, lane_id); + __syncwarp(); + } + + bool changeRow = (i < end - 1) && (row != rows[i+1]); + if(changeRow || i == end - 1) { + IRREP_T* dst = l3; + if(firstSegment) { + dst = workspace + {{forward_schedule.L3.dim}} * warp_id + lane_id; + firstSegment = false; + } + {{ store_ir_segments(segment.L3Map, "dst", "L3_smem", "j") }} + __syncwarp(); + + ROW_OPERATION({{segment.L3.dim}}, j, L3_smem[j + lane_id] = 0.0f;) + } + } + } {%- endfor %} +} + +{{ generate_fixup_kernel("fixup_double_backwardB", double_backward_schedule.launch_config.warp_size, double_backward_schedule.L1.dim, double_backwardB_offset) }} + +{%- for i, segment in enumerate(double_backward_schedule.segments) %} +{{ generate_segment_kernel_backward(i, segment, double_backward_schedule.launch_config.warp_size, double_bwd=True) }} +{%- endfor %} + +{% set schedule = double_backward_schedule %} + +__global__ void double_backward_B( + IRREP_T* L1_in, IRREP_T* L2_in, WEIGHT_T* W, IRREP_T* L3_grad, + IRREP_T* L1_dgrad, IRREP_T* L2_dgrad, IRREP_T* W_dgrad, + IRREP_T* L1_grad, IRREP_T* L2_grad, WEIGHT_T* W_grad, IRREP_T* L3_dgrad, + ConvData c, void* workspace_raw, unsigned {{idx_type}}* transpose_perm) { + + size_t num_products = c.nnz; + {{idx_type}}* rows = ({{idx_type}}*) c.cols; + {{idx_type}}* cols = ({{idx_type}}*) c.rows; + {{idx_type}}* tperm = ({{idx_type}}*) transpose_perm; + + extern __shared__ char s[]; + {{ set_launch_bound_variables(schedule.launch_config) }} + char* smem = s + {{schedule.memory_per_warp}} * warp_loc; + + IRREP_T* workspace = (IRREP_T*) workspace_raw; + {{idx_type}}* dst_idxs = ({{idx_type}}*) ((char*) workspace + {{double_backwardB_offset}}); + + if(lane_id == 0) { + if(start < end) { + dst_idxs[warp_id] = cols[start]; + } + else { + dst_idxs[warp_id] = -1; + } + } + + {%- set tpp = schedule.updated_config %} + + {%- for i, segment in enumerate(schedule.segments) %} { + {{ declare_smem_variables(segment, "smem") }} + + bool firstSegment = true; + ROW_OPERATION({{segment.L1.dim}}, j, L1_grad_smem[j + lane_id] = 0.0f;) + + for(size_t i = start; i < end; i++) { + unsigned {{idx_type}} row = rows[i]; unsigned {{idx_type}} col = cols[i]; + {{idx_type}} tperm_idx = tperm[i]; + + IRREP_T* l1_shft = L1_dgrad + col * {{schedule.L1.dim}} + lane_id; + IRREP_T* l2_shft = L2_dgrad + tperm_idx * {{schedule.L2.dim}} + lane_id; + IRREP_T* l3_shft = L3_grad + row * {{schedule.L3.dim}} + lane_id; + + IRREP_T* l1_original = L1_in + col * {{schedule.L1.dim}} + lane_id; + IRREP_T* l2_original = L2_in + tperm_idx * {{schedule.L2.dim}} + lane_id; + + {%- if not tpp.shared_weights %} + WEIGHT_T* w = W + tperm_idx * {{tpp.weight_numel}}; + WEIGHT_T* wgrad = W_grad + tperm_idx * {{tpp.weight_numel}}; + WEIGHT_T* wdgrad = W_dgrad + tperm_idx * {{tpp.weight_numel}}; + {%- else %} + WEIGHT_T* w = W; + WEIGHT_T* wgrad = W_grad; + WEIGHT_T* wdgrad = W_dgrad; + {%- endif %} + + WEIGHT_T* weights_shft = w + lane_id; + WEIGHT_T* weights_dgrad_shft = wdgrad + lane_id; + + {{ load_ir_segments(segment.L3Map, "l3_shft", "L3_grad_smem", "j") }} + {{ load_ir_segments(segment.L1Map, "l1_shft", "L1_smem", "j") }} + {{ load_ir_segments(segment.L2Map, "l2_shft", "L2_smem", "j") }} + {{ load_ir_segments(segment.L2Map, "l2_original", "L2_dgrad_smem", "j") }} + + __syncwarp(); + + {%- if not segment.L2Map.persist_load %} + ROW_OPERATION({{segment.L2.dim}}, j, L2_grad_smem[j + lane_id] = 0.0f;) + {%- endif %} + + {% if not schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = weights_shft[{{segment.weight_offset}} + j];) + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_smem[j + lane_id] = 0.0;) + {% endif %} + + WEIGHT_T* w_buffer = w; + IRREP_T* L2_buffer = L2_smem; + IRREP_T* L2_dgrad_buffer = L2_dgrad_smem; + + for(int n = 0; n < 2; n++) { + if(n == 1) { + {{ load_ir_segments(segment.L1Map, "l1_original", "L1_smem", "j") }} + + {% if not schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_smem[j + lane_id] = weights_dgrad_shft[{{segment.weight_offset}} + j];) + {% endif %} + w_buffer = wdgrad; + L2_buffer = L2_dgrad_smem; + L2_dgrad_buffer = L2_smem; + } + + __syncwarp(); + double_backward_loop_unroll_{{i}}(L1_smem, L2_buffer, w_buffer, weights_smem, L3_grad_smem, + L1_grad_smem, L2_grad_smem, L2_dgrad_buffer, n, wgrad, weights_grad_smem, scratch_smem, lane_id); + __syncwarp(); + } + + IRREP_T* l1_grad_shft = L1_grad + col * {{schedule.L1.dim}} + lane_id; + IRREP_T* l2_grad_shft = L2_grad + tperm_idx * {{schedule.L2.dim}} + lane_id; + + {%- if not tpp.shared_weights %} + WEIGHT_T* weights_grad_shft = W_grad + tperm_idx * {{schedule.updated_config.weight_numel}} + lane_id; + {%- else %} + WEIGHT_T* weights_grad_shft = W_grad + lane_id; + {%- endif %} + + bool changeRow = (i < end - 1) && (col != cols[i+1]); + if(changeRow || i == end - 1) { + IRREP_T* dst = l1_grad_shft; + if(firstSegment) { + dst = workspace + {{schedule.L1.dim}} * warp_id + lane_id; + firstSegment = false; + } + {{ store_ir_segments(segment.L1Map, "dst", "L1_grad_smem", "j") }} + __syncwarp(); + ROW_OPERATION({{segment.L1.dim}}, j, L1_grad_smem[j + lane_id] = 0.0f;) + } + + {{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }} + + {% if not schedule.stream_weights%} + ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];) + {% endif %} + } + } {%- endfor %} } \ No newline at end of file diff --git a/tests/conv_test.py b/tests/conv_test.py index 782021ee..ca3cd2ab 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -81,12 +81,11 @@ def problem(self, request, dtype): class TestUVWSingleIrrep(ConvCorrectness): muls = [ - (1, 1, 1), (4, 1, 4), (8, 1, 8), (16, 1, 16), - (32, 1, 32), (5, 1, 5), (13, 1, 13), (33, 1, 33), (49, 1, 49), (64, 1, 64), + (1, 1, 1), (4, 1, 4), (8, 1, 8), (16, 1, 16), (32, 1, 32), (5, 1, 5), (13, 1, 13), (33, 1, 33), (49, 1, 49), (64, 1, 64), (1, 2, 1), (1, 4, 1), (1, 16, 1), (1, 32, 1), (16, 3, 16) ] - irs = [ (0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (5, 3, 5), (7, 2, 5) ] + irs = [(0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (5, 3, 5), (7, 2, 5)] def id_func(m, i): return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e"