Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 109 additions & 64 deletions cpp/src/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "ann_utils.cuh"
#include "cagra/device_common.hpp"
#include "cuvs/distance/distance.h"
#include "nn_descent_gnnd.hpp"

#include <cuvs/distance/distance.hpp>
Expand Down Expand Up @@ -285,6 +286,10 @@ RAFT_KERNEL preprocess_data_kernel(
} else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
output_data[list_id * dim + idx] =
(float)input_data[(size_t)blockIdx.x * dim + idx] / sqrt(l2_norm);
} else if (metric == cuvs::distance::DistanceType::BitwiseHamming) {
int idx_for_byte = list_id * dim + idx; // uint8 or int8 data
uint8_t* output_bytes = reinterpret_cast<uint8_t*>(output_data);
output_bytes[idx_for_byte] = input_data[idx_for_byte];
} else { // L2Expanded or L2SqrtExpanded
output_data[list_id * dim + idx] = input_data[(size_t)blockIdx.x * dim + idx];
if (idx == 0) { l2_norms[list_id] = l2_norm; }
Expand Down Expand Up @@ -588,39 +593,44 @@ __launch_bounds__(BLOCK_SIZE)
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
wmma::fill_fragment(c_frag, 0.0);
for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) {
int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1)
? data_dim - step * TILE_COL_WIDTH
: TILE_COL_WIDTH;
if (metric != cuvs::distance::DistanceType::BitwiseHamming) {
wmma::fill_fragment(c_frag, 0.0);

for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) {
int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1)
? data_dim - step * TILE_COL_WIDTH
: TILE_COL_WIDTH;
#pragma unroll
for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) {
int idx = i * num_warps + warp_id;
if (idx < list_new_size) {
size_t neighbor_id = new_neighbors[idx];
size_t idx_in_data = neighbor_id * data_dim;
load_vec(s_nv[idx],
data + idx_in_data + step * TILE_COL_WIDTH,
num_load_elems,
TILE_COL_WIDTH,
lane_id);
for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) {
int idx = i * num_warps + warp_id;
if (idx < list_new_size) {
size_t neighbor_id = new_neighbors[idx];
size_t idx_in_data = neighbor_id * data_dim;
load_vec(s_nv[idx],
data + idx_in_data + step * TILE_COL_WIDTH,
num_load_elems,
TILE_COL_WIDTH,
lane_id);
}
}
}
__syncthreads();

for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) {
wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD);
wmma::load_matrix_sync(b_frag, s_nv[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
__syncthreads();

for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) {
wmma::load_matrix_sync(
a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD);
wmma::load_matrix_sync(
b_frag, s_nv[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
__syncthreads();
}
}
}

wmma::store_matrix_sync(
s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N,
c_frag,
SKEWED_MAX_NUM_BI_SAMPLES,
wmma::mem_row_major);
wmma::store_matrix_sync(
s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N,
c_frag,
SKEWED_MAX_NUM_BI_SAMPLES,
wmma::mem_row_major);
}
__syncthreads();

