Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions openequivariance/extension/convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,41 +97,50 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI
JIT_IMPL jit;
KernelLaunchConfig forward_config;
KernelLaunchConfig backward_config;
KernelLaunchConfig double_backward_config;
bool is_uvw;

JITConvImpl(
std::string jit_kernel,
KernelLaunchConfig forward_config_i,
KernelLaunchConfig backward_config_i,
KernelLaunchConfig double_backward_config_i,
bool is_uvw_i) :
jit(jit_kernel),
forward_config(forward_config_i),
backward_config(backward_config_i),
double_backward_config(double_backward_config_i),
is_uvw(is_uvw_i) {

vector<string> kernels = {"forward", "backward", "fixup_forward", "fixup_backward"};
vector<string> kernels = {"forward", "backward", "fixup_forward", "fixup_backward", "double_backward_A", "double_backward_B", "fixup_double_backwardB"};

int opt_level = 3;
#ifdef HIP_BACKEND
if(is_uvw) {
opt_level = 1;
}
#endif
jit.compile(kernels, {{}, {}, {}, {}}, opt_level);
jit.compile(kernels, {{}, {}, {}, {}, {}, {}, {}}, opt_level);

if(forward_config.smem > 0) {
jit.set_max_smem(0, forward_config.smem);
jit.set_max_smem(4, forward_config.smem);
}

if(backward_config.smem > 0) {
jit.set_max_smem(1, backward_config.smem);
}

if(double_backward_config.smem > 0) {
jit.set_max_smem(5, double_backward_config.smem);
}
}

JITConvImpl(
std::string jit_kernel,
std::unordered_map<string, int64_t> fwd_dict,
std::unordered_map<string, int64_t> bwd_dict,
std::unordered_map<string, int64_t> dbl_bwd_dict,
std::unordered_map<string, int64_t> kernel_dims
) : JITConvImpl(
jit_kernel,
Expand All @@ -145,6 +154,11 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI
bwd_dict["num_threads"],
bwd_dict["smem"]
),
KernelLaunchConfig(
dbl_bwd_dict["num_blocks"],
dbl_bwd_dict["num_threads"],
dbl_bwd_dict["smem"]
),
kernel_dims["is_uvw"] == 1) { }

void exec_conv(
Expand Down Expand Up @@ -201,5 +215,40 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI
}
}

void double_backward(
void* L1_in, void* L2_in, void* W, void* L3_grad,
void* L1_dgrad, void* L2_dgrad, void* w_dgrad,
void* L1_grad, void* L2_grad, void* W_grad, void* L3_dgrad,
void* rows, void* cols,
uint64_t nnz, uint64_t node_count,
void* wspace, void* transpose_perm) {

ConvData conv_data = {rows, cols, nnz, node_count};
void* args[] = {
&L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad,
&L1_grad, &L2_grad, &W_grad, &L3_dgrad, &conv_data, &wspace, &transpose_perm
};

jit.execute(4, args, forward_config);
if(reinterpret_cast<uint64_t>(wspace) != 0) {
void *fixup_args[] = {&wspace, &L3_dgrad};
KernelLaunchConfig fixup_config;
fixup_config.num_blocks = forward_config.num_blocks;
fixup_config.num_threads = forward_config.num_threads;
fixup_config.smem = 0;
jit.execute(2, fixup_args, fixup_config);
}

jit.execute(5, args, double_backward_config);
if(reinterpret_cast<uint64_t>(wspace) != 0) {
void *fixup_args[] = {&wspace, &L1_grad};
KernelLaunchConfig fixup_config;
fixup_config.num_blocks = double_backward_config.num_blocks;
fixup_config.num_threads = double_backward_config.num_threads;
fixup_config.smem = 0;
jit.execute(6, fixup_args, fixup_config);
}
}

