In [15]:
import torch
import time

try:
    from torch_scatter import segment_csr as torch_scatter_segment_csr
    from torch_scatter import scatter_softmax as torch_scatter_softmax
    from torch_scatter import scatter as torch_scatter_scatter
    HAS_TORCH_SCATTER = True
except ImportError:
    HAS_TORCH_SCATTER = False
    print("torch_scatter not found. Comparisons will be skipped.")

In [9]:
def benchmark_function(func, *args, num_runs=100):
    torch.cuda.synchronize()
    start_time = time.time()
    for _ in range(num_runs):
        result = func(*args)
        torch.cuda.synchronize()
    end_time = time.time()
    return (end_time - start_time) / num_runs, result

In [20]:
def segment_csr(src: torch.Tensor, indptr: torch.Tensor, reduce: str = 'sum') -> torch.Tensor:
    if reduce not in ['sum', 'mean', 'min', 'max']:
        raise ValueError("reduce must be one of 'sum', 'mean', 'min', or 'max'")
    
    indptr = indptr.squeeze()
    segment_lengths = indptr[1:] - indptr[:-1]
    index = torch.repeat_interleave(torch.arange(len(segment_lengths), device=src.device), segment_lengths)
    
    if reduce == 'sum':
        out = torch.zeros(len(segment_lengths), *src.shape[1:], device=src.device)
        out.scatter_add_(0, index.view(-1, *([1] * (src.dim() - 1))).expand_as(src), src)
    elif reduce == 'mean':
        out = torch.zeros(len(segment_lengths), *src.shape[1:], device=src.device)
        out.scatter_add_(0, index.view(-1, *([1] * (src.dim() - 1))).expand_as(src), src)
        out /= segment_lengths.view(-1, *([1] * (src.dim() - 1)))
    elif reduce in ['min', 'max']:
        out = torch.full((len(segment_lengths), *src.shape[1:]), float('inf') if reduce == 'min' else float('-inf'), device=src.device)
        out.scatter_reduce_(0, index.view(-1, *([1] * (src.dim() - 1))).expand_as(src), src, reduce=reduce)
    
    return out

In [25]:
device = "cuda"
# Test and benchmark segment_csr
src = torch.randn(1_000_000, 64, device=device)
indptr = torch.tensor([0, 200000, 500000, 1000000], device=device)
iidx = torch.tensor([0] *indptr[1].item() + [1] * (indptr[2].item() - indptr[1].item()) + [2] *  (indptr[3].item() - indptr[2].item()), device=device)

custom_time, custom_result = benchmark_function(segment_csr, src, indptr)
print(f"Custom segment_csr average time: {custom_time:.6f} seconds")

if HAS_TORCH_SCATTER:
    torch_scatter_time, torch_scatter_result = benchmark_function(torch_scatter_segment_csr, src, indptr)
    print(f"torch_scatter segment_csr average time: {torch_scatter_time:.6f} seconds")
    
    # torch.testing.assert_close(custom_result, torch_scatter_result), "segment_csr results are not close!"
    # print("segment_csr results are close to torch_scatter implementation.")


    torch_scatter_simple_time, torch_scatter_simple_result = benchmark_function(torch_scatter_scatter, src, iidx, 0)
    print(f"torch_scatter scatter average time: {torch_scatter_simple_time:.6f} seconds")
    
    # torch.testing.assert_close(custom_result, torch_scatter_simple_result), "segment_csr results are not close!"
    # print("scatter results are close to torch_scatter implementation.")


Custom segment_csr average time: 0.010329 seconds
torch_scatter segment_csr average time: 0.035853 seconds
torch_scatter scatter average time: 0.009620 seconds


In [26]:
def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim_size: int = None) -> torch.Tensor:
    if not torch.is_floating_point(src):
        raise ValueError('`scatter_softmax` can only be computed over tensors with floating point data types.')

    if dim_size is None:
        dim_size = index.max().item() + 1

    max_value_per_index = torch.zeros(dim_size, device=src.device, dtype=src.dtype)
    max_value_per_index.scatter_reduce_(0, index, src, reduce='amax')
    max_per_src_element = max_value_per_index.gather(0, index)

    recentered_scores = src - max_per_src_element
    recentered_scores_exp = recentered_scores.exp()

    sum_per_index = torch.zeros(dim_size, device=src.device, dtype=src.dtype)
    sum_per_index.scatter_add_(0, index, recentered_scores_exp)
    normalizing_constants = sum_per_index.gather(0, index)

    return recentered_scores_exp / normalizing_constants

In [29]:

# Test and benchmark scatter_softmax
src = torch.randn(1000000, device=device)
index = torch.randint(0, 100, (1000000,), device=device)

custom_time, custom_result = benchmark_function(scatter_softmax, src, index)
print(f"Custom scatter_softmax average time: {custom_time:.6f} seconds")

if HAS_TORCH_SCATTER:
    torch_scatter_time, torch_scatter_result = benchmark_function(torch_scatter_softmax, src, index)
    print(f"torch_scatter scatter_softmax average time: {torch_scatter_time:.6f} seconds")
    
    # assert torch.allclose(custom_result, torch_scatter_result, atol=1e-4), "scatter_softmax results are not close!"
    # print("scatter_softmax results are close to torch_scatter implementation.")

Custom scatter_softmax average time: 0.010368 seconds
torch_scatter scatter_softmax average time: 0.007084 seconds


In [28]:
custom_result.shape, torch_scatter_result.shape

(torch.Size([1000000]), torch.Size([1000000]))