Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ trying our code. OpenEquivariance cannot accelerate all tensor products; see
:doc:`this page </supported_ops>` for a list of supported configurations.

.. autoclass:: openequivariance.TensorProduct
:members:
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn, to
:undoc-members:
:exclude-members: name

.. autoclass:: openequivariance.TensorProductConv
:members:
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn, to
:undoc-members:
:exclude-members: name

Expand Down
18 changes: 6 additions & 12 deletions openequivariance/benchmark/correctness_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,9 @@ def correctness_double_backward(
if impl == CUETensorProduct and problem.shared_weights:
weights = weights[np.newaxis, :]

weights_reordered = np.zeros_like(weights)
if tp.reorder_weights_e3nn_to_oeq is not None:
tp.reorder_weights_e3nn_to_oeq(
weights, weights_reordered, not tp.config.shared_weights
)
else:
weights_reordered = weights
weights_reordered = tp.reorder_weights_from_e3nn(
weights, not tp.config.shared_weights
)

in1_torch = torch.tensor(in1, device="cuda", requires_grad=True)
in2_torch = torch.tensor(in2, device="cuda", requires_grad=True)
Expand Down Expand Up @@ -248,11 +244,9 @@ def correctness_double_backward(
)

weights_grad = weights_torch.grad.detach().cpu().numpy()
if tp.reorder_weights_oeq_to_e3nn is not None:
weights_grad_copy = weights_grad.copy()
tp.reorder_weights_oeq_to_e3nn(
weights_grad_copy, weights_grad, not tp.config.shared_weights
)
weights_grad = tp.reorder_weights_to_e3nn(
weights_grad, not tp.config.shared_weights
)

tensors.append(
(
Expand Down
35 changes: 31 additions & 4 deletions openequivariance/implementations/ComputationSchedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def calculate_backward_smem(
smem=self.memory_per_warp * warps_per_block,
)

def reorder_weights(self, weights_in, weights_out, direction, has_batch_dim):
def reorder_weights(self, weights_in, direction, has_batch_dim):
"""
Reorders weights from the canonical e3nn form to the
form that LoopUnrollTP can ingest. Can also reorder the parameters
Expand All @@ -629,7 +629,9 @@ def reorder_weights(self, weights_in, weights_out, direction, has_batch_dim):
If has_batch_dim is true, the first dimension of the input weight matrix
is treated as the batch dimension.
"""
weights_out *= 0.0
import torch # TODO-someday: no need to specialize this to PyTorch

weights_out = torch.zeros_like(weights_in)
assert direction in ["forward", "backward"]
for i, child_inst in enumerate(self.problem_splitter.new_instructions):
parent_start, parent_end = (
Expand Down Expand Up @@ -670,16 +672,41 @@ def reorder_weights(self, weights_in, weights_out, direction, has_batch_dim):
sliced_weights = weights_in[tuple(parent_range)].reshape(parent_shape)[
tuple(weights_subrange)
]
weights_out[tuple(child_range)] = sliced_weights.transpose(
weights_out[tuple(child_range)] = sliced_weights.permute(
transpose_perm
).reshape(reshape_size)
elif direction == "backward":
transpose_child_shape = [child_shape[i] for i in transpose_perm]
sliced_weights = (
weights_in[tuple(child_range)]
.reshape(transpose_child_shape)
.transpose(transpose_perm)
.permute(transpose_perm)
)
weights_out[tuple(parent_range)].reshape(parent_shape)[
tuple(weights_subrange)
] = sliced_weights.flatten().reshape(child_shape)

return weights_out

def reorder_weights_numpy(self, weights_in, direction, has_batch_dim):
import torch

weights_in = torch.from_numpy(weights_in.copy())
result = self.reorder_weights(weights_in, direction, has_batch_dim)
return result.detach().cpu().numpy().copy()

def reorder_weights_from_e3nn(self, weights_in, has_batch_dim):
import torch

if isinstance(weights_in, np.ndarray):
return self.reorder_weights_numpy(weights_in, "forward", has_batch_dim)
elif isinstance(weights_in, torch.Tensor):
return self.reorder_weights(weights_in, "forward", has_batch_dim)

def reorder_weights_to_e3nn(self, weights_in, has_batch_dim):
import torch

if isinstance(weights_in, np.ndarray):
return self.reorder_weights_numpy(weights_in, "backward", has_batch_dim)
elif isinstance(weights_in, torch.Tensor):
return self.reorder_weights(weights_in, "backward", has_batch_dim)
35 changes: 10 additions & 25 deletions openequivariance/implementations/LoopUnrollTP.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,13 @@ def generate_double_backward_schedule(warps_per_block):
},
)
logger.info("Kernel compiled!")

logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB")

self.reorder_weights_e3nn_to_oeq = (
lambda input, output, has_batch_dim: self.forward_schedule.reorder_weights(
input, output, "forward", has_batch_dim
)
)
self.reorder_weights_oeq_to_e3nn = (
lambda input, output, has_batch_dim: self.forward_schedule.reorder_weights(
input, output, "backward", has_batch_dim
)
)
def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
return self.forward_schedule.reorder_weights_from_e3nn(weights, has_batch_dim)

def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
return self.forward_schedule.reorder_weights_to_e3nn(weights, has_batch_dim)

@classmethod
def register_torch_fakes(cls):
Expand Down Expand Up @@ -177,24 +171,15 @@ def __setstate__(self, state):
self.dbl_bwd_config = state["dbl_bwd_config"]
self.kernel_dims = state["kernel_dims"]

def exec_tensor_product_rawptr(
self, batch: int, L1_in: int, L2_in: int, L3_out: int, weights: int
) -> None:
def exec_tensor_product_rawptr(*args, **kwargs):
pass

def backward_rawptr(
self,
batch_size: int,
L1_in: int,
L1_grad: int,
L2_in: int,
L2_grad: int,
weights: int,
weights_grad: int,
L3_grad: int,
):
def backward_rawptr(*args, **kwargs):
pass

def get_L3_dim(self):
return self.kernel_dims["L3_dim"]

@torch.library.register_fake("libtorch_tp_jit::jit_tp_forward")
def fake_forward(jit, L1_in, L2_in, W):
L3_dim = None
Expand Down
3 changes: 3 additions & 0 deletions openequivariance/implementations/TensorProduct.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def _init_class(self):
self.forward = self.forward_opaque

def to(self, *args, **kwargs):
r"""
See `torch.nn.Module.to() <https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.to>`_.
"""
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
*args, **kwargs
)
Expand Down
57 changes: 37 additions & 20 deletions openequivariance/implementations/TensorProductBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(self, config: TPProblem, torch_op: bool = False):
config.irreps_out,
)
self.irrep_dtype, self.weight_dtype = config.irrep_dtype, config.weight_dtype
self.reorder_weights_e3nn_to_oeq, self.reorder_weights_oeq_to_e3nn = None, None

self.tp_id = TensorProductBase.next_tp_id
TensorProductBase.next_tp_id += 1
Expand All @@ -44,6 +43,34 @@ def __init__(self, config: TPProblem, torch_op: bool = False):
def __call__(self, L1_in, L2_in, weights):
return self.forward(L1_in, L2_in, weights)

def reorder_weights_from_e3nn(self, weights, has_batch_dim: bool = True):
r"""
Reorders weights from ``e3nn`` canonical order to the order used by ``oeq``.

:param weights: Weights in ``e3nn`` canonical order, either an
np.ndarray or a torch.Tensor. Tensor of dimensions ``[B, problem.weight_numel]``
when ``has_batch_dim=True``, otherwise of dimensions ``[problem.weight_numel]``.

:param has_batch_dim: If ``True``, treats the first dimension of weights as a batch dimension. Default: ``True``.

:return: Weights in ``oeq`` order. Output type is identical to input.
"""
return weights

def reorder_weights_to_e3nn(self, weights, has_batch_dim: bool = True):
r"""
Reorders weights from ``oeq`` canonical order to the order used by ``e3nn``.

:param weights: Weights in ``oeq`` canonical order, either an
np.ndarray or a torch.Tensor. Tensor of dimensions ``[B, problem.weight_numel]``
when ``has_batch_dim=True``, otherwise of dimensions ``[problem.weight_numel]``.

:param has_batch_dim: If ``True``, treats the first dimension of wieghts as a batch dimension. Default: ``True``.

:return: Weights in ``e3nn`` order. Output type is identical to input.
"""
return weights

def forward_raw(
self,
batch: np.uint64,
Expand Down Expand Up @@ -76,13 +103,9 @@ def forward_cpu(
L3_out: np.ndarray,
weights: np.ndarray,
) -> None:
weights_chunked = np.zeros_like(weights)
if self.reorder_weights_e3nn_to_oeq is not None:
self.reorder_weights_e3nn_to_oeq(
weights, weights_chunked, not self.config.shared_weights
)
else:
weights_chunked = weights
weights_chunked = self.reorder_weights_from_e3nn(
weights, not self.config.shared_weights
)

batch = L1_in.shape[0]
L1_d = DeviceBuffer(L1_in)
Expand All @@ -101,13 +124,9 @@ def forward_cpu(
def backward_cpu(
self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad
) -> None:
weights_chunked = np.zeros_like(weights)
if self.reorder_weights_e3nn_to_oeq is not None:
self.reorder_weights_e3nn_to_oeq(
weights, weights_chunked, not self.config.shared_weights
)
else:
weights_chunked = weights
weights_chunked = self.reorder_weights_from_e3nn(
weights, not self.config.shared_weights
)

batch = L1_in.shape[0]
L1_d, L2_d, L3_d = (
Expand Down Expand Up @@ -136,11 +155,9 @@ def backward_cpu(
L2_grad_d.copy_to_host()
weights_grad_d.copy_to_host()

if self.reorder_weights_oeq_to_e3nn is not None:
weights_grad_copy = weights_grad.copy()
self.reorder_weights_oeq_to_e3nn(
weights_grad_copy, weights_grad, not self.config.shared_weights
)
weights_grad[:] = self.reorder_weights_to_e3nn(
weights_grad, not self.config.shared_weights
)

def benchmark_forward(
self,
Expand Down
67 changes: 28 additions & 39 deletions openequivariance/implementations/convolution/ConvolutionBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ def __init__(
self.workspace_ptr = 0
self.workspace_size = 0

def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
r"""
See :py:func:`oeq.TensorProduct.reorder_weights_from_e3nn`.
"""
return weights

def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
r"""
See :py:func:`oeq.TensorProduct.reorder_weights_to_e3nn`.
"""
return weights

def allocate_workspace(self, size_bytes):
self.workspace_size = size_bytes
if self.torch_op:
Expand All @@ -136,13 +148,9 @@ def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph):
assert graph.rows.dtype == self.idx_dtype
assert graph.cols.dtype == self.idx_dtype

weights_chunked = np.zeros_like(weights)
if self.reorder_weights_e3nn_to_oeq is not None:
self.reorder_weights_e3nn_to_oeq(
weights, weights_chunked, not self.config.shared_weights
)
else:
weights_chunked = weights
weights_chunked = self.reorder_weights_from_e3nn(
weights, not self.config.shared_weights
)

L1_d, L2_d, weights_d = (
DeviceBuffer(L1_in),
Expand Down Expand Up @@ -174,13 +182,9 @@ def backward_cpu(
assert graph.rows.dtype == self.idx_dtype
assert graph.cols.dtype == self.idx_dtype

weights_chunked = np.zeros_like(weights)
if self.reorder_weights_e3nn_to_oeq is not None:
self.reorder_weights_e3nn_to_oeq(
weights, weights_chunked, not self.config.shared_weights
)
else:
weights_chunked = weights
weights_chunked = self.reorder_weights_from_e3nn(
weights, not self.config.shared_weights
)

L1_d = DeviceBuffer(L1_in)
L2_d = DeviceBuffer(L2_in)
Expand Down Expand Up @@ -219,11 +223,9 @@ def backward_cpu(
L2_grad_d.copy_to_host()
weights_grad_d.copy_to_host()

if self.reorder_weights_oeq_to_e3nn is not None:
weights_grad_copy = weights_grad.copy()
self.reorder_weights_oeq_to_e3nn(
weights_grad_copy, weights_grad, not self.config.shared_weights
)
weights_grad[:] = self.reorder_weights_to_e3nn(
weights_grad, not self.config.shared_weights
)

return L1_grad, L2_grad, weights_grad

Expand Down Expand Up @@ -712,17 +714,10 @@ def test_correctness_double_backward(
in1_torch = torch.tensor(in1, device="cuda", requires_grad=True)
in2_torch = torch.tensor(in2, device="cuda", requires_grad=True)

weights_reordered = np.zeros_like(weights)
if (
i == 0
and hasattr(self, "reorder_weights_e3nn_to_oeq")
and self.reorder_weights_e3nn_to_oeq is not None
):
self.reorder_weights_e3nn_to_oeq(
weights, weights_reordered, not self.config.shared_weights
)
else:
weights_reordered[:] = weights
weights_reordered = tp.reorder_weights_from_e3nn(
weights, not self.config.shared_weights
)

weights_torch = torch.tensor(
weights_reordered, device="cuda", requires_grad=True
)
Expand Down Expand Up @@ -754,15 +749,9 @@ def test_correctness_double_backward(
)

weights_grad = weights_torch.grad.detach().cpu().numpy()
if (
i == 0
and hasattr(self, "reorder_weights_e3nn_to_oeq")
and self.reorder_weights_oeq_to_e3nn is not None
):
weights_grad_copy = weights_grad.copy()
self.reorder_weights_oeq_to_e3nn(
weights_grad_copy, weights_grad, not self.config.shared_weights
)
weights_grad = tp.reorder_weights_to_e3nn(
weights_grad, not self.config.shared_weights
)

tensors.append(
(
Expand Down
Loading