In [1]:
import os
import torch
from torch.utils.cpp_extension import load


# Enable device-side assertions
os.environ["TORCH_USE_CUDA_DSA"] = "1"
# Compile the custom kernel
print("Load rspmm extension. This may take a while...")
path = os.getcwd()  # Use current working directory
rspmm = load(
    "rspmm",
    [os.path.join(path, "source/rspmm.cpp"), os.path.join(path, "source/rspmm.cu")],
    extra_cflags=["-DCUDA_OP"],
    extra_cuda_cflags=["--expt-relaxed-constexpr"],
)

# Define an undirected test graph with two edge types
edge_index = torch.tensor([[0, 1, 1, 2, 2, 0], [1, 0, 2, 1, 0, 2]], dtype=torch.long)  # Undirected edges
edge_type = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.long)  # Edge types for each direction
edge_weight = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.float32)  # Weights for edges
edge_attr = torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8], [0.9, 1.0], [1.1, 1.2]], dtype=torch.float32)  # Edge attributes
relation = torch.tensor([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]], dtype=torch.float32)  # Placeholder
input_feat = torch.tensor([[0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2], [1.3, 1.4, 1.5, 1.6]], dtype=torch.float32)  # Node features (dim=4)
# Minimal test graph: 2 nodes, 1 edge
#edge_index = torch.tensor([[0,1], [1,0]], dtype=torch.int64)
#edge_type = torch.tensor([0,1], dtype=torch.int64)
#edge_weight = torch.tensor([1.0,1.0], dtype=torch.float32)
#edge_attr = torch.tensor([[0.1], [0.2]], dtype=torch.float32)
#relation = torch.tensor([[1.0, 1.0],[1.0, 1.0]], dtype=torch.float32)  # Placeholder
#input_feat = torch.tensor([[0.1, 0.2], [0.3, 0.4]], dtype=torch.float32)

# Move tensors to GPU
device = torch.device("cuda")
edge_index = edge_index.to(device)
edge_type = edge_type.to(device)
edge_weight = edge_weight.to(device)
edge_attr = edge_attr.to(device)
relation = relation.to(device)
input_feat = input_feat.to(device)

# Ensure edges are sorted according to the logic in generalized_rspmm
node_in, node_out = edge_index
key = node_in * (node_out.max() + 1) + node_out
order = key.argsort()
sorted_edge_index = edge_index[:, order]
sorted_edge_type = edge_type[order]
sorted_edge_weight = edge_weight[order]
sorted_edge_attr = edge_attr[order, :]

# Forward pass using custom kernel
output = rspmm.rspmm_add_mul_forward_cuda(
    sorted_edge_index, sorted_edge_type, sorted_edge_weight, sorted_edge_attr, relation, input_feat
)

# Move output back to CPU for comparison
output = output.cpu()

# Manual output computation for validation
expected_output = torch.zeros_like(input_feat).cpu()
for idx, (u, v) in enumerate(sorted_edge_index.t().cpu()):
    for d in range(input_feat.size(1)):
        attr = sorted_edge_attr[idx, d % sorted_edge_attr.size(1)].cpu()
        in_feat = input_feat[v, d].cpu()
        weight = sorted_edge_weight[idx].cpu()
        expected_output[u, d] += weight * attr * in_feat

# Compare the computed output with the expected one
print("Computed output:")
print(output)

print("Expected output:")
print(expected_output)

# Validation
assert torch.allclose(output, expected_output, atol=1e-6), "Output mismatch detected!"
print("Output computation test passed successfully.")


Load rspmm extension. This may take a while...
block_ptr: 2, offset_ptr: 0, global_ptr: 0
block_ptr: 4, offset_ptr: 0, global_ptr: 0
block_ptr: 0, offset_ptr: 0, global_ptr: 0
block_ptr: 2, offset_ptr: 0, global_ptr: 1
block_ptr: 4, offset_ptr: 0, global_ptr: 1
block_ptr: 0, offset_ptr: 0, global_ptr: 1
Computed output:
tensor([[1.5200, 1.8800, 1.7600, 2.1600],
        [0.8000, 1.0800, 0.9600, 1.2800],
        [1.0800, 1.4000, 1.4000, 1.7600]])
Expected output:
tensor([[1.5200, 1.8800, 1.7600, 2.1600],
        [0.8000, 1.0800, 0.9600, 1.2800],
        [1.0800, 1.4000, 1.4000, 1.7600]])
Output computation test passed successfully.
