In [32]:
import numpy as np
import torch, pickle
import lightgbm as lgb
from hummingbird import convert_sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import fetch_covtype


In [33]:
# use some dataset from sklearn
X, y = fetch_covtype(return_X_y=True)
nrows=2500
X = X[0:nrows]
y = y[0:nrows]
X_torch = torch.from_numpy(X).float()

In [34]:
# Create and train a model
model = RandomForestClassifier(n_estimators=10)
model.fit(X, y)

RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
                       max_depth=None, max_features='auto', max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=10,
                       n_jobs=None, oob_score=False, random_state=None,
                       verbose=0, warm_start=False)

In [35]:
# Use hummingbird to convert your sklearn model to pytorch
pytorch_model = convert_sklearn(
    model, 
    extra_config = {"tree_implementation": "gemm"})

In [36]:
%%timeit -r 3

#time for skl
skl = model.predict(X)

4.58 ms ± 66.7 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)


In [37]:
%%timeit -r 3

# time for hummingbird - CPU
pytorch_model.to('cpu')
hum_cpu = pytorch_model(X_torch.float())

57.5 ms ± 119 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)


In [38]:
%%timeit -r 3

# time for hummingbird - GPU. Note that you must have a GPU-enabled machine.
pytorch_model.to('cuda')
hum_gpu = pytorch_model(X_torch.to('cuda'))

3.49 ms ± 664 ns per loop (mean ± std. dev. of 3 runs, 100 loops each)


In [51]:
# make sure Hummingbird output matches Scikit-learn as expected
# (note that we have to recreate skl since `timeit` make it out of scope here)
skl = model.predict_proba(X)
pytorch_model.to('cuda')
hum_gpu = pytorch_model(X_torch.to('cuda'))

np.testing.assert_allclose(skl, hum_gpu[1].data.to('cpu').numpy(), rtol=1e-6, atol=1e-1) #TODO atol