Skip to content

Commit

Permalink
Support CPUNearestNeighbor for benchmarking exact nearest neighbors. (#…
Browse files Browse the repository at this point in the history
…655)

* support bench exact CPU knn

* set maxResultSize=0 to unlimit broadcast

* fix a typo in run_benchmark.sh, add test functions for CPUNearestNeighbors

* revise

* limit omp job to 1 per spark task when using sklearn

---------

Signed-off-by: Jinfeng <jinfengl@nvidia.com>
  • Loading branch information
lijinf2 committed May 25, 2024
1 parent 7dce71b commit d608e96
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 59 deletions.
147 changes: 92 additions & 55 deletions python/benchmark/benchmark/bench_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,53 @@
# limitations under the License.
#
import time
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from pyspark.ml.feature import VectorAssembler
from pyspark.ml.functions import array_to_vector
from pyspark.sql import DataFrame, SparkSession

from benchmark.base import BenchmarkBase
from benchmark.utils import with_benchmark
from spark_rapids_ml.core import (
EvalMetricInfo,
_ConstructFunc,
_EvaluateFunc,
_TransformFunc,
alias,
)
from spark_rapids_ml.knn import ApproximateNearestNeighborsModel


class CPUNearestNeighborsModel(ApproximateNearestNeighborsModel):
def __init__(self, item_df: DataFrame):
super().__init__(item_df)

def kneighbors(self, query_df: DataFrame) -> Tuple[DataFrame, DataFrame, DataFrame]:
self._item_df_withid = self._ensureIdCol(self._item_df_withid)
return super().kneighbors(query_df)

def _get_cuml_transform_func(
self, dataset: DataFrame, eval_metric_info: Optional[EvalMetricInfo] = None
) -> Tuple[
_ConstructFunc,
_TransformFunc,
Optional[_EvaluateFunc],
]:
self._cuml_params["algorithm"] = "brute"
_, _transform_internal, _ = super()._get_cuml_transform_func(
dataset, eval_metric_info
)

from sklearn.neighbors import NearestNeighbors as SKNN

n_neighbors = self.getK()

def _construct_sknn() -> SKNN:
nn_object = SKNN(algorithm="brute", n_neighbors=n_neighbors)
return nn_object

return _construct_sknn, _transform_internal, None


class BenchmarkNearestNeighbors(BenchmarkBase):
Expand All @@ -37,6 +76,13 @@ def _add_extra_arguments(self) -> None:
help="whether to enable dataframe repartition, cache and cout outside fit function",
)

self._parser.add_argument(
"--fraction_sampled_queries",
type=float,
required=True,
help="the number of vectors sampled from the dataset as query vectors",
)

