# SLU13 - Tree-based models: Examples

In [1]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd

from sklearn.datasets import load_boston
from sklearn.ensemble import (
    RandomForestClassifier,
    RandomForestRegressor,
    GradientBoostingClassifier,
    GradientBoostingRegressor,
)
from sklearn.metrics import mean_squared_error
from sklearn.tree import (
    DecisionTreeClassifier,
    DecisionTreeRegressor,
)

from utils.utils import *

# Decision trees

## Classification

In [2]:
data = make_data()
X, y = separate_target_variable(data)

X = process_categorical_features(X) 

dtc = DecisionTreeClassifier()
dtc.fit(X, y)
dtc.predict(X)

array([0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0])

## Regression

In [3]:
def prepare_boston():
    boston = load_boston()
    X = pd.DataFrame(data=boston.data, columns=boston.feature_names)
    y = pd.Series(data=boston.target, name='price')
    
    return X, y


X_, y_ = prepare_boston()

In [4]:
dtc = DecisionTreeRegressor()
dtc.fit(X_, y_)
dtc.predict(X_)

array([24. , 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15. ,
       18.9, 21.7, 20.4, 18.2, 19.9, 23.1, 17.5, 20.2, 18.2, 13.6, 19.6,
       15.2, 14.5, 15.6, 13.9, 16.6, 14.8, 18.4, 21. , 12.7, 14.5, 13.2,
       13.1, 13.5, 18.9, 20. , 21. , 24.7, 30.8, 34.9, 26.6, 25.3, 24.7,
       21.2, 19.3, 20. , 16.6, 14.4, 19.4, 19.7, 20.5, 25. , 23.4, 18.9,
       35.4, 24.7, 31.6, 23.3, 19.6, 18.7, 16. , 22.2, 25. , 33. , 23.5,
       19.4, 22. , 17.4, 20.9, 24.2, 21.7, 22.8, 23.4, 24.1, 21.4, 20. ,
       20.8, 21.2, 20.3, 28. , 23.9, 24.8, 22.9, 23.9, 26.6, 22.5, 22.2,
       23.6, 28.7, 22.6, 22. , 22.9, 25. , 20.6, 28.4, 21.4, 38.7, 43.8,
       33.2, 27.5, 26.5, 18.6, 19.3, 20.1, 19.5, 19.5, 20.4, 19.8, 19.4,
       21.7, 22.8, 18.8, 18.7, 18.5, 18.3, 21.2, 19.2, 20.4, 19.3, 22. ,
       20.3, 20.5, 17.3, 18.8, 21.4, 15.7, 16.2, 18. , 14.3, 19.2, 19.6,
       23. , 18.4, 15.6, 18.1, 17.4, 17.1, 13.3, 17.8, 14. , 14.4, 13.4,
       15.6, 11.8, 13.8, 15.6, 14.6, 17.8, 15.4, 21

# Random Forests

## Classification

In [5]:
dtc = RandomForestClassifier()
dtc.fit(X, y)
dtc.predict(X)

array([0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0])

## Regression

In [6]:
dtc = RandomForestRegressor()
dtc.fit(X_, y_)
dtc.predict(X_)

array([25.981, 22.002, 34.861, 34.066, 35.619, 26.871, 22.1  , 22.036,
       17.113, 19.183, 16.939, 19.821, 21.134, 20.253, 19.127, 19.928,
       22.623, 17.684, 19.594, 18.585, 13.844, 19.025, 15.587, 15.057,
       15.891, 14.631, 16.976, 14.949, 19.108, 21.903, 13.301, 16.293,
       14.316, 13.674, 13.627, 19.554, 20.305, 21.046, 23.246, 29.662,
       34.778, 28.192, 24.979, 24.601, 21.508, 19.462, 20.073, 17.745,
       15.709, 19.537, 20.202, 21.152, 25.329, 22.248, 19.096, 35.   ,
       24.046, 31.352, 23.051, 19.896, 18.705, 17.44 , 22.878, 25.304,
       32.906, 23.652, 19.839, 21.394, 18.22 , 20.825, 23.865, 21.605,
       22.727, 23.798, 24.381, 21.877, 20.54 , 21.009, 21.164, 20.557,
       27.43 , 24.577, 24.168, 23.173, 23.267, 26.788, 21.511, 22.16 ,
       26.537, 29.854, 22.639, 22.133, 22.945, 24.84 , 21.061, 27.933,
       21.812, 41.216, 43.38 , 32.994, 26.746, 26.195, 19.028, 19.669,
       20.127, 19.345, 18.869, 20.09 , 19.85 , 19.011, 21.277, 23.939,
      

# Gradient Boosting

## Classification

In [7]:
dtc = GradientBoostingClassifier()
dtc.fit(X, y)
dtc.predict(X)

array([0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0])

## Regression

In [8]:
dtc = GradientBoostingRegressor()
dtc.fit(X_, y_)
dtc.predict(X_)

array([25.90772604, 21.96320179, 33.92712155, 34.14528061, 35.41267912,
       26.7925396 , 21.48031031, 20.87839556, 16.95411564, 18.45898255,
       18.05928146, 20.04582877, 19.88575493, 20.39575276, 18.96852027,
       20.21179657, 21.76179638, 16.96912497, 19.2506871 , 19.01636451,
       14.00404763, 18.24207805, 16.1851528 , 14.57787808, 15.77917671,
       14.86551421, 16.81521596, 14.71617291, 18.54626475, 20.6745727 ,
       13.39158783, 18.14465195, 13.44652826, 15.1383587 , 14.41160215,
       21.05815993, 21.42388009, 22.12656447, 23.40601618, 29.19035841,
       34.2785307 , 28.48271366, 24.45460595, 24.43134015, 22.55323935,
       20.688097  , 20.8094598 , 18.07873592, 16.27250646, 18.89252819,
       20.53847579, 21.69114043, 25.03984638, 21.75334966, 17.27619882,
       34.80045814, 24.13821604, 31.24080817, 22.9320532 , 20.75925478,
       18.78591467, 17.35922753, 22.55596532, 24.23089689, 32.40378007,
       24.75067937, 19.7671051 , 20.89998955, 19.02139353, 21.23

---