Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add vectorize method #134

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Changelog
Ver 0.1.*
---------

* |Feature| |API| Add :meth:`vectorize` for faster ensemble model inference using :mod:`functorch` (requiring :mod:`torch` version >= 1.13.0) | `@xuyxu <https://github.com/xuyxu>`__
* |Feature| |API| Add ``voting_strategy`` parameter for :class:`VotingClassifer`, :class:`NeuralForestClassifier`, and :class:`SnapshotEnsembleClassifier` | `@LukasGardberg <https://github.com/LukasGardberg>`__
* |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe <https://github.com/SunHaozhe>`__
* |Feature| |API| Add :class:`NeuralForestClassifier` and :class:`NeuralForestRegressor` | `@xuyxu <https://github.com/xuyxu>`__
Expand Down
15 changes: 15 additions & 0 deletions docs/advanced.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Advanced Usage
==============

The following sections outline advanced usage in :mod:`torchensemble`.

Faster inference using functorch
--------------------------------

:mod:`functorch` has been integrated into Pytorch since the release of version 1.13, which is JAX-like composable function transforms for PyTorch. To enable faster inference of ensembles in :mod:`torchensemble`, you could use :meth:`vectorize` method of the ensemble to convert it into a stateless version (fmodel), and stacked parameters and buffers.

The stateless model, parameters, along with buffers could be used to reduce the inference time using :meth:`vmap` in :mod:`functorch`. More details are available at `functorch documentation <https://pytorch.org/functorch/stable/notebooks/ensembling.html>`__. The code snippet below demonstrates how to pass :meth:`ensemble.vectorize` results into :meth:`functorch.vmap`.

.. code:: python

from torchensemble import VotingClassifier # voting is a classic ensemble strategy
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Content
Guidance <guide>
Experiment <experiment>
API Reference <parameters>
Advanced Usage <advanced>

.. toctree::
:maxdepth: 1
Expand Down
18 changes: 17 additions & 1 deletion torchensemble/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ def get_doc(item):
__doc = {
"model": const.__model_doc,
"seq_model": const.__seq_model_doc,
"tree_ensmeble_model": const.__tree_ensemble_doc,
"tree_ensemble_model": const.__tree_ensemble_doc,
"fit": const.__fit_doc,
"predict": const.__predict_doc,
"vectorize": const.__vectorize_doc,
"set_optimizer": const.__set_optimizer_doc,
"set_scheduler": const.__set_scheduler_doc,
"set_criterion": const.__set_criterion_doc,
Expand Down Expand Up @@ -199,6 +200,21 @@ def predict(self, *x):
pred = pred.cpu()
return pred

def vectorize(self):
"""Docstrings decorated by downstream ensembles."""
try:
from functorch import combine_state_for_ensemble
except Exception:
msg = (
"Failed to import functorch utils, please make sure the"
" Pytorch version >= 1.13.0."
)
raise RuntimeError(msg)

self.eval()
fmodel, params, buffers = combine_state_for_ensemble(self.estimators_)
return fmodel, params, buffers


class BaseTreeEnsemble(BaseModule):
def __init__(
Expand Down
15 changes: 15 additions & 0 deletions torchensemble/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,21 @@
"""


__vectorize_doc = """
Return the vectorization result of the ensemble using functorch. Details
available at `functorch model ensembling <https://pytorch.org/functorch/stable/notebooks/ensembling.html>`_.

Returns
-------
fmodel : FunctionalModuleWithBuffers
Functional version of one of the models in the ensemble.
params : tuple
Tuple of stacked model parameters in the ensemble.
buffers : tuple
Tuple of buffers, empty if not exists.
""" # noqa: E501


__classification_forward_doc = """
Parameters
----------
Expand Down
78 changes: 78 additions & 0 deletions torchensemble/tests/test_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import torch.nn as nn
from numpy.testing import assert_array_equal

from functorch import vmap
from torch.utils.data import TensorDataset, DataLoader

import torchensemble
Expand Down Expand Up @@ -302,3 +304,79 @@ def test_predict():
with pytest.raises(ValueError) as excinfo:
model.predict([X_test]) # list
assert "The type of input X should be one of" in str(excinfo.value)


@pytest.mark.parametrize("clf", all_clf)
def test_clf_vectorize_same_output(clf):
"""
This unit test checks the inference with/without vectorize for all
classifiers.
"""
epochs = 2
n_estimators = 2

model = clf(estimator=MLP_clf, n_estimators=n_estimators, cuda=False)

# Optimizer
model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4)

# Prepare data
train = TensorDataset(X_train, y_train_clf)
train_loader = DataLoader(train, batch_size=2, shuffle=False)
test = TensorDataset(X_test, y_test_clf)
test_loader = DataLoader(test, batch_size=2, shuffle=False)

# Train
model.fit(train_loader, epochs=epochs, test_loader=test_loader)

fmodel, params, buffers = model.vectorize()

with torch.no_grad():
for idx, (data, target) in enumerate(test_loader):
vmap_output = vmap(fmodel, in_dims=(0, 0, None))(
params, buffers, data
)
pytorch_output = [
estimator(data) for estimator in model.estimators_
]
assert torch.allclose(
vmap_output, torch.stack(pytorch_output), atol=1e-3, rtol=1e-5
)


@pytest.mark.parametrize("reg", all_reg)
def test_reg_vectorize_same_output(reg):
"""
This unit test checks the inference with/without vectorize for all
classifiers.
"""
epochs = 2
n_estimators = 2

model = reg(estimator=MLP_reg, n_estimators=n_estimators, cuda=False)

# Optimizer
model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4)

