Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace ML::MetricType with raft::distance::DistanceType #3389

Merged
merged 38 commits into from
Mar 2, 2021

Conversation

lowener
Copy link
Contributor

@lowener lowener commented Jan 20, 2021

Closes #3319.
This PR will replace the distance type from ML::MetricType to raft::distance::DistanceType.

Since Raft DistanceType makes the distinction between the expanded and non-expanded distances in the name, I changed the C++ API to remove the boolean parameter expanded which becomes useless.

@github-actions github-actions bot added Cython / Python Cython or Python issue libcuml labels Jan 20, 2021
@lowener lowener changed the title Replace ML::MetricType with raft::distance::DistanceType [WIP] Replace ML::MetricType with raft::distance::DistanceType Jan 20, 2021
@lowener lowener changed the title [WIP] Replace ML::MetricType with raft::distance::DistanceType Replace ML::MetricType with raft::distance::DistanceType Jan 31, 2021
@lowener lowener marked this pull request as ready for review January 31, 2021 14:57
@lowener lowener requested review from a team as code owners January 31, 2021 14:57
@cjnolet cjnolet added this to PR-WIP in v0.18 Release via automation Feb 3, 2021
@cjnolet cjnolet added the 3 - Ready for Review Ready for review by team label Feb 3, 2021
@lowener lowener requested a review from cjnolet February 11, 2021 22:15
@cjnolet cjnolet added 3 - Ready for Review Ready for review by team and removed 4 - Waiting on Author Waiting for author to respond to review labels Feb 11, 2021
Copy link
Contributor

@viclafargue viclafargue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just have a few suggestions.

MetricType metric = MetricType::METRIC_L2,
float metric_arg = 2.0f, bool expanded = false);
raft::distance::DistanceType metric =
raft::distance::DistanceType::L2Unexpanded,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may want L2SqrtUnexpanded here to have the euclidean distance as default. At least if the results is seen by the end-user and not only used internally. Normally, METRIC_L2 in FAISS provides the euclidean distance before root-squaring. Then post-processing should apply the root-square. @cjnolet probably knows better about this though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would leave these in expanded form, actually. It's the most used metric and the difference in performance is pretty huge

ML::MetricType metric = ML::MetricType::METRIC_L2,
float metricArg = 0, bool expanded_form = false);
raft::distance::DistanceType metric =
raft::distance::DistanceType::L2Unexpanded,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, L2SqrtUnexpanded might be needed here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, I would leave this in expanded form.

@@ -71,7 +71,8 @@ cumlError_t knn_search(const cumlHandle_t handle, float **input, int *sizes,
try {
ML::brute_force_knn(*handle_ptr, input_vec, sizes_vec, D, search_items, n,
res_I, res_D, k, rowMajorIndex, rowMajorQuery,
(ML::MetricType)metric_type, metric_arg, expanded);
(raft::distance::DistanceType)metric_type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to keep making this conversion explicit?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're using the same enum type everywhere now, I think this conversion can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conversion from int to DistanceType is needed to not modify the C API of knn_search. I changed it to a static_cast.

@@ -89,7 +89,7 @@ void get_distances(const raft::handle_t &handle,
k_graph.knn_indices, k_graph.knn_dists, k_graph.n_neighbors,
handle.get_cusparse_handle(), handle.get_device_allocator(), stream,
ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE,
ML::MetricType::METRIC_L2);
raft::distance::DistanceType::L2Expanded);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need L2SqrtUnexpanded here, unless this distance value doesn't reach user's eye (only used by TSNE internally).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't reach the users so I think it's okay not to use the sqrt here. I think the expanded form is also good to use here for speed.

@@ -91,7 +92,7 @@ void launcher(const raft::handle_t &handle,
inputsB.n, inputsB.d, out.knn_indices, out.knn_dists, n_neighbors,
handle.get_cusparse_handle(), d_alloc, stream,
ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE,
ML::MetricType::METRIC_L2);
raft::distance::DistanceType::L2Expanded);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, L2SqrtUnexpanded might be needed here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would leave all of these in expanded form. The unexpanded is more stable under some conditions but in general it's a better (and faster) starting point.

Comment on lines +484 to +486
if (metric == raft::distance::DistanceType::L2SqrtExpanded ||
metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
metric == raft::distance::DistanceType::LpUnexpanded) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be wrong, but I think that in this case, only unexpanded forms (that need post-procesing) : L2SqrtUnexpanded and LpUnexpanded should have post-processing. @cjnolet probably knows more about this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FAISS only supports the expanded form but I believe we're converting both the Expanded and Unexpanded L2 forms into faiss::METRIC_L2 so we'll need to sqrt both of them.

ML::MetricType metric_ = ML::MetricType::METRIC_L2,
float metricArg_ = 0, bool expanded_form_ = false)
raft::distance::DistanceType metric_ =
raft::distance::DistanceType::L2Unexpanded,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, L2SqrtUnexpanded might be needed here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would revert this to expanded form as well.

ML::MetricType metric = ML::MetricType::METRIC_L2,
float metricArg = 0, bool expanded_form = false) {
raft::distance::DistanceType metric =
raft::distance::DistanceType::L2Unexpanded,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, L2SqrtUnexpanded might be needed here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here.

if metric == "euclidean" or metric == "l2":
m = MetricType.METRIC_L2
m = DistanceType.L2SqrtExpanded
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, L2SqrtUnexpanded might be needed here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Expanded version is preferred for speed

Comment on lines 400 to 401
elif metric == "cityblock" or metric == "l1"\
or metric == "manhattan" or metric == 'taxicab':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can maybe be replaced by elif metric in [..., ...]:.

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@@ -57,6 +57,8 @@ cumlError_t knn_search(const cumlHandle_t handle, float **input, int *sizes,
cumlError_t status;
raft::handle_t *handle_ptr;
std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle);
raft::distance::DistanceType metric_distance_type =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this. The intent becomes more clear.

@@ -434,8 +378,12 @@ class sparse_knn_t {
dist_config.allocator = allocator;
dist_config.stream = stream;

raft::sparse::distance::pairwiseDistance(batch_dists, dist_config,
get_pw_metric(), metricArg);
if (raft::sparse::distance::supportedDistance.find(metric) ==
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the introduction of the explicit set for supported distances. For a follow-on / future PR, we might want to consider just using a hash map to map the distance enum type to its “distances_t” implementation, which will allow us to get rid of the switch statement altogether.

v0.19 Release automation moved this from PR-Needs review to PR-Reviewer approved Feb 22, 2021
@cjnolet
Copy link
Member

cjnolet commented Feb 22, 2021

@gpucibot merge

@cjnolet
Copy link
Member

cjnolet commented Feb 22, 2021

Adding a reminder / note here that the dense & sparse knn primitives will need to be updated in RAFT before #3476 goes in.

@cjnolet
Copy link
Member

cjnolet commented Feb 22, 2021

rerun tests

@cjnolet
Copy link
Member

cjnolet commented Feb 23, 2021

@lowener, according to the logs, it looks like there was a gtest failure in KNNTest:

[----------] 1 test from KNNTestF
[ RUN      ] KNNTestF.Fit
/opt/conda/envs/rapids/conda-bld/libcuml_1614025928022/work/cpp/test/prims/knn.cu:103: Failure
Value of: raft::devArrMatch(d_ref_D, d_pred_D, n * n, raft::CompareApprox<float>(1e-3))
  Actual: false (actual=2401 != expected=49 @1)
Expected: true
[  FAILED  ] KNNTestF.Fit (218 ms)

@JohnZed JohnZed added 5 - Ready to Merge Testing and reviews complete, ready to merge and removed 3 - Ready for Review Ready for review by team labels Feb 25, 2021
@JohnZed
Copy link
Contributor

JohnZed commented Feb 25, 2021

@gpucibot merge

@cjnolet
Copy link
Member

cjnolet commented Mar 2, 2021

rerun tests

@codecov-io
Copy link

Codecov Report

Merging #3389 (03fd464) into branch-0.19 (39c7262) will increase coverage by 9.02%.
The diff coverage is 73.19%.

Impacted file tree graph

@@               Coverage Diff               @@
##           branch-0.19    #3389      +/-   ##
===============================================
+ Coverage        71.77%   80.80%   +9.02%     
===============================================
  Files              212      227      +15     
  Lines            17075    17735     +660     
===============================================
+ Hits             12256    14331    +2075     
+ Misses            4819     3404    -1415     
Flag Coverage Δ
dask 45.30% <15.38%> (?)
non-dask 73.06% <71.79%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
...on/cuml/_thirdparty/sklearn/preprocessing/_data.py 63.24% <ø> (-0.30%) ⬇️
python/cuml/experimental/explainer/common.py 88.05% <42.85%> (-4.01%) ⬇️
python/cuml/neighbors/ann.pyx 61.62% <61.62%> (ø)
python/cuml/common/import_utils.py 59.43% <66.66%> (+3.43%) ⬆️
python/cuml/experimental/explainer/base.pyx 67.06% <67.06%> (ø)
python/cuml/metrics/hinge_loss.pyx 73.33% <73.33%> (ø)
python/cuml/dask/common/utils.py 43.68% <83.33%> (+16.13%) ⬆️
python/cuml/neighbors/nearest_neighbors.pyx 92.25% <85.71%> (-0.46%) ⬇️
...l/_thirdparty/sklearn/preprocessing/_imputation.py 62.80% <100.00%> (+0.29%) ⬆️
python/cuml/experimental/explainer/kernel_shap.pyx 97.75% <100.00%> (+0.48%) ⬆️
... and 108 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 6f06d4b...03fd464. Read the comment docs.

@mike-wendt
Copy link
Contributor

rerun tests

@rapids-bot rapids-bot bot merged commit 9fa6e17 into rapidsai:branch-0.19 Mar 2, 2021
v0.19 Release automation moved this from PR-Reviewer approved to Done Mar 2, 2021
@lowener lowener deleted the 018_replace_distancetype branch April 8, 2021 09:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
5 - Ready to Merge Testing and reviews complete, ready to merge Cython / Python Cython or Python issue improvement Improvement / enhancement to an existing function libcuml non-breaking Non-breaking change
Projects
No open projects
v0.19 Release
  
Done
Development

Successfully merging this pull request may close these issues.

[FEA] Replace ML::MetricType with raft::distance::DistanceType
6 participants