# Tree-based models - Examples

In [1]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd

from sklearn.datasets import load_boston
from sklearn.ensemble import (
    RandomForestClassifier,
    RandomForestRegressor,
    GradientBoostingClassifier,
    GradientBoostingRegressor,
)
from sklearn.metrics import mean_squared_error
from sklearn.tree import (
    DecisionTreeClassifier,
    DecisionTreeRegressor,
)

from utils.utils import *

# Decision trees

## Classification

In [2]:
data = make_data()
X, y = separate_target_variable(data)

X = process_categorical_features(X) 

dtc = DecisionTreeClassifier()
dtc.fit(X, y)
dtc.predict(X)

array([0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0])

## Regression

In [3]:
def prepare_boston():
    boston = load_boston()
    X = pd.DataFrame(data=boston.data, columns=boston.feature_names)
    y = pd.Series(data=boston.target, name='price')
    
    return X, y


X_, y_ = prepare_boston()

In [4]:
dtc = DecisionTreeRegressor()
dtc.fit(X_, y_)
dtc.predict(X_)

array([24. , 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15. ,
       18.9, 21.7, 20.4, 18.2, 19.9, 23.1, 17.5, 20.2, 18.2, 13.6, 19.6,
       15.2, 14.5, 15.6, 13.9, 16.6, 14.8, 18.4, 21. , 12.7, 14.5, 13.2,
       13.1, 13.5, 18.9, 20. , 21. , 24.7, 30.8, 34.9, 26.6, 25.3, 24.7,
       21.2, 19.3, 20. , 16.6, 14.4, 19.4, 19.7, 20.5, 25. , 23.4, 18.9,
       35.4, 24.7, 31.6, 23.3, 19.6, 18.7, 16. , 22.2, 25. , 33. , 23.5,
       19.4, 22. , 17.4, 20.9, 24.2, 21.7, 22.8, 23.4, 24.1, 21.4, 20. ,
       20.8, 21.2, 20.3, 28. , 23.9, 24.8, 22.9, 23.9, 26.6, 22.5, 22.2,
       23.6, 28.7, 22.6, 22. , 22.9, 25. , 20.6, 28.4, 21.4, 38.7, 43.8,
       33.2, 27.5, 26.5, 18.6, 19.3, 20.1, 19.5, 19.5, 20.4, 19.8, 19.4,
       21.7, 22.8, 18.8, 18.7, 18.5, 18.3, 21.2, 19.2, 20.4, 19.3, 22. ,
       20.3, 20.5, 17.3, 18.8, 21.4, 15.7, 16.2, 18. , 14.3, 19.2, 19.6,
       23. , 18.4, 15.6, 18.1, 17.4, 17.1, 13.3, 17.8, 14. , 14.4, 13.4,
       15.6, 11.8, 13.8, 15.6, 14.6, 17.8, 15.4, 21

# Random Forests

## Classification

In [5]:
dtc = RandomForestClassifier()
dtc.fit(X, y)
dtc.predict(X)

array([0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0])

## Regression

In [6]:
dtc = RandomForestRegressor()
dtc.fit(X_, y_)
dtc.predict(X_)

array([26.43, 22.2 , 35.22, 34.9 , 34.68, 26.36, 22.77, 23.3 , 16.2 ,
       18.61, 17.22, 19.75, 21.46, 20.25, 18.47, 20.05, 22.54, 18.59,
       19.9 , 18.34, 13.72, 18.92, 15.59, 14.49, 15.74, 14.7 , 16.97,
       15.11, 19.15, 23.16, 12.99, 16.99, 14.51, 13.45, 13.78, 20.2 ,
       20.61, 20.78, 24.07, 30.53, 35.26, 27.9 , 25.24, 24.67, 21.  ,
       19.56, 20.06, 18.68, 15.95, 18.83, 20.24, 20.92, 25.03, 22.89,
       18.67, 34.9 , 23.92, 31.95, 23.28, 19.86, 18.71, 16.45, 23.08,
       24.74, 31.36, 23.64, 20.5 , 21.61, 18.79, 20.91, 24.04, 21.16,
       22.79, 23.3 , 24.07, 21.59, 20.78, 20.82, 20.96, 20.66, 26.35,
       24.07, 24.01, 23.22, 23.82, 26.4 , 21.68, 22.72, 27.31, 30.2 ,
       22.09, 21.81, 22.82, 24.55, 21.71, 27.85, 21.12, 40.29, 43.08,
       33.06, 25.65, 25.91, 19.35, 19.3 , 19.81, 18.93, 19.59, 19.89,
       19.9 , 19.19, 21.37, 23.57, 19.27, 18.1 , 18.73, 18.51, 20.91,
       19.37, 20.06, 19.37, 21.8 , 20.33, 20.59, 17.3 , 18.62, 21.29,
       15.96, 15.55,

# Gradient Boosting

## Classification

In [7]:
dtc = GradientBoostingClassifier()
dtc.fit(X, y)
dtc.predict(X)

array([0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0])

## Regression

In [8]:
dtc = GradientBoostingRegressor()
dtc.fit(X_, y_)
dtc.predict(X_)

array([25.90772604, 21.96320179, 33.92712155, 34.14528061, 35.41267912,
       26.7925396 , 21.48031031, 20.87839556, 16.95411564, 18.45898255,
       18.05928146, 20.04582877, 19.88575493, 20.39575276, 18.96852027,
       20.21179657, 21.76179638, 16.96912497, 19.2506871 , 19.01636451,
       14.00404763, 18.24207805, 16.1851528 , 14.57787808, 15.77917671,
       14.86551421, 16.81521596, 14.71617291, 18.54626475, 20.6745727 ,
       13.39158783, 18.14465195, 13.44652826, 15.1383587 , 14.41160215,
       21.05815993, 21.42388009, 22.12656447, 23.40601618, 29.19035841,
       34.2785307 , 28.48271366, 24.45460595, 24.43134015, 22.55323935,
       20.688097  , 20.8094598 , 18.07873592, 16.27250646, 18.89252819,
       20.53847579, 21.69114043, 25.03984638, 21.75334966, 17.27619882,
       34.80045814, 24.13821604, 31.24080817, 22.9320532 , 20.75925478,
       18.78591467, 17.35922753, 22.55596532, 24.23089689, 32.40378007,
       24.75067937, 19.7671051 , 20.89998955, 19.02139353, 21.23