# Prepare data
train = TensorDataset(X_train, y_train_reg)
train_loader = DataLoader(train, batch_size=2, shuffle=False)
test = TensorDataset(X_test, y_test_reg)
test_loader = DataLoader(test, batch_size=2, shuffle=False)

# Train
model.fit(train_loader, epochs=epochs, test_loader=test_loader)

fmodel, params, buffers = model.vectorize()

with torch.no_grad():
for idx, (data, target) in enumerate(test_loader):
vmap_output = vmap(fmodel, in_dims=(0, 0, None))(
params, buffers, data
)
pytorch_output = [
estimator(data) for estimator in model.estimators_
]
assert torch.allclose(
vmap_output, torch.stack(pytorch_output), atol=1e-3, rtol=1e-5
)
20 changes: 18 additions & 2 deletions torchensemble/voting.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,13 @@ def evaluate(self, test_loader, return_loss=False):
def predict(self, *x):
return super().predict(*x)

@torchensemble_model_doc(item="vectorize")
def vectorize(self):
return super().vectorize()


@torchensemble_model_doc(
"""Implementation on the NeuralForestClassifier.""", "tree_ensmeble_model"
"""Implementation on the NeuralForestClassifier.""", "tree_ensemble_model"
)
class NeuralForestClassifier(BaseTreeEnsemble, VotingClassifier):
def __init__(self, voting_strategy="soft", **kwargs):
Expand Down Expand Up @@ -374,6 +378,10 @@ def fit(
save_dir=save_dir,
)

@torchensemble_model_doc(item="vectorize")
def vectorize(self):
return super().vectorize()


@torchensemble_model_doc("""Implementation on the VotingRegressor.""", "model")
class VotingRegressor(BaseRegressor):
Expand Down Expand Up @@ -559,9 +567,13 @@ def evaluate(self, test_loader):
def predict(self, *x):
return super().predict(*x)

@torchensemble_model_doc(item="vectorize")
def vectorize(self):
return super().vectorize()


@torchensemble_model_doc(
"""Implementation on the NeuralForestRegressor.""", "tree_ensmeble_model"
"""Implementation on the NeuralForestRegressor.""", "tree_ensemble_model"
)
class NeuralForestRegressor(BaseTreeEnsemble, VotingRegressor):
@torchensemble_model_doc(
Expand Down Expand Up @@ -620,3 +632,7 @@ def fit(
save_model=save_model,
save_dir=save_dir,
)

@torchensemble_model_doc(item="vectorize")
def vectorize(self):
return super().vectorize()