def run_once(
self,
spark: SparkSession,
Expand All @@ -54,6 +100,8 @@ def run_once(
num_cpus = self.args.num_cpus
no_cache = self.args.no_cache
n_neighbors = self.args.n_neighbors
fraction_sampled_queries = self.args.fraction_sampled_queries
seed = 0

func_start_time = time.time()

Expand All @@ -65,22 +113,33 @@ def run_once(
if not is_single_col:
input_cols = [c for c in train_df.schema.names]

query_df = train_df.sample(
withReplacement=False, fraction=fraction_sampled_queries, seed=seed
)

def cache_df(dfA: DataFrame, dfB: DataFrame) -> Tuple[DataFrame, DataFrame]:
dfA = dfA.cache()
dfB = dfB.cache()

def func_dummy(pdf_iter): # type: ignore
import pandas as pd

yield pd.DataFrame({"dummy": [1]})

dfA.mapInPandas(func_dummy, schema="dummy int").count()
dfB.mapInPandas(func_dummy, schema="dummy int").count()
return (dfA, dfB)

params = self.class_params
if num_gpus > 0:
from spark_rapids_ml.knn import NearestNeighbors, NearestNeighborsModel

assert num_cpus <= 0
if not no_cache:

def gpu_cache_df(df: DataFrame) -> DataFrame:
df = df.repartition(num_gpus).cache()
df.count()
return df

train_df, prepare_time = with_benchmark(
"prepare dataset", lambda: gpu_cache_df(train_df)
(train_df, query_df), prepare_time = with_benchmark(
"prepare dataset", lambda: cache_df(train_df, query_df)
)

params = self.class_params
gpu_estimator = NearestNeighbors(
num_workers=num_gpus, verbose=self.args.verbose, **params
)
Expand All @@ -100,67 +159,44 @@ def transform(model: NearestNeighborsModel, df: DataFrame) -> DataFrame:
return knn_df

knn_df, transform_time = with_benchmark(
"gpu transform", lambda: transform(gpu_model, train_df)
"gpu transform", lambda: transform(gpu_model, query_df)
)
total_time = round(time.time() - func_start_time, 2)
print(f"gpu total took: {total_time} sec")

if num_cpus > 0:
assert num_gpus <= 0
if is_array_col:
vector_df = train_df.select(
array_to_vector(train_df[first_col]).alias(first_col)
)
elif not is_vector_col:
vector_assembler = VectorAssembler(outputCol="features").setInputCols(
input_cols
)
vector_df = vector_assembler.transform(train_df).drop(*input_cols)
first_col = "features"
else:
vector_df = train_df

if not no_cache:

def cpu_cache_df(df: DataFrame) -> DataFrame:
df = df.cache()
df.count()
return df

vector_df, prepare_time = with_benchmark(
"prepare dataset", lambda: cpu_cache_df(vector_df)
(train_df, query_df), prepare_time = with_benchmark(
"prepare dataset", lambda: cache_df(train_df, query_df)
)

from pyspark.ml.feature import (
BucketedRandomProjectionLSH,
BucketedRandomProjectionLSHModel,
)
def get_cpu_model() -> CPUNearestNeighborsModel:
cpu_estimator = CPUNearestNeighborsModel(train_df).setK(
params["n_neighbors"]
)

cpu_estimator = BucketedRandomProjectionLSH(
inputCol=first_col,
outputCol="hashes",
bucketLength=2.0,
numHashTables=3,
)
return cpu_estimator

cpu_model, fit_time = with_benchmark(
"cpu fit time", lambda: cpu_estimator.fit(vector_df)
"cpu fit time", lambda: get_cpu_model()
)

if is_single_col:
cpu_model = cpu_model.setInputCol(first_col)
else:
cpu_model = cpu_model.setInputCols(input_cols)

def cpu_transform(
model: BucketedRandomProjectionLSHModel, df: DataFrame, n_neighbors: int
) -> None:
queries = df.collect()
for row in queries:
query = row[first_col]
knn_df = model.approxNearestNeighbors(
dataset=df, key=query, numNearestNeighbors=n_neighbors
)
knn_df.count()

_, transform_time = with_benchmark(
model: CPUNearestNeighborsModel, df: DataFrame
) -> DataFrame:
(item_df_withid, query_df_withid, knn_df) = model.kneighbors(df)
knn_df.count()
return knn_df

knn_df, transform_time = with_benchmark(
"cpu transform",
lambda: cpu_transform(cpu_model, vector_df, n_neighbors),
lambda: cpu_transform(cpu_model, query_df),
)

total_time = round(time.time() - func_start_time, 2)
Expand All @@ -171,6 +207,7 @@ def cpu_transform(
"transform": transform_time,
"total_time": total_time,
"n_neighbors": n_neighbors,
"fraction_sampled_queries": fraction_sampled_queries,
"num_gpus": num_gpus,
"num_cpus": num_cpus,
"no_cache": no_cache,
Expand Down
7 changes: 5 additions & 2 deletions python/run_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ unset SPARK_HOME
# data set params
num_rows=${num_rows:-5000}
knn_num_rows=$num_rows
knn_fraction_sampled_queries=${knn_fraction_sampled_queries:-0.01}
num_cols=${num_cols:-3000}
num_sparse_cols=${num_sparse_cols:-3000}
density=${density:-0.1}
Expand Down Expand Up @@ -185,14 +186,16 @@ if [[ "${MODE}" =~ "knn" ]] || [[ "${MODE}" == "all" ]]; then
fi

echo "$sep algo: knn $sep"
python ./benchmark/benchmark_runner.py knn \
--n_neighbors 3 \
OMP_NUM_THREADS=1 python ./benchmark/benchmark_runner.py knn \
--n_neighbors 20 \
--fraction_sampled_queries ${knn_fraction_sampled_queries} \
--num_gpus $num_gpus \
--num_cpus $num_cpus \
--no_cache \
--num_runs $num_runs \
--train_path "${gen_data_root}/blobs/r${knn_num_rows}_c${num_cols}_float32.parquet" \
--report_path "report_knn_${cluster_type}.csv" \
--spark_confs "spark.driver.maxResultSize=0" \
$common_confs $spark_rapids_confs \
${EXTRA_ARGS}
fi
Expand Down
9 changes: 7 additions & 2 deletions python/src/spark_rapids_ml/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,14 +1334,19 @@ def _transform_internal(
):
distances = distances * distances

indices = indices.get()
if isinstance(distances, cp.ndarray):
distances = distances.get()

if isinstance(indices, cp.ndarray):
indices = indices.get()

indices_global = item_row_number[indices]

res = pd.DataFrame(
{
f"query_{id_col_name}": bcast_qids.value,
"indices": list(indices_global),
"distances": list(distances.get()),
"distances": list(distances),
}
)
return res
Expand Down
116 changes: 116 additions & 0 deletions python/tests/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
import sys

file_path = os.path.abspath(__file__)
file_dir_path = os.path.dirname(file_path)
extra_python_path = file_dir_path + "/../benchmark"
sys.path.append(extra_python_path)

from typing import List, Tuple

import numpy as np
import pandas as pd
import pytest
from pyspark.sql import DataFrame
from sklearn.datasets import make_blobs

from benchmark.bench_nearest_neighbors import CPUNearestNeighborsModel
from spark_rapids_ml.core import alias

from .sparksession import CleanSparkSession
from .utils import array_equal


def get_sgnn_res(
X_item: np.ndarray, X_query: np.ndarray, n_neighbors: int
) -> Tuple[np.ndarray, np.ndarray]:
from sklearn.neighbors import NearestNeighbors as SGNN

sg_nn = SGNN(n_neighbors=n_neighbors)
sg_nn.fit(X_item)
sg_distances, sg_indices = sg_nn.kneighbors(X_query)
return (sg_distances, sg_indices)


def assert_knn_equal(
knn_df: DataFrame, id_col_name: str, distances: np.ndarray, indices: np.ndarray
) -> None:
res_pd: pd.DataFrame = knn_df.sort(f"query_{id_col_name}").toPandas()
mg_indices: np.ndarray = np.array(res_pd["indices"].to_list())
mg_distances: np.ndarray = np.array(res_pd["distances"].to_list())

assert array_equal(mg_indices, indices)
assert array_equal(mg_distances, distances)


@pytest.mark.slow
def test_cpunn_withid() -> None:

n_samples = 1000
n_features = 50
n_clusters = 10
n_neighbors = 30

X, _ = make_blobs(
n_samples=n_samples,
n_features=n_features,
centers=n_clusters,
random_state=0,
) # make_blobs creates a random dataset of isotropic gaussian blobs.

sg_distances, sg_indices = get_sgnn_res(X, X, n_neighbors)

with CleanSparkSession({}) as spark:

def py_func(id: int) -> List[int]:
return X[id].tolist()

from pyspark.sql.functions import udf

spark_func = udf(py_func, "array<float>")
df = spark.range(len(X)).select("id", spark_func("id").alias("features"))

mg_model = (
CPUNearestNeighborsModel(df)
.setInputCol("features")
.setIdCol("id")
.setK(n_neighbors)
)

_, _, knn_df = mg_model.kneighbors(df)
assert_knn_equal(knn_df, "id", sg_distances, sg_indices)


# @pytest.mark.slow
def test_cpunn_noid() -> None:

n_samples = 1000
n_features = 50
n_clusters = 10
n_neighbors = 30

X, _ = make_blobs(
n_samples=n_samples,
n_features=n_features,
centers=n_clusters,
random_state=0,
) # make_blobs creates a random dataset of isotropic gaussian blobs.

with CleanSparkSession({}) as spark:

df = spark.createDataFrame(X)
from pyspark.sql.functions import array

df = df.select(array(df.columns).alias("features"))

mg_model = (
CPUNearestNeighborsModel(df).setInputCol("features").setK(n_neighbors)
)

df_withid, _, knn_df = mg_model.kneighbors(df)

pdf: pd.DataFrame = df_withid.sort(alias.row_number).toPandas()
X = np.array(pdf["features"].to_list())

distances, indices = get_sgnn_res(X, X, n_neighbors)
assert_knn_equal(knn_df, alias.row_number, distances, indices)

0 comments on commit d608e96

Please sign in to comment.