In [1]:
import os
import torch
from torch.utils.cpp_extension import load

# Enable device-side assertions
os.environ["TORCH_USE_CUDA_DSA"] = "1"
# Compile the custom kernel
print("Load rspmm extension. This may take a while...")
path = os.getcwd()  # Use current working directory
rspmm = load(
    "rspmm",
    [os.path.join(path, "source/rspmm.cpp"), os.path.join(path, "source/rspmm.cu")],
    extra_cflags=["-DCUDA_OP"],
    extra_cuda_cflags=["--expt-relaxed-constexpr"],
)

# Define a directed test graph with more nodes and edges
edge_index = torch.tensor([
    [0, 1, 2, 3, 4, 5, 6, 7, 0, 2, 4, 6, 1, 3, 5, 7],  # Source nodes
    [1, 2, 3, 4, 5, 6, 7, 0, 2, 4, 6, 0, 3, 5, 7, 1]   # Target nodes
], dtype=torch.long)  # Directed edges

edge_type = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0], dtype=torch.long)  # Edge types for each direction
edge_weight = torch.tensor([1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4, 2.5], dtype=torch.float32)  # Weights for edges
edge_attr = torch.tensor([
    [0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8],
    [0.9, 1.0], [1.1, 1.2], [1.3, 1.4], [1.5, 1.6],
    [1.7, 1.8], [1.9, 2.0], [2.1, 2.2], [2.3, 2.4],
    [2.5, 2.6], [2.7, 2.8], [2.9, 3.0], [3.1, 3.2]
], dtype=torch.float32)  # Edge attributes

input_feat = torch.tensor([
    [0.1, 0.2, 0.3, 0.4],
    [0.5, 0.6, 0.7, 0.8],
    [0.9, 1.0, 1.1, 1.2],
    [1.3, 1.4, 1.5, 1.6],
    [1.7, 1.8, 1.9, 2.0],
    [2.1, 2.2, 2.3, 2.4],
    [2.5, 2.6, 2.7, 2.8],
    [2.9, 3.0, 3.1, 3.2]
], dtype=torch.float32)  # Node features (dim=4)
# Move tensors to GPU
device = torch.device("cuda")
edge_index = edge_index.to(device)
edge_type = edge_type.to(device)
edge_weight = edge_weight.to(device)
edge_attr = edge_attr.to(device)
input_feat = input_feat.to(device)

# Ensure edges are sorted according to the logic in generalized_rspmm
node_in, node_out = edge_index
key = node_in * (node_out.max() + 1) + node_out
order = key.argsort()
sorted_edge_index = edge_index[:, order]
sorted_edge_type = edge_type[order]
sorted_edge_weight = edge_weight[order]
sorted_edge_attr = edge_attr[order, :]

# Forward pass using custom kernel
output = rspmm.rspmm_add_mul_forward_cuda(
    sorted_edge_index, sorted_edge_type, sorted_edge_weight, sorted_edge_attr, input_feat
)

# Backward pass using custom kernel
output_grad = torch.ones_like(output).to(device)
weight_grad_cuda, edge_attr_grad_cuda, input_grad_cuda = rspmm.rspmm_add_mul_backward_cuda(
    sorted_edge_index, sorted_edge_type, sorted_edge_weight, sorted_edge_attr, input_feat, output, output_grad
)


# Move gradients back to CPU
input_grad_cuda = input_grad_cuda.cpu()
edge_attr_grad_cuda = edge_attr_grad_cuda.cpu()
sorted_edge_index_cpu = sorted_edge_index.cpu()
sorted_edge_type_cpu = sorted_edge_type.cpu()
sorted_edge_weight_cpu = sorted_edge_weight.cpu()
sorted_edge_attr_cpu = sorted_edge_attr.cpu()
input_feat_cpu = input_feat.cpu()
output_cpu = output.cpu()
output_grad_cpu = output_grad.cpu()


# Manual backward computation for validation
input_grad_manual = torch.zeros_like(input_feat).cpu()
edge_attr_grad_manual = torch.zeros_like(edge_attr).cpu()

for idx, (u, v) in enumerate(sorted_edge_index.t().cpu()):
    for d in range(input_feat.size(1)):
        attr = sorted_edge_attr[idx, d % sorted_edge_attr.size(1)].cpu()
        in_feat = input_feat[v, d].cpu()
        weight = sorted_edge_weight[idx].cpu()
        out_grad = output_grad[u, d].cpu()
        
        input_grad_manual[v, d] += out_grad * weight * attr
        edge_attr_grad_manual[idx, d % sorted_edge_attr.size(1)] += out_grad * weight * in_feat

# Backward pass using CPU implementation
weight_grad_cpu, edge_attr_grad_cpu, input_grad_cpu = rspmm.rspmm_add_mul_backward_cpu(
    sorted_edge_index_cpu, sorted_edge_type_cpu, sorted_edge_weight_cpu, sorted_edge_attr_cpu, input_feat_cpu, output_cpu, output_grad_cpu
)


# Assertions
assert torch.allclose(input_grad_cuda, input_grad_manual, atol=1e-6), "CUDA vs Manual: Input gradients mismatch"
assert torch.allclose(edge_attr_grad_cuda, edge_attr_grad_manual, atol=1e-6), "CUDA vs Manual: Edge attribute gradients mismatch"

assert torch.allclose(input_grad_cpu, input_grad_manual, atol=1e-6), "CPU vs Manual: Input gradients mismatch"
assert torch.allclose(edge_attr_grad_cpu, edge_attr_grad_manual, atol=1e-6), "CPU vs Manual: Edge attribute gradients mismatch"

assert torch.allclose(input_grad_cpu, input_grad_cuda, atol=1e-6), "CPU vs Cuda: Input gradients mismatch"
assert torch.allclose(edge_attr_grad_cpu, edge_attr_grad_cuda, atol=1e-6)
print("All gradient checks passed successfully.")


Load rspmm extension. This may take a while...
All gradient checks passed successfully.