for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) {
Expand All @@ -632,6 +642,16 @@ __launch_bounds__(BLOCK_SIZE)
s_distances[i] = -s_distances[i];
} else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
s_distances[i] = 1.0 - s_distances[i];
} else if (metric == cuvs::distance::DistanceType::BitwiseHamming) {
s_distances[i] = 0.0;
int n1 = new_neighbors[row_id];
int n2 = new_neighbors[col_id];
// TODO: https://github.com/rapidsai/cuvs/issues/1127
const uint8_t* data_n1 = reinterpret_cast<const uint8_t*>(data) + n1 * data_dim;
const uint8_t* data_n2 = reinterpret_cast<const uint8_t*>(data) + n2 * data_dim;
for (int d = 0; d < data_dim; d++) {
s_distances[i] += __popc(static_cast<uint32_t>(data_n1[d] ^ data_n2[d]) & 0xff);
}
} else { // L2Expanded or L2SqrtExpanded
s_distances[i] =
l2_norms[new_neighbors[row_id]] + l2_norms[new_neighbors[col_id]] - 2.0 * s_distances[i];
Expand Down Expand Up @@ -659,56 +679,60 @@ __launch_bounds__(BLOCK_SIZE)

__syncthreads();

wmma::fill_fragment(c_frag, 0.0);
for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) {
int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1)
? data_dim - step * TILE_COL_WIDTH
: TILE_COL_WIDTH;
if (TILE_COL_WIDTH < data_dim) {
if (metric != cuvs::distance::DistanceType::BitwiseHamming) {
wmma::fill_fragment(c_frag, 0.0);
for (int step = 0; step < raft::ceildiv(data_dim, TILE_COL_WIDTH); step++) {
int num_load_elems = (step == raft::ceildiv(data_dim, TILE_COL_WIDTH) - 1)
? data_dim - step * TILE_COL_WIDTH
: TILE_COL_WIDTH;
if (TILE_COL_WIDTH < data_dim) {
#pragma unroll
for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) {
int idx = i * num_warps + warp_id;
if (idx < list_new_size) {
size_t neighbor_id = new_neighbors[idx];
size_t idx_in_data = neighbor_id * data_dim;
load_vec(s_nv[idx],
data + idx_in_data + step * TILE_COL_WIDTH,
num_load_elems,
TILE_COL_WIDTH,
lane_id);
}
}
}
#pragma unroll
for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) {
int idx = i * num_warps + warp_id;
if (idx < list_new_size) {
size_t neighbor_id = new_neighbors[idx];
if (idx < list_old_size) {
size_t neighbor_id = old_neighbors[idx];
size_t idx_in_data = neighbor_id * data_dim;
load_vec(s_nv[idx],
load_vec(s_ov[idx],
data + idx_in_data + step * TILE_COL_WIDTH,
num_load_elems,
TILE_COL_WIDTH,
lane_id);
}
}
}
#pragma unroll
for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) {
int idx = i * num_warps + warp_id;
if (idx < list_old_size) {
size_t neighbor_id = old_neighbors[idx];
size_t idx_in_data = neighbor_id * data_dim;
load_vec(s_ov[idx],
data + idx_in_data + step * TILE_COL_WIDTH,
num_load_elems,
TILE_COL_WIDTH,
lane_id);
__syncthreads();

for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) {
wmma::load_matrix_sync(
a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD);
wmma::load_matrix_sync(
b_frag, s_ov[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
__syncthreads();
}
}
__syncthreads();

for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) {
wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD);
wmma::load_matrix_sync(b_frag, s_ov[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD);
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
__syncthreads();
}
wmma::store_matrix_sync(
s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N,
c_frag,
SKEWED_MAX_NUM_BI_SAMPLES,
wmma::mem_row_major);
__syncthreads();
}

wmma::store_matrix_sync(
s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N,
c_frag,
SKEWED_MAX_NUM_BI_SAMPLES,
wmma::mem_row_major);
__syncthreads();

for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) {
int row_id = i % SKEWED_MAX_NUM_BI_SAMPLES;
int col_id = i / SKEWED_MAX_NUM_BI_SAMPLES;
Expand All @@ -717,6 +741,16 @@ __launch_bounds__(BLOCK_SIZE)
s_distances[i] = -s_distances[i];
} else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
s_distances[i] = 1.0 - s_distances[i];
} else if (metric == cuvs::distance::DistanceType::BitwiseHamming) {
s_distances[i] = 0.0;
int n1 = old_neighbors[row_id];
int n2 = new_neighbors[col_id];
// TODO: https://github.com/rapidsai/cuvs/issues/1127
const uint8_t* data_n1 = reinterpret_cast<const uint8_t*>(data) + n1 * data_dim;
const uint8_t* data_n2 = reinterpret_cast<const uint8_t*>(data) + n2 * data_dim;
for (int d = 0; d < data_dim; d++) {
s_distances[i] += __popc(static_cast<uint32_t>(data_n1[d] ^ data_n2[d]) & 0xff);
}
} else { // L2Expanded or L2SqrtExpanded
s_distances[i] =
l2_norms[old_neighbors[row_id]] + l2_norms[new_neighbors[col_id]] - 2.0 * s_distances[i];
Expand Down Expand Up @@ -980,7 +1014,11 @@ GNND<Data_t, Index_t>::GNND(raft::resources const& res, const BuildConfig& build
nrow_(build_config.max_dataset_size),
ndim_(build_config.dataset_dim),
d_data_{raft::make_device_matrix<__half, size_t, raft::row_major>(
res, nrow_, build_config.dataset_dim)},
res,
nrow_,
build_config.metric == cuvs::distance::DistanceType::BitwiseHamming
? (build_config.dataset_dim + 1) / 2
: build_config.dataset_dim)},
l2_norms_{raft::make_device_vector<DistData_t, size_t>(res, 0)},
graph_buffer_{
raft::make_device_matrix<ID_t, size_t, raft::row_major>(res, nrow_, DEGREE_ON_DEVICE)},
Expand Down Expand Up @@ -1071,12 +1109,19 @@ void GNND<Data_t, Index_t>::build(Data_t* data,
{
using input_t = typename std::remove_const<Data_t>::type;

if (build_config_.metric == cuvsDistanceType::BitwiseHamming &&
!(std::is_same_v<input_t, uint8_t> || std::is_same_v<input_t, int8_t>)) {
RAFT_FAIL(
"Data type needs to be int8 or uint8 for NN Descent to run with BitwiseHamming distance.");
}

cudaStream_t stream = raft::resource::get_cuda_stream(res);
nrow_ = nrow;
graph_.nrow = nrow;
graph_.bloom_filter.set_nrow(nrow);
update_counter_ = 0;
graph_.h_graph = (InternalID_t<Index_t>*)output_graph;
raft::matrix::fill(res, d_data_.view(), static_cast<__half>(0));

cudaPointerAttributes data_ptr_attr;
RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data));
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/neighbors/detail/nn_descent_gnnd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,11 @@ inline BuildConfig get_build_config(raft::resources const& res,
auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded ||
params.metric == cuvs::distance::DistanceType::InnerProduct;
params.metric == cuvs::distance::DistanceType::InnerProduct ||
params.metric == cuvs::distance::DistanceType::BitwiseHamming;
RAFT_EXPECTS(allowed_metrics,
"The metric for NN Descent should be L2Expanded, L2SqrtExpanded, CosineExpanded or "
"InnerProduct");
"The metric for NN Descent should be L2Expanded, L2SqrtExpanded, CosineExpanded, "
"InnerProduct or BitwiseHamming");
RAFT_EXPECTS(
metric == params.metric,
"The metrics set in nn_descent::index_params and nn_descent::index are inconsistent");
Expand Down
14 changes: 10 additions & 4 deletions cpp/tests/neighbors/ann_nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/nn_descent.hpp>

