Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate FAISS index for KNN classifier #557

Merged
merged 10 commits into from Jul 27, 2020
Merged

Conversation

brc7
Copy link
Contributor

@brc7 brc7 commented Jul 16, 2020

Description of changes:
This code adds

The new KNeighborsClassifier is a drop-in replacement for the sklearn model, but is built on top of the FAISS index. There is one additional optional parameter for the model: index_factory_string, which describes what type of FAISS index to use (this just gets passed to FAISS's index construction method).

There is one new dependency on faiss-cpu.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@szha
Copy link

szha commented Jul 16, 2020

Job PR-557-1 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-557/1/index.html

Copy link
Contributor

@Innixma Innixma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excited to see this! Once my comments relating to extending KNNModel are addressed, I will try this out and run it on our benchmarks to see how it compares to the default KNN.



def fit(self, X_train, y_train):
X_train = X_train.astype(np.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we check if X_train is a dataframe, we could do X_train.to_numpy(dtype=np.float32), which may eliminate an unnecessary data copy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point - fixed. Also fixed a second unnecessary copy that might happen in the following line.

return self

def predict(self, X):
X = X.astype(np.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -32,7 +32,7 @@ def preprocess(self, X):
def _set_default_params(self):
default_params = {
'weights': 'uniform',
'n_jobs': -1,
'index_factory_string' : 'Flat',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we not able to use multiple cores with this implementation? Why was n_jobs removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this param back in to allow for the use of fewer cores. It was removed because Faiss uses OpenMP underneath, which defaults to using all cores and there wasn't much in the docs on how to change it

from ...constants import REGRESSION
from ....utils.exceptions import NotEnoughMemoryError

logger = logging.getLogger(__name__)


# TODO: Normalize data!
class KNNModel(SKLearnModel):
class KNNModel(AbstractModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the first version of this new FAISS implementation, could you instead extend the KNN model to a new model class so we aren't removing the existing KNNModel? Something like FAISSModel? This will help significantly with apples to apples comparisons.

In regards to changing KNNModel to implement AbstractModel instead of SKLearnModel, I believe this is fine as KNN no longer utilizes SKLearnModel to my knowledge.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good suggestion - done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps I wasn't quite clear in my first comment. I was referring to the KNNModel class, such as creating a new:

class FAISSModel(KNNModel):
    def _get_model_type(self):
        if self.problem_type == REGRESSION:
            return FAISSNeighborsRegressor
        else:
            return FAISSNeighborsClassifier

    def _set_default_params(self):
        default_params = {
            'index_factory_string': 'Flat',
        }
        for param, val in default_params.items():
            self._set_default_param_value(param, val)
        super()._set_default_params()

and changing KNNModel slightly to look like:

class KNNModel(AbstractModel):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self._model_type = self._get_model_type()

    def _get_model_type(self):
        if self.problem_type == REGRESSION:
            return KNeighborsRegressor
        else:
            return KNeighborsClassifier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks - this makes a bit more sense than what I had in mind. Should be able to easily use both model types for comparisons now.


# Rather than try to import non-public sklearn internals, we implement our own weighting functions here
# These support the same operations as the sklearn functions - at least as far as possible with FAISS
def _check_weights(weights):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit comments relating to code readability: If you haven't already I'd recommend adding PEP8 violation highlighting to your IDE. While we don't follow the line-length limit PEP8 recommendation, we do follow most others, such as having 2 new lines between each def and class, as well as 2 new lines after imports before code is written.

This isn't critical, as I can clean-up the code afterwards, but is something to keep in mind in future PR's to make the process smoother.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tip, I didn't know you could do this. I did it and formatted everything in a separate commit.

if weights is None:
y_pred, _ = mode(outputs, axis = 1)
else:
y_pred,_ = weighted_mode(outputs, weights, axis = 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add space between , and _

return y_pred

def predict_proba(self, X):
X = X.astype(np.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

setup.py Outdated
@@ -57,7 +57,8 @@ def create_version_file():
'pandas>=0.24.0,<1.0',
'psutil>=5.0.0',
'scikit-learn>=0.22.0,<0.23',
'networkx>=2.3,<3.0'
'networkx>=2.3,<3.0',
'faiss-cpu>=1.6.3'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May decide to remove this as a required dependency, but can keep for the moment. Will have to identify how heavy the package is.

self.index = faiss.deserialize_index(self.index)



Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Have exactly 2 new lines between classes

self.__dict__.update(state)
self.index = faiss.deserialize_index(self.index)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: have exactly 1 new line before the EOF

@@ -53,6 +53,8 @@ def __init__(self, n_neighbors = 5, weights='uniform', n_jobs = -1, index_factor

The model itself is a clone of the sklearn one
"""
try_import_faiss()
import faiss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just set self.faiss = faiss and then you don't need to import it in every method

@jwmueller
Copy link
Contributor

Should there be a hyperparameter that quantifies the exactness (approximation-error) of the nearest neighbor retrieval when the index gets built? Would be good to document at least one of the straightforward ways to control this, in say a comment for now.

@jwmueller
Copy link
Contributor

Do you plan to extend this categorical + numeric features?

@brc7
Copy link
Contributor Author

brc7 commented Jul 19, 2020

Should there be a hyperparameter that quantifies the exactness (approximation-error) of the nearest neighbor retrieval when the index gets built? Would be good to document at least one of the straightforward ways to control this, in say a comment for now.

This is what "index_factory_string" is for - different factory strings produce an index with different exactness tradeoffs.

Do you plan to extend this categorical + numeric features?

For now, no. FAISS only supports float32 vectors, so it's not immediately clear how to do this. You could code categorical features as vectors in R^d (either naively or with a mapping to a sphere). You could also build a classifier around a (non-faiss) index that supports something like the Jaccard metric. I'm not sure which is best.

@Innixma
Copy link
Contributor

Innixma commented Jul 19, 2020

I think CI is broke atm by some package updating, will look into it Sunday.

@Innixma
Copy link
Contributor

Innixma commented Jul 20, 2020

@brc7 I believe the CI is now fixed on mainline. Can you rebase your PR onto mainline to get the latest fix?

@jwmueller
Copy link
Contributor

jwmueller commented Jul 21, 2020

Should there be a hyperparameter that quantifies the exactness (approximation-error) of the nearest neighbor retrieval when the index gets built? Would be good to document at least one of the straightforward ways to control this, in say a comment for now.

This is what "index_factory_string" is for - different factory strings produce an index with different exactness tradeoffs.

Do you plan to extend this categorical + numeric features?

For now, no. FAISS only supports float32 vectors, so it's not immediately clear how to do this. You could code categorical features as vectors in R^d (either naively or with a mapping to a sphere). You could also build a classifier around a (non-faiss) index that supports something like the Jaccard metric. I'm not sure which is best.

Sounds good, thanks for clarifying. I'm unsure mixing Jaccard + Euclidean metrics for datasets with both categorical+continuous features will work well with any existing indices?
I think since FAISS is scalable, we can instead just use a preprocessor like the neural-net data preprocessor to convert categoricals -> floats with one-hot encoding + strict limit on # of possible categories (ie. keep top K categories per feature (based on prevalence), all other rarer categories are lumped into a single "other" category, so total OHE dimension is at most K+1).

Might you have bandwidth to add this in a followup PR?
The preprocessor code is here:
https://github.com/awslabs/autogluon/blob/master/autogluon/utils/tabular/ml/models/tabular_nn/tabular_nn_model.py#L659
https://github.com/awslabs/autogluon/blob/master/autogluon/utils/tabular/ml/models/tabular_nn/tabular_nn_model.py#L510-L516
(except for KNN we may want to counterintuitively not normalize continuous features, but probably do at least want to impute them if missing as done in this preprocessor)

Other ideas for followup PRs:

  • Auto-selection of best k-value based on the validation data would be awesome :)
  • Subclass BaggedEnsemble for KNN to just have one KNN model with leave-one-out estimates instead of k KNN copies
  • If hyperparameter_tune=True, use hyperparameter-tuning to select best metric and whether distance-weighted-mean over neighbors should be used to produce predictions or not.
  • Feature selection capabilities to extend to higher-dimensional datasets.

@brc7
Copy link
Contributor Author

brc7 commented Jul 21, 2020

@Innixma I rebased onto the psutil fix, but it seems that CI is still breaking. Can you confirm? (maybe I did rebase wrong)

@Innixma
Copy link
Contributor

Innixma commented Jul 21, 2020

@brc7 No worries, I think the psutil fix we did did not resolve the issue, and the issue appears to be non-deterministic. We will continue looking into it, but it isn't a problem with your code.

@Innixma
Copy link
Contributor

Innixma commented Jul 21, 2020

@brc7 Could you try rebasing again? I believe (I really hope) the CI is fixed now after merging #577.

@brc7
Copy link
Contributor Author

brc7 commented Jul 21, 2020

@brc7 Could you try rebasing again? I believe (I really hope) the CI is fixed now after merging #577.

Done!

@szha
Copy link

szha commented Jul 21, 2020

Job PR-557-10 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-557/10/index.html

@Innixma
Copy link
Contributor

Innixma commented Jul 21, 2020

Great, now it has passed CI!

Next is to benchmark the new implementation against the previous.

I have written some code to help you get started:

import time

from collections import defaultdict

from autogluon import TabularPrediction as task

SAMPLE = 50000
# SAMPLE = None


if __name__ == '__main__':
    banned_models_to_compare = ['weighted_ensemble_k0_l1']

    from autogluon.utils.tabular.ml.models.knn.knn_model import FAISSModel, KNNModel
    hyperparameters = {
        FAISSModel: {},
        KNNModel: {},
    }

    path_s3_prefix = 'https://autogluon.s3.amazonaws.com/datasets/'

    dataset_names = ['AdultIncomeBinaryClassification', 'AmesHousingPriceRegression', 'CoverTypeMulticlassClassification']
    dataset_labels = ['class', 'SalePrice', 'Cover_Type']

    dataset_info = {}
    for i, dataset_name in enumerate(dataset_names):
        label = dataset_labels[i]
        path_prefix = path_s3_prefix + dataset_name + '/'
        path_train = path_prefix + 'train_data.csv'
        path_test = path_prefix + 'test_data.csv'

        X_train = task.Dataset(file_path=path_train)
        X_test = task.Dataset(file_path=path_test)

        if SAMPLE is not None and SAMPLE < len(X_train):
            X_train = X_train.sample(n=SAMPLE, random_state=0)

        predictor = task.fit(
            train_data=X_train, label=label,
            hyperparameters=hyperparameters,
        )

        leaderboard = predictor.leaderboard(X_test)

        models = leaderboard['model'].unique()
        valid_models = [model for model in models if model not in banned_models_to_compare]
        valid_models.sort()
        predictor_info = predictor.info()

        time_avg_dict = defaultdict(int)
        disk_usage_dict = {}
        num_repeats = 2
        for model in valid_models:
            time_start = time.time()
            for i in range(num_repeats):
                predictor.predict(X_test, model=model)
            time_end = time.time()
            time_avg_dict[model] = (time_end - time_start) / num_repeats
            disk_usage_dict[model] = predictor_info['model_info'][model]['memory_size']
        dataset_info[dataset_name] = (time_avg_dict, disk_usage_dict)

    for dataset_name, (time_avg_dict, disk_usage_dict) in dataset_info.items():
        print(f'On dataset {dataset_name}:')
        for model, inference_time in time_avg_dict.items():
            print(f'\tModel {model} took an average of {round(inference_time, 3)}s to infer')
        for model, disk_usage in disk_usage_dict.items():
            print(f'\tModel {model} took {round(disk_usage/1000000, 3)}MB of disk')

Running this on my Mac, I get the following results:

On dataset AdultIncomeBinaryClassification:
	Model FAISSModelClassifier took an average of 2.417s to infer
	Model KNeighborsClassifier took an average of 0.182s to infer
	Model FAISSModelClassifier took 2.05MB of disk
	Model KNeighborsClassifier took 4.361MB of disk
On dataset AmesHousingPriceRegression:
	Model FAISSModelRegressor took an average of 0.226s to infer
	Model KNeighborsRegressor took an average of 0.304s to infer
	Model FAISSModelRegressor took 0.588MB of disk
	Model KNeighborsRegressor took 1.214MB of disk
On dataset CoverTypeMulticlassClassification:
	Model FAISSModelClassifier took an average of 47.672s to infer
	Model KNeighborsClassifier took an average of 5.648s to infer
	Model FAISSModelClassifier took 20.523MB of disk
	Model KNeighborsClassifier took 42.845MB of disk

While I'm very glad to see that the disk usage is halved when using FAISS, it seems to be much slower to infer. Could you try to experiment to see why this is and remedy it? The difference seems to become larger with more training data, with SAMPLE=None, KNN is 9.7s and FAISS is 454s on CoverType.

Also, the scores between KNN and FAISS are not always identical, and can at times differ drastically. Can you explain why this is?

Another bit of optimization I stumbled upon which you might want to give a try: https://www.sicara.ai/blog/2017-07-05-fast-custom-knn-sklearn-cython

@brc7
Copy link
Contributor Author

brc7 commented Jul 22, 2020

While I'm very glad to see that the disk usage is halved when using FAISS, it seems to be much slower to infer. Could you try to experiment to see why this is and remedy it? The difference seems to become larger with more training data, with SAMPLE=None, KNN is 9.7s and FAISS is 454s on CoverType.

The default settings in the FAISSNeighbor classes just do exact brute-force search on CPU with no smart algorithms, while sklearn automatically chooses a good index type. More sensible defaults will probably fix it, but I'll profile / benchmark to make sure there's nothing too crazy going on.

Also, the scores between KNN and FAISS are not always identical, and can at times differ drastically. Can you explain why this is?

I have no idea but it will be interesting to find out, since they're supposed to be the same.

@jwmueller
Copy link
Contributor

Auto-selection of the best index type based on dataset characteristics & GPU/CPU hardware would be great to have, but for now I'd say the default index type should be set to maximize performance for training datasets with #rows in the 50-500k range and #columns in the 10-100 range (I'd guess this is the bulk of autogluon-KNN usage where performance actually starts to matter).

@brc7
Copy link
Contributor Author

brc7 commented Jul 23, 2020

Regarding the validation score differences:

TL;DR: FAISS cheats when it computes distances
"Flat" should have exactly the same performance as sklearn, since it brute forces all the distances and then it selects points the same way that sklearn does. But it doesn't.

Example: I issued a query for sklearn, and it returned:
indices of nearest 5 = [33858, 9245, 7542, 17405, 22308]
distances = [534.3, 548.9, 798.8, 820.9, 1424.3]

Okay, what does FAISS say for this query?
indices of nearest 5 = [33858, 9245, 7542, 17405, 22308]
distances = [535.5, 543.1, 799.4, 819.6, 1425.3]

This difference seems small but if many points have close distances, you might get a different classification with the FAISSNeighbor classifier than with sklearn.

This happens because FAISS uses BLAS on CPU to brute force distances for batch sizes larger than 20, and BLAS computes the distance in such a way that the 32bit rounding errors become a problem. This is controlled by the faiss.distance_compute_blas_threshold parameter. If you set it larger than the query batch size, I have verified that the distances become correct and the validation scores are the same, at a small speed cost.

Do we want to do anything about this? Note that the other (faster) index types, such as HNSW are approximate, so the rounding errors are not so noticeable there.

@Innixma
Copy link
Contributor

Innixma commented Jul 23, 2020

It's primarily down to what is practically most useful. I think the validation score differences are less of a concern at present than the inference times. If it is only a small speed decrease to match KNN on validation, it might be best so that we don't have to take that additional factor into consideration when determining which approach is better.

@brc7
Copy link
Contributor Author

brc7 commented Jul 24, 2020

All right, I did some comparisons. FAISS has a lot of index types, so this is only a partial comparison (even though there are lots of rows in the table)

PQ is product quantization (with brute-force search), IVFx is a type of sharding (where we only check some of the "x" shards), Flat is a brute-force search, and HNSWx is the approximate near neighbor graph method with "x" connections per node.

ef_search is a parameter that trades off the HNSW accuracy for speed. By setting it to minimum, you get the fastest possible HNSW with maybe less good results.

AdultIncomeBinaryClassification

Index Train (sec) Infer (sec) Size (MB) Validation Score
sklearn 0.200 0.154 4.361 0.7752
PQ2 19.172 0.501 0.453 0.772
PQ3 20.214 0.606 0.527 0.788
IVF256,Flat 0.325 0.121 2.652 0.772
IVF256,HNSW4 0.630 0.102 5.66 0.766
IVF1024,HNSW4 0.613 0.111 5.66 0.766
Flat 0.078 0.879 2.05 0.7736
HNSW4 0.617 0.121 5.66 0.765
HNSW4 (min ef_search) 0.615 0.098 5.66 0.762
HNSW4_SQ8 0.656 0.101 5.66 0.765
HNSW4_SQ8 (min ef_search) 0.606 0.093 5.66 0.762
HNSW6 0.552 0.115 6.787 0.768
HNSW6 (min ef_search) 0.581 0.105 6.787 0.765
HNSW8_SQ8 (min ef_search) 0.751 0.128 7.945 0.759
HNSW16 0.920 0.288 21.963 0.770
HNSW16 (min ef_search) 0.591 0.162 12.604 0.765
HNSW32 (min ef_search) 0.934 0.253 21.963 0.769

AmesHousingPriceRegression

Index Train (sec) Infer (sec) Size (MB) Validation Score
sklearn 0.024 0.217 1.214 --
Flat 0.126 0.134 0.588 --
PQ2 18.741 0.151 0.104 --
IVF256,Flat 0.059 0.152 0.699 --
IVF256,HNSW4 0.053 0.130 0.773 --
IVF1024,HNSW4 0.054 0.135 0.773 --
HNSW4 0.055 0.136 0.773 --
HNSW4 (min ef_search) 0.053 0.141 0.773 --
HNSW4_SQ8 0.055 0.142 0.773 --
HNSW4_SQ8 (min ef_search) 0.052 0.149 0.773 --
HNSW6 0.064 0.135 0.831 --
HNSW6 (min ef_search) 0.045 0.136 0.831 --
HNSW8_PQ8 (min ef_search) 0.047 0.151 0.891 --
HNSW16 0.077 0.178 1.608 --
HNSW16 (min ef_search) 0.038 0.155 1.128 --
HNSW32 (min ef_search) 0.061 0.148 1.608 --

CoverTypeMulticlassClassification

Index Train (sec) Infer (sec) Size (MB) Validation Score
sklearn 5.06 2.272 419.6 0.9653
Flat 0.982 139.4 202.5 0.9655
PQ2 21.214 72.180 5.636 0.707
PQ3 21.545 96.833 6.556 0.726
IVF256,Flat 4.199 6.370 209.951 0.963
IVF256,HNSW4 12.283 2.964 247.869 0.956
IVF1024,HNSW4 12.130 2.887 247.869 0.955
HNSW4 13.447 2.921 247.869 0.955
HNSW4 (min ef_search) 13.363 2.575 247.869 0.934
HNSW4_PQ8 11.938 3.175 247.869 0.955
HNSW4_PQ8 (min ef_search) 12.651 2.680 247.869 0.934
HNSW6 10.973 3.542 262.091 0.964
HNSW6 (min ef_search) 11.240 2.578 262.091 0.953
HNSW8_PQ8 (min ef_search) 12.095 2.685 276.621 0.957
HNSW16 (min ef_search) 10.803 3.343 335.245 0.961
HNSW16 17.121 5.208 452.907 0.965
HNSW32 (min ef_search) 17.856 5.254 452.907 0.963

A good step toward an auto-selection routine would be to encode these guidelines, but for now I am leaning towards a HNSW index, probably with IVF(dependent on n) and some form of vector compression as the default. The trouble is that the best index is pretty problem dependent, but I think HNSW is a good all-purpose option

@jwmueller
Copy link
Contributor

@brc7 Thanks for the experiments! Hopefully extending the KNN to categoricals will not greatly alter these conclusions (I guess if we use OHE, the KNN will just operate on higher-dimensional vectors where quantization would anyway have no accuracy-effect for OHE dimensions).

Do you think we should view validation scores for the regression task as well? I suspect the approximation-factor of nearest-neighbor retrieval may have an even greater effect on accuracy in regression.

Re the 2 datasets with validation accuracies: sklearn performance seems unmatched by any point along FAISS' Pareto-frontier of accuracy/inference-speed. If all sklearn were doing differently was automatically choosing the appropriate index, I'd expect it's performance would lie at one point along FAISS' Pareto-frontier. In particular, none of the FAISS models (regardless of index) are as fast as sklearn for CoverType; any insight on this?

@brc7
Copy link
Contributor Author

brc7 commented Jul 27, 2020

Do you think we should view validation scores for the regression task as well?
I was getting negative validation scores that I wasn't sure how to interpret. It's trivial to re-run experiments, so happy to include if desired.

sklearn performance seems unmatched by any point along FAISS' Pareto-frontier of accuracy/inference-speed ... In particular, none of the FAISS models (regardless of index) are as fast as sklearn for CoverType; any insight on this?

Sklearn has different algorithm options than FAISS. I looked at their auto-selection code - it seems sklearn always uses a BallTree unless you have a really weird distance function. It may be that I am configuring the FAISS index wrong, or maybe BallTrees are actually better for typical AutoGluon use cases.

I think it's a combination of both. FAISS(hnsw) can get great performance (look at ann-benchmarks), but those benchmarks do a lot of dataset-specific tuning and are on large pro. I suspect that due to memory limits, FAISS will be the only option for large-scale AutoML (i.e. N > 10 million) or with really high dimensional inputs (i.e. OHE for big multi-class problem), and the best option when a GPU is available.

@Innixma
Copy link
Contributor

Innixma commented Jul 27, 2020

Regarding large scale data, can you try to also get benchmark results for uncapped Covertype (SAMPLE=None)? This might help indicate how well FAISS scales compared to Sklearn. Also, even if FAISS can't exceed sklearn but at least get close, it still may have significant use when bagging if we can get a shared index.

Regarding the state of the code, I think we are good to merge after the dependency is removed from setup.py for faiss, so that it is purely an extra optional dependency until we nail down more aspects of its performance. Can you please update the PR to have setup.py be unchanged? Then I am good to give approval and merge.

@brc7
Copy link
Contributor Author

brc7 commented Jul 27, 2020

Should have clarified - all results are for SAMPLE=None. The covtype experiments are with the full 464k data points.

I've removed the setup.py dependency - should be good to go

@szha
Copy link

szha commented Jul 27, 2020

Job PR-557-11 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-557/11/index.html

Copy link
Contributor

@Innixma Innixma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution!

@Innixma Innixma merged commit 7afea4a into autogluon:master Jul 27, 2020
@jwmueller
Copy link
Contributor

jwmueller commented Jul 28, 2020

@brc7 The validation scores are negative for regression because we assume higher==better for all metrics so those like MSE are flipped in sign. I think looking at some regression performance will be important in choosing the index (we can ofc use different default index for regression vs classification).

Thanks for the explanation re sklearn's BallTree vs FAISS options. I agree FAISS should still be best option for large datasets, but its advantage seems less clear for these 3 datasets (except model-size has improved), so it's important to establish where FAISS >> sklearn.

Also, if we end up needing to keep sklearn KNN around as well as FAISS (eg. for different data regimes), it may be worth standardizing their APIs a bit more so new KNN-functionality can be developed for both sklearn/FAISS models without duplicate code.

@jwmueller
Copy link
Contributor

@Innixma

even if FAISS can't exceed sklearn but at least get close, it still may have significant use when bagging if we can get a shared index.

I'm not sure what you mean here. In how I envisioned it, either sklearn-KNN or FAISS could be used to produce held-out predictions on-demand from one index (ie. no bagging), so whichever is better could be used. My thinking was just to produce leave-one-out predictions for each held-out datapoint after first calling KNN fit() only once on the entire training-set (build shared index). To produce KNN held-out prediction for a datapoint, the resulting KNN model just needs a special predict_loo() method which can ignore 1 of the nearest-neighbors (whose distance == 0) since this same datapoint is also present in the KNN-training-set.

Technically this shared KNN model should be built with (k+1) replacing hyperparameter k, so if k affects the index-creation, there theoretically might need to be two KNN models (one with k used for inference, one with (k+1) used during training); for reasonable values of k this difference probably doesn't practically matter...

One practical issue that may arise with this leave-one-out scheme is the KNN validation-scores may be biased higher than other models because KNN is trained on more data (all except one datapoint) than 10-fold training offers. But this also shouldn't affect ensemble-construction, just the leaderboard.

@Innixma
Copy link
Contributor

Innixma commented Jul 29, 2020

@brc7 and @jwmueller

Regarding small - medium datasets (<10M rows), perhaps we could use nmslib?

According to ann_benchmark, nmslib's HNSW implementation is SOTA, and nearly an order of magnitude faster compared to FAISS: https://github.com/erikbern/ann-benchmarks

The downside is it doesn't scale as well to truly huge 1B+ row datasets like FAISS can, but even the FAISS authors recommend using nmslib for medium-small datasets when memory isn't a major concern: https://github.com/facebookresearch/faiss/wiki/Indexing-1M-vectors facebookresearch/faiss#23

@willsmithorg
Copy link
Contributor

willsmithorg commented Jan 12, 2022

There's a great video on HNSW, how it works, and how to choose the parameters to trade off disk space, accuracy, inference and model build time here: https://www.pinecone.io/learn/hnsw/ .
Or just read that page - the content is the same as the video. The video is 35 minutes.

@Innixma
Copy link
Contributor

Innixma commented Jan 12, 2022

@willsmithorg Thanks for this! We may consider this deeper when we look into optimizing for 10M+ row datasets. Currently KNN is ~250x faster in v0.3.1 than it was at the time of this PR due to numerous optimizations (usage of sklearnex + efficient LOO for bagging), so the concern over having a faster KNN model has become less critical. Still, the issue of practicality on truly large datasets remains an open problem and we will likely turn towards ANN to solve.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants