Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CPUNearestNeighbor for benchmarking exact nearest neighbors. #655

Merged
merged 5 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 92 additions & 55 deletions python/benchmark/benchmark/bench_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,53 @@
# limitations under the License.
#
import time
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from pyspark.ml.feature import VectorAssembler
from pyspark.ml.functions import array_to_vector
from pyspark.sql import DataFrame, SparkSession

from benchmark.base import BenchmarkBase
from benchmark.utils import with_benchmark
from spark_rapids_ml.core import (
EvalMetricInfo,
_ConstructFunc,
_EvaluateFunc,
_TransformFunc,
alias,
)
from spark_rapids_ml.knn import ApproximateNearestNeighborsModel


class CPUNearestNeighborsModel(ApproximateNearestNeighborsModel):
def __init__(self, item_df: DataFrame):
super().__init__(item_df)

def kneighbors(self, query_df: DataFrame) -> Tuple[DataFrame, DataFrame, DataFrame]:
self._item_df_withid = self._ensureIdCol(self._item_df_withid)
return super().kneighbors(query_df)

def _get_cuml_transform_func(
self, dataset: DataFrame, eval_metric_info: Optional[EvalMetricInfo] = None
) -> Tuple[
_ConstructFunc,
_TransformFunc,
Optional[_EvaluateFunc],
]:
self._cuml_params["algorithm"] = "brute"
_, _transform_internal, _ = super()._get_cuml_transform_func(
dataset, eval_metric_info
)

from sklearn.neighbors import NearestNeighbors as SKNN

n_neighbors = self.getK()

def _construct_sknn() -> SKNN:
nn_object = SKNN(algorithm="brute", n_neighbors=n_neighbors)
return nn_object

return _construct_sknn, _transform_internal, None


class BenchmarkNearestNeighbors(BenchmarkBase):
Expand All @@ -37,6 +76,13 @@ def _add_extra_arguments(self) -> None:
help="whether to enable dataframe repartition, cache and cout outside fit function",
)

self._parser.add_argument(
"--fraction_sampled_queries",
type=float,
required=True,
help="the number of vectors sampled from the dataset as query vectors",
)

