Skip to content

Potential Integer Overflows, Out-of-Bounds Access, and Data Races in rspmm.cu #33

@molly-ting

Description

@molly-ting

I'm performing static analysis on CUDA programs and have identified potential issues in rspmm.cu, including integer overflows, out-of-bounds memory accesses, and data races.

In function rspmm_forward_cuda:

int64_t num_row = input.size(0);
int64_t dim = input.size(1);
Tensor row_ptr = ind2ptr(row_ind, num_row);
const int num_dim_block = (dim + dim_per_block * kCoarseningFactor - 1) / (dim_per_block * kCoarseningFactor);

Overflow Risk 1: ind2ptr(row_ind, num_row) truncates num_row (a 64-bit value) to 32-bit internally. By analyzing model code, I found that num_row is derived from the input and could be very large, which can lead to an integer overflow.

Overflow Risk 2: The computation of num_dim_block also truncates 64-bit value to 32-bit and dim is derived from the input and could be very large. So it may result in an overflow.

In function rspmm_forward_out_cuda:

int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x;
...
for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) {
    int64_t ptr = block_ptr + threadIdx.x;
    if (ptr < ptr_end) {
        col_ind_buf[threadIdx.x] = col_ind[ptr];
        layer_ind_buf[threadIdx.x] = layer_ind[ptr];
        weight_buf[threadIdx.x] = weight[ptr];
    }
}
...
for (int64_t i = 0; i < kCoarseningFactor; i++) {
    int64_t d = d_start + i * warpSize;
    if (d >= dim) break;
    scalar_t x = BinaryOp::forward(relation[layer * dim + d], input[col * dim + d]);
    scalar_t y = w * x;
    out[i] = NaryOp::forward(out[i], y);
}

Overflow Risk: Multiplications in d_start calculation may overflow as num_dim_block could be very large.

Out-of-Bounds Access:
col_ind[ptr], layer_ind[ptr], and weight[ptr] may be out-of-bounds as ptr_end depends on the contents of edge_index.
Similarly, relation[layer * dim + d] and input[col * dim + d] may access beyond allocated memory.

Data Race Risk: Concurrent threads may access col_ind_buf[threadIdx.x], layer_ind_buf[threadIdx.x], and weight_buf[threadIdx.x] simultaneously, causing potential data races. This is due to possible overlap in threadIdx.x across threads within the same warp or block, depending on the grid configuration (dim3(num_row_block, num_dim_block) and dim3(dim_per_block, row_per_block)).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions