In [9]:
import numpy as np
from collections import Counter
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from DecisionTree import DecisionTree

In [10]:
dataset = datasets.load_breast_cancer()
X, y = dataset.data, dataset.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42, stratify=y)

In [46]:
class RandomForest:
    def __init__(self, n_estimators=2, min_sample_split=5, max_depth=4):
        self.n_estimators = n_estimators
        self.estimators = [DecisionTree(min_sample_split=min_sample_split,
                                        depth=max_depth) for _ in range(self.n_estimators)]
        
    
    def fit(self, X: np.ndarray, y: np.ndarray):
        for i in range(self.n_estimators):
            X_sampled, y_sampled = self._boostrap(X, y)
            self.estimators[i].fit(X_sampled, y_sampled)


    def _boostrap(self, X: np.ndarray, y:np.ndarray):
        n_sample, n_feature = X.shape
        sample_idxs = np.random.choice(n_sample, n_sample // 2, replace=True)
        return X[sample_idxs], y[sample_idxs]
    

    def predict(self, X: np.ndarray):
        preductions = []
        for i in range(self.n_estimators):
            prediction = self.estimators[i].predict(X)
            preductions.append(prediction)
        
        preductions = np.array(preductions).swapaxes(0, 1)
        return [Counter(prediction).most_common(1)[0][0] for prediction in preductions]

In [51]:
rf = RandomForest(n_estimators=5,
                  min_sample_split=4,
                  max_depth=5)

rf.fit(X_train, y_train)

In [52]:
y_preds = rf.predict(X_test)
print(classification_report(y_test, y_preds))

              precision    recall  f1-score   support

           0       0.89      0.94      0.92        53
           1       0.97      0.93      0.95        90

    accuracy                           0.94       143
   macro avg       0.93      0.94      0.93       143
weighted avg       0.94      0.94      0.94       143

