In [None]:
import numpy as np
import scipy as sp
from sklearn.base import BaseEstimator
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_digits
from sklearn.model_selection import cross_val_score

digits = load_digits()


In [None]:
# Normal Decision Tree
tree = DecisionTreeClassifier()
cross_val_score(tree, digits.data, digits.target, cv=10, scoring="accuracy").mean()


In [None]:
# Scikit-learn implementation of Random Forest
from sklearn.ensemble import RandomForestClassifier
rf = RandomForestClassifier(n_estimators=100)
cross_val_score(rf, digits.data, digits.target, cv=10, scoring="accuracy").mean()


In [None]:
# Your implementation of Random Forest
class MyRandomForest(BaseEstimator):

    def __init__(self, n_estimators=100):
        self.n_estimators = n_estimators  # number of trees to fit
        self.trees = []                  # list to store the fitted trees

    def fit(self, X, y):
        # fit n_estimators DecisionTreeClassifiers (with max_features="sqrt")
        # on a randomized bootstrap of the data
        n = len(X)
        for _ in range(self.n_estimators):
            tree = DecisionTreeClassifier(max_features="sqrt")
            idx = np.random.choice(n, n) 
            tree.fit(X[idx,:], y[idx])
            self.trees.append(tree)

        return self

    def predict(self, X):
        # Use the DecisionTreeClassifiers to predict values
        predictions = np.zeros((len(X), self.n_estimators))
        for i in range(self.n_estimators):
            predictions[:, i] = self.trees[i].predict(X)
        return sp.stats.mode(predictions, axis=1)[0].ravel()


# Test MyRandomForest


In [None]:
mrf = MyRandomForest()

mrf.fit(digits.data, digits.target)


In [None]:
(mrf.predict(digits.data) == digits.target).mean()


In [None]:
cross_val_score(mrf, digits.data, digits.target, cv=10, scoring="accuracy").mean()


In [None]:
# example of computing majority vote
predictions = np.random.choice(2, (20, 5))  # 2 classes, 20 test data points, 5 trees
print(predictions)
sp.stats.mode(predictions, axis=1)[0].ravel()
