
CSYE7105 - Parallel Machine Learning and AI

Instructor: Dr. Handan Liu

Example: Parallel Machine Learning: XGBoost in Parallel with Dask


In [None]:
from dask.distributed import Client

client = Client(n_workers=1, processes=False, threads_per_worker=8)
client

In [None]:
from dask_ml.datasets import make_classification

X, y = make_classification(n_samples=100000, n_features=20,
                           chunks=1000, n_informative=4,
                           random_state=0)
X

In [None]:
from dask_ml.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.15)

In [None]:
import dask_ml
import xgboost

In [None]:
from xgboost import dask as dxgb

In [None]:
# from xgboost import XGBRegressor 

In [None]:
import dask_xgboost  #??

In [None]:
params = {'objective': 'binary:logistic',
          'max_depth': 4, 'eta': 0.01, 'subsample': 0.5,
          'min_child_weight': 0.5}

bst = dask_xgboost.train(client, params, X_train, y_train, num_boost_round=10)

In [None]:
clf = dxgb.DaskXGBClassifier(n_estimators=100, tree_method="hist")
clf.client = client  # assign the client
bst = clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])
proba = clf.predict_proba(X_train)

In [None]:
clf

In [None]:
proba 

In [None]:
proba.computer()

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

ax = xgboost.plot_importance(bst, height=0.8, max_num_features=9)
ax.grid(False, axis="y")
ax.set_title('Estimated feature importance');

In [None]:
y_hat = dxgb.predict(client, bst, X_test).persist()
y_hat

In [None]:
from sklearn.metrics import roc_curve

y_test, y_hat = dask.compute(y_test, y_hat)
fpr, tpr, _ = roc_curve(y_test, y_hat)

In [None]:
from sklearn.metrics import auc

fig, ax = plt.subplots(figsize=(5, 5))
ax.plot(fpr, tpr, lw=3,
        label='ROC Curve (area = {:.2f})'.format(auc(fpr, tpr)))
ax.plot([0, 1], [0, 1], 'k--', lw=2)
ax.set(
    xlim=(0, 1),
    ylim=(0, 1),
    title="ROC Curve",
    xlabel="False Positive Rate",
    ylabel="True Positive Rate",
)
ax.legend();