#include <raft/core/host_mdarray.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/util/cudart_utils.hpp>
#include <raft/util/itertools.hpp>
Expand Down Expand Up @@ -86,6 +87,10 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
protected:
void testNNDescent()
{
if (ps.metric == cuvs::distance::DistanceType::BitwiseHamming &&
!(std::is_same_v<DataT, uint8_t> || std::is_same_v<DataT, int8_t>)) {
GTEST_SKIP();
}
size_t queries_size = ps.n_rows * ps.graph_degree;
std::vector<IdxT> indices_NNDescent(queries_size);
std::vector<DistanceT> distances_NNDescent(queries_size);
Expand Down Expand Up @@ -471,10 +476,11 @@ class AnnNNDescentBatchTest : public ::testing::TestWithParam<AnnNNDescentBatchI
};

const std::vector<AnnNNDescentInputs> inputs =
raft::util::itertools::product<AnnNNDescentInputs>({2000, 4000}, // n_rows
{4, 16, 64, 256, 1024}, // dim
{32, 64}, // graph_degree
{cuvs::distance::DistanceType::L2Expanded,
raft::util::itertools::product<AnnNNDescentInputs>({2000, 4000}, // n_rows
{4, 16, 31, 64, 256, 1024}, // dim
{32, 64}, // graph_degree
{cuvs::distance::DistanceType::BitwiseHamming,
cuvs::distance::DistanceType::L2Expanded,
cuvs::distance::DistanceType::L2SqrtExpanded,
cuvs::distance::DistanceType::InnerProduct,
cuvs::distance::DistanceType::CosineExpanded},
Expand Down
4 changes: 2 additions & 2 deletions cpp/tests/neighbors/naive_knn.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -66,7 +66,7 @@ RAFT_KERNEL naive_distance_kernel(EvalT* dist,
acc += diff * diff;
} break;
case cuvs::distance::DistanceType::BitwiseHamming: {
if constexpr (std::is_same_v<uint8_t, DataT>) {
if constexpr (std::is_same_v<uint8_t, DataT> || std::is_same_v<int8_t, DataT>) {
acc += __popc(static_cast<uint32_t>(xv ^ yv) & 0xff);
}
} break;
Expand Down