~JITConvImpl() = default;
};
1 change: 1 addition & 0 deletions openequivariance/extension/generic_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ PYBIND11_MODULE(generic_module, m) {
.def("backward_rawptrs", &ConvolutionImpl::backward_rawptrs);
py::class_<JITConvImpl<JITKernel>, ConvolutionImpl>(m, "JITConvImpl")
.def(py::init< std::string,
std::unordered_map<string, int64_t>,
std::unordered_map<string, int64_t>,
std::unordered_map<string, int64_t>,
std::unordered_map<string, int64_t>>());
Expand Down
75 changes: 66 additions & 9 deletions openequivariance/extension/torch_tp_jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,25 +215,32 @@ tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> jit_tp_double_

class TorchJITConv : public torch::CustomClassHolder {
public:
Map_t fwd_dict, bwd_dict, kernel_dims;
Map_t fwd_dict, bwd_dict, dbl_bwd_dict, kernel_dims;
JITConvImpl<JITKernel> internal;
int64_t L3_dim;

TorchJITConv(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t kernel_dims_i) :
TorchJITConv(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t dbl_bwd_dict_i, Map_t kernel_dims_i) :
fwd_dict(fwd_dict_i.copy()),
bwd_dict(bwd_dict_i.copy()),
dbl_bwd_dict(bwd_dict_i.copy()),
kernel_dims(kernel_dims_i.copy()),
internal(kernel_plaintext,
to_map(fwd_dict_i),
to_map(bwd_dict_i),
to_map(dbl_bwd_dict_i),
to_map(kernel_dims_i)
),
L3_dim(kernel_dims.at("L3_dim")) { }

tuple<tuple<string, string>, tuple<string, Map_t>, tuple<string, Map_t>, tuple<string, Map_t>> __obj_flatten__() {
tuple<tuple<string, string>,
tuple<string, Map_t>,
tuple<string, Map_t>,
tuple<string, Map_t>,
tuple<string, Map_t>> __obj_flatten__() {
return tuple(tuple("kernel_plaintext", internal.jit.kernel_plaintext),
tuple("fwd_config", fwd_dict),
tuple("bwd_config", bwd_dict),
tuple("dbl_bwd_config", dbl_bwd_dict),
tuple("kernel_dims", kernel_dims));
}

Expand Down Expand Up @@ -347,6 +354,55 @@ tuple<torch::Tensor, torch::Tensor, torch::Tensor> jit_conv_backward(
return tuple(L1_grad, L2_grad, W_grad);
}

tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> jit_conv_double_backward(
const c10::intrusive_ptr<TorchJITConv> &jit_instance,
const torch::Tensor &L1_in,
const torch::Tensor &L2_in,
const torch::Tensor &W,
const torch::Tensor &L3_grad,
const torch::Tensor &L1_dgrad,
const torch::Tensor &L2_dgrad,
const torch::Tensor &W_dgrad,
const torch::Tensor &rows,
const torch::Tensor &cols,
const torch::Tensor &workspace,
const torch::Tensor &transpose_perm) {

int64_t nnz = rows.sizes()[0];
int64_t node_count = L1_in.sizes()[0];
torch::Tensor L1_grad = torch::zeros(L1_in.sizes(), L1_in.options());
torch::Tensor L2_grad = torch::empty(L2_in.sizes(), L2_in.options());
torch::Tensor W_grad = torch::empty(W.sizes(), W.options());
torch::Tensor L3_dgrad = torch::zeros(L3_grad.sizes(), L3_grad.options());

torch::Tensor L1_in_contig = L1_in.contiguous();
torch::Tensor L2_in_contig = L2_in.contiguous();
torch::Tensor W_contig = W.contiguous();
torch::Tensor L3_grad_contig = L3_grad.contiguous();
torch::Tensor L1_dgrad_contig = L1_dgrad.contiguous();
torch::Tensor L2_dgrad_contig = L2_dgrad.contiguous();
torch::Tensor W_dgrad_contig = W_dgrad.contiguous();

torch::Tensor rows_contig = rows.contiguous();
torch::Tensor cols_contig = cols.contiguous();
torch::Tensor workspace_contig = workspace.contiguous();
torch::Tensor transpose_perm_contig = transpose_perm.contiguous();

jit_instance->internal.double_backward(
data_ptr(L1_in_contig), data_ptr(L2_in_contig),
data_ptr(W_contig), data_ptr(L3_grad_contig),
data_ptr(L1_dgrad_contig), data_ptr(L2_dgrad_contig),
data_ptr(W_dgrad_contig),
data_ptr(L1_grad), data_ptr(L2_grad),
data_ptr(W_grad), data_ptr(L3_dgrad),
data_ptr(rows_contig), data_ptr(cols_contig),
nnz, node_count,
data_ptr(workspace_contig), data_ptr(transpose_perm_contig)
);

return tuple(L1_grad, L2_grad, W_grad, L3_dgrad);
}

// ===========================================================

TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) {
Expand Down Expand Up @@ -376,7 +432,7 @@ TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) {


m.class_<TorchJITConv>("TorchJITConv")
.def(torch::init<string, Map_t, Map_t, Map_t>())
.def(torch::init<string, Map_t, Map_t, Map_t, Map_t>())
.def("__obj_flatten__", &TorchJITConv::__obj_flatten__)
.def("exec_conv_rawptrs", &TorchJITConv::exec_conv_rawptrs)
.def("backward_rawptrs", &TorchJITConv::backward_rawptrs)
Expand All @@ -386,27 +442,28 @@ TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) {
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<TorchJITConv>& self)
-> tuple<string, Map_t, Map_t, Map_t> {
return tuple(self->internal.jit.kernel_plaintext, self->fwd_dict, self->bwd_dict, self->kernel_dims);
-> tuple<string, Map_t, Map_t, Map_t, Map_t> {
return tuple(self->internal.jit.kernel_plaintext, self->fwd_dict, self->bwd_dict, self->dbl_bwd_dict, self->kernel_dims);
},
// __setstate__
[](tuple<string, Map_t, Map_t, Map_t> state)
[](tuple<string, Map_t, Map_t, Map_t, Map_t> state)
-> c10::intrusive_ptr<TorchJITConv> {
return c10::make_intrusive<TorchJITConv>(get<0>(state), get<1>(state), get<2>(state), get<3>(state));
return c10::make_intrusive<TorchJITConv>(get<0>(state), get<1>(state), get<2>(state), get<3>(state), get<4>(state));
});

m.def("jit_conv_forward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor");
m.def("jit_conv_backward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)");
m.def("jit_conv_double_backward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)");
};


TORCH_LIBRARY_IMPL(torch_tp_jit, CUDA, m) {
m.impl("jit_tp_forward", &jit_tp_forward);
m.impl("jit_tp_backward", &jit_tp_backward);
m.impl("jit_tp_double_backward", &jit_tp_double_backward);

m.impl("jit_conv_forward", &jit_conv_forward);
m.impl("jit_conv_backward", &jit_conv_backward);
m.impl("jit_conv_double_backward", &jit_conv_double_backward);
};

PYBIND11_MODULE(torch_tp_jit, m) {}
2 changes: 1 addition & 1 deletion openequivariance/implementations/ComputationSchedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __init__(self,

smem_limit -= 1
self.memory_per_warp = smem_limit // warps_per_block
self.memory_per_warp -= self.memory_per_warp % 4
self.memory_per_warp -= self.memory_per_warp % 8

# =====================================================================
# Shared memory partitioning functions
Expand Down
53 changes: 36 additions & 17 deletions openequivariance/implementations/convolution/LoopUnrollConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,21 @@ def generate_backward_schedule(warps_per_block):
warp_size=dp.warpsize,
include_scratch=self.is_uvw,
stream_weights=self.is_uvw)

def generate_double_backward_schedule(warps_per_block):
self.double_backward_schedule = ComputationSchedule(self.config,
smem_limit=dp.maxSharedMemPerBlock,
warps_per_block=warps_per_block,
warp_size=dp.warpsize,
block_count=dp.multiprocessorCount,
direction = "double_backward",
irrep_dtype = config.irrep_dtype,
weight_dtype = config.weight_dtype,
include_scratch=self.is_uvw,
stream_weights=self.is_uvw,
schedule_type=3)

scheduler_generators = [generate_forward_schedule, generate_backward_schedule]
scheduler_generators = [generate_forward_schedule, generate_backward_schedule, generate_double_backward_schedule]

for generate_schedule in scheduler_generators:
warp_count = 6
Expand All @@ -72,35 +85,47 @@ def generate_backward_schedule(warps_per_block):
for key in segment.L1Map.storeback_procedure:
segment.L1Map.storeback_procedure[key] = "atomic_accumulate"

for segment in self.double_backward_schedule.segments:
for key in segment.L1Map.storeback_procedure:
segment.L1Map.storeback_procedure[key] = "atomic_accumulate"


idx_type_map = {np.int32: "int", np.int64: "long"}

if self.torch_op:
self.setup_torch_module()

self.forward_workspace_offset = None
self.backward_workspace_offset = None
self.double_backwardB_offset = None

workspace_size = 1
if deterministic:
destination_index_bytes = 32 # Add extra to account for padding
workspace_size = max(
(self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.forward_schedule.total_warps,
(self.backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.backward_schedule.total_warps)
(self.backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.backward_schedule.total_warps,
(self.double_backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.double_backward_schedule.total_warps
)

self.forward_workspace_offset = self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize * self.forward_schedule.total_warps
self.backward_workspace_offset = self.backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize * self.backward_schedule.total_warps
self.double_backwardB_offset = self.double_backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize * self.double_backward_schedule.total_warps

self.forward_workspace_offset = (self.forward_workspace_offset + 7) // 8 * 8
self.backward_workspace_offset = (self.backward_workspace_offset + 7) // 8 * 8
self.double_backwardB_offset = (self.double_backwardB_offset + 7) // 8 * 8

self.allocate_workspace(workspace_size)

self.jit_kernel = template.render(
forward_schedule=self.forward_schedule,
backward_schedule=self.backward_schedule,
double_backward_schedule=self.double_backward_schedule,
idx_type=idx_type_map[idx_dtype],
forward_workspace_offset=self.forward_workspace_offset,
backward_workspace_offset=self.backward_workspace_offset)
backward_workspace_offset=self.backward_workspace_offset,
double_backwardB_offset=self.double_backwardB_offset)
self.jit_kernel = postprocess_kernel(self.jit_kernel)

if self.torch_op and extlib.TORCH_COMPILE:
Expand All @@ -115,6 +140,7 @@ def generate_backward_schedule(warps_per_block):
self.internal = internal_cls(self.jit_kernel,
vars(self.forward_schedule.launch_config),
vars(self.backward_schedule.launch_config),
vars(self.double_backward_schedule.launch_config),
{"L3_dim": self.L3.dim,
"is_uvw": int(self.is_uvw)})
logger.info("Kernel compiled!")
Expand All @@ -137,9 +163,10 @@ def register_torch_fakes(cls):
class TorchJITConv:
def __init__(self, kernel_plaintext: str,
fwd_config: dict[str, int],
bwd_config: dict[str, int],
bwd_config: dict[str, int],
dbl_bwd_config: dict[str, int],
kernel_dims: dict[str, int]) -> None:
self.kernel_plaintext, self.fwd_config, self.bwd_config, self.kernel_dims = kernel_plaintext, fwd_config, bwd_config, kernel_dims
self.kernel_plaintext, self.fwd_config, self.bwd_config, self.dbl_bwd_config, self.kernel_dims = kernel_plaintext, fwd_config, bwd_config, dbl_bwd_config, kernel_dims

@classmethod
def __obj_unflatten__(cls, flattened_product):
Expand All @@ -149,7 +176,7 @@ def __len__(self):
return 0

def __setstate__(self, state):
self.kernel_plaintext, self.fwd_config, self.bwd_config, self.kernel_dims = state
self.kernel_plaintext, self.fwd_config, self.bwd_config, self.dbl_bwd_config, self.kernel_dims = state

@torch.library.register_fake("torch_tp_jit::jit_conv_forward")
def fake_forward(jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm):
Expand All @@ -163,6 +190,7 @@ def fake_backward(jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, s
def register_autograd(cls):
forward_op = torch.ops.torch_tp_jit.jit_conv_forward
backward_op = torch.ops.torch_tp_jit.jit_conv_backward
double_backward_op = torch.ops.torch_tp_jit.jit_conv_double_backward

def setup_context(ctx, inputs, output):
ctx.jit, ctx.L1_in, ctx.L2_in, ctx.W, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm = inputs
Expand All @@ -178,17 +206,8 @@ def setup_context_double_backward(ctx, inputs, output):
ctx.inputs = inputs

def double_backward(ctx, E, F, G):
jit, A, B, C, D, rows, cols, wspace, sender_perm = ctx.jit, ctx.L1_in, ctx.L2_in, ctx.grad_output, ctx.W, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm

op1 = backward_op(jit, E, F, D, C, rows, cols, wspace, sender_perm)
op2 = backward_op(jit, A, B, G, C, rows, cols, wspace, sender_perm)
op3 = forward_op(jit, E, B, D, rows, cols, wspace, sender_perm)
op4 = backward_op(jit, E, B, D, C, rows, cols, wspace, sender_perm) # op4 and op5 could be combined with op3 and op6
op5 = backward_op(jit, A, F, D, C, rows, cols, wspace, sender_perm)
op6 = forward_op(jit, A, F, D, rows, cols, wspace, sender_perm)
op7 = forward_op(jit, A, B, G, rows, cols, wspace, sender_perm)

return None, op1[0] + op2[0], op1[1] + op2[1], op4[2] + op5[2], (op3 + op6 + op7), None, None, None, None
result = double_backward_op(ctx.jit, ctx.L1_in, ctx.L2_in, ctx.W, ctx.grad_output, E, F, G, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm)
return None, result[0], result[1], result[2], result[3], None, None, None, None

torch.library.register_autograd("torch_tp_jit::jit_conv_backward", double_backward, setup_context=setup_context_double_backward)

Expand Down
Loading