def run_once(
self,
spark: SparkSession,
Expand All @@ -54,6 +100,8 @@ def run_once(
num_cpus = self.args.num_cpus
no_cache = self.args.no_cache
n_neighbors = self.args.n_neighbors
fraction_sampled_queries = self.args.fraction_sampled_queries
seed = 0

func_start_time = time.time()

Expand All @@ -65,22 +113,33 @@ def run_once(
if not is_single_col:
input_cols = [c for c in train_df.schema.names]

query_df = train_df.sample(
withReplacement=False, fraction=fraction_sampled_queries, seed=seed
)

def cache_df(dfA: DataFrame, dfB: DataFrame) -> Tuple[DataFrame, DataFrame]:
dfA = dfA.cache()
dfB = dfB.cache()

def func_dummy(pdf_iter): # type: ignore
import pandas as pd

yield pd.DataFrame({"dummy": [1]})

dfA.mapInPandas(func_dummy, schema="dummy int").count()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to avoid python udfs for this kind of thing but probably ok.

dfB.mapInPandas(func_dummy, schema="dummy int").count()
return (dfA, dfB)

params = self.class_params
if num_gpus > 0:
from spark_rapids_ml.knn import NearestNeighbors, NearestNeighborsModel

assert num_cpus <= 0
if not no_cache:

def gpu_cache_df(df: DataFrame) -> DataFrame:
df = df.repartition(num_gpus).cache()
df.count()
return df

train_df, prepare_time = with_benchmark(
"prepare dataset", lambda: gpu_cache_df(train_df)
(train_df, query_df), prepare_time = with_benchmark(
"prepare dataset", lambda: cache_df(train_df, query_df)
)

params = self.class_params
gpu_estimator = NearestNeighbors(
num_workers=num_gpus, verbose=self.args.verbose, **params
)
Expand All @@ -100,67 +159,44 @@ def transform(model: NearestNeighborsModel, df: DataFrame) -> DataFrame:
return knn_df

knn_df, transform_time = with_benchmark(
"gpu transform", lambda: transform(gpu_model, train_df)
"gpu transform", lambda: transform(gpu_model, query_df)
)
total_time = round(time.time() - func_start_time, 2)
print(f"gpu total took: {total_time} sec")

if num_cpus > 0:
assert num_gpus <= 0
if is_array_col:
vector_df = train_df.select(
array_to_vector(train_df[first_col]).alias(first_col)
)
elif not is_vector_col:
vector_assembler = VectorAssembler(outputCol="features").setInputCols(
input_cols
)
vector_df = vector_assembler.transform(train_df).drop(*input_cols)
first_col = "features"
else:
vector_df = train_df

if not no_cache:

def cpu_cache_df(df: DataFrame) -> DataFrame:
df = df.cache()
df.count()
return df

vector_df, prepare_time = with_benchmark(
"prepare dataset", lambda: cpu_cache_df(vector_df)
(train_df, query_df), prepare_time = with_benchmark(
"prepare dataset", lambda: cache_df(train_df, query_df)
)

from pyspark.ml.feature import (
BucketedRandomProjectionLSH,
BucketedRandomProjectionLSHModel,
)
def get_cpu_model() -> CPUNearestNeighborsModel:
cpu_estimator = CPUNearestNeighborsModel(train_df).setK(
params["n_neighbors"]
)

cpu_estimator = BucketedRandomProjectionLSH(
inputCol=first_col,
outputCol="hashes",
bucketLength=2.0,
numHashTables=3,
)
return cpu_estimator

cpu_model, fit_time = with_benchmark(
"cpu fit time", lambda: cpu_estimator.fit(vector_df)
"cpu fit time", lambda: get_cpu_model()
)

if is_single_col:
cpu_model = cpu_model.setInputCol(first_col)
else:
cpu_model = cpu_model.setInputCols(input_cols)

def cpu_transform(
model: BucketedRandomProjectionLSHModel, df: DataFrame, n_neighbors: int
) -> None:
queries = df.collect()
for row in queries:
query = row[first_col]
knn_df = model.approxNearestNeighbors(
dataset=df, key=query, numNearestNeighbors=n_neighbors
)
knn_df.count()

_, transform_time = with_benchmark(
model: CPUNearestNeighborsModel, df: DataFrame
) -> DataFrame:
(item_df_withid, query_df_withid, knn_df) = model.kneighbors(df)
knn_df.count()
return knn_df

knn_df, transform_time = with_benchmark(
"cpu transform",
lambda: cpu_transform(cpu_model, vector_df, n_neighbors),
lambda: cpu_transform(cpu_model, query_df),
)

total_time = round(time.time() - func_start_time, 2)
Expand All @@ -171,6 +207,7 @@ def cpu_transform(
"transform": transform_time,
"total_time": total_time,
"n_neighbors": n_neighbors,
"fraction_sampled_queries": fraction_sampled_queries,
"num_gpus": num_gpus,
"num_cpus": num_cpus,
"no_cache": no_cache,
Expand Down
7 changes: 5 additions & 2 deletions python/run_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ unset SPARK_HOME
# data set params
num_rows=${num_rows:-5000}
knn_num_rows=$num_rows
knn_fraction_sampled_queries=${knn_fraction_sampled_queries:-0.01}
num_cols=${num_cols:-3000}
num_sparse_cols=${num_sparse_cols:-3000}
density=${density:-0.1}
Expand Down Expand Up @@ -185,14 +186,16 @@ if [[ "${MODE}" =~ "knn" ]] || [[ "${MODE}" == "all" ]]; then
fi

echo "$sep algo: knn $sep"
python ./benchmark/benchmark_runner.py knn \
--n_neighbors 3 \
OMP_NUM_THREADS=1 python ./benchmark/benchmark_runner.py knn \
--n_neighbors 20 \
--fraction_sampled_queries ${knn_fraction_sampled_queries} \
--num_gpus $num_gpus \
--num_cpus $num_cpus \
--no_cache \
--num_runs $num_runs \
--train_path "${gen_data_root}/blobs/r${knn_num_rows}_c${num_cols}_float32.parquet" \
--report_path "report_knn_${cluster_type}.csv" \
--spark_confs "spark.driver.maxResultSize=0" \
$common_confs $spark_rapids_confs \
${EXTRA_ARGS}
fi
Expand Down
9 changes: 7 additions & 2 deletions python/src/spark_rapids_ml/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,14 +1334,19 @@ def _transform_internal(
):
distances = distances * distances

indices = indices.get()
if isinstance(distances, cp.ndarray):
distances = distances.get()

if isinstance(indices, cp.ndarray):
indices = indices.get()

indices_global = item_row_number[indices]

res = pd.DataFrame(
{
f"query_{id_col_name}": bcast_qids.value,
"indices": list(indices_global),
"distances": list(distances.get()),
"distances": list(distances),
}
)
return res
Expand Down
116 changes: 116 additions & 0 deletions python/tests/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
import sys

file_path = os.path.abspath(__file__)
file_dir_path = os.path.dirname(file_path)
extra_python_path = file_dir_path + "/../benchmark"
sys.path.append(extra_python_path)

from typing import List, Tuple

import numpy as np
import pandas as pd
import pytest
from pyspark.sql import DataFrame
from sklearn.datasets import make_blobs

from benchmark.bench_nearest_neighbors import CPUNearestNeighborsModel
from spark_rapids_ml.core import alias

from .sparksession import CleanSparkSession
from .utils import array_equal


def get_sgnn_res(
X_item: np.ndarray, X_query: np.ndarray, n_neighbors: int
) -> Tuple[np.ndarray, np.ndarray]:
from sklearn.neighbors import NearestNeighbors as SGNN

sg_nn = SGNN(n_neighbors=n_neighbors)
sg_nn.fit(X_item)
sg_distances, sg_indices = sg_nn.kneighbors(X_query)
return (sg_distances, sg_indices)


def assert_knn_equal(
knn_df: DataFrame, id_col_name: str, distances: np.ndarray, indices: np.ndarray
) -> None:
res_pd: pd.DataFrame = knn_df.sort(f"query_{id_col_name}").toPandas()
mg_indices: np.ndarray = np.array(res_pd["indices"].to_list())
mg_distances: np.ndarray = np.array(res_pd["distances"].to_list())

assert array_equal(mg_indices, indices)
assert array_equal(mg_distances, distances)


@pytest.mark.slow
def test_cpunn_withid() -> None:

n_samples = 1000
n_features = 50
n_clusters = 10
n_neighbors = 30

X, _ = make_blobs(
n_samples=n_samples,
n_features=n_features,
centers=n_clusters,
random_state=0,
) # make_blobs creates a random dataset of isotropic gaussian blobs.

sg_distances, sg_indices = get_sgnn_res(X, X, n_neighbors)

with CleanSparkSession({}) as spark:

def py_func(id: int) -> List[int]:
return X[id].tolist()

from pyspark.sql.functions import udf

spark_func = udf(py_func, "array<float>")
df = spark.range(len(X)).select("id", spark_func("id").alias("features"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any advantage to doing this way vs createDataFrame from pandas df?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems does not throw out a "task size larger than 1000k" warning on large dataset, but it looks no different on small dataset.


mg_model = (
CPUNearestNeighborsModel(df)
.setInputCol("features")
.setIdCol("id")
.setK(n_neighbors)
)

_, _, knn_df = mg_model.kneighbors(df)
assert_knn_equal(knn_df, "id", sg_distances, sg_indices)


# @pytest.mark.slow
def test_cpunn_noid() -> None:

n_samples = 1000
n_features = 50
n_clusters = 10
n_neighbors = 30

X, _ = make_blobs(
n_samples=n_samples,
n_features=n_features,
centers=n_clusters,
random_state=0,
) # make_blobs creates a random dataset of isotropic gaussian blobs.

with CleanSparkSession({}) as spark:

df = spark.createDataFrame(X)
from pyspark.sql.functions import array

df = df.select(array(df.columns).alias("features"))

mg_model = (
CPUNearestNeighborsModel(df).setInputCol("features").setK(n_neighbors)
)

df_withid, _, knn_df = mg_model.kneighbors(df)

pdf: pd.DataFrame = df_withid.sort(alias.row_number).toPandas()
X = np.array(pdf["features"].to_list())

distances, indices = get_sgnn_res(X, X, n_neighbors)
assert_knn_equal(knn_df, alias.row_number, distances, indices)