In [1]:
import sys
sys.path.append('../')

In [2]:
import math
import random
import pandas as pd

In [3]:
from ml.classification.DecisionTreeClassifier import DecisionTree
from ml.validation.split_data import train_test_split
from ml.validation.accuracy import accuracy

In [4]:
class RandomForestClassifier(object):
    def __init__(self, tree_number=15, result_label='label', max_depth=10):
        self.trees = [ DecisionTree(max_depth=max_depth, result_label=result_label) for _ in range(tree_number)]
        self.max_depth = max_depth
        
    def train(self, data, attributes, alpha=0.2, beta=0.5):
        attributes_number = len(attributes)
        row_number = data.shape[0]
        
        for tree in self.trees:
            sampled_data = data.sample(n=int(alpha*row_number))
            sampled_attr = random.sample(attributes.tolist(), int(beta*attributes_number))
            tree.train(data=sampled_data, attributes=sampled_attr)
            
    def predict(self, data):
        vote_true, vote_false = 0, 0
        for tree in self.trees:
            if tree.predict(data):
                vote_true += 1
            else:
                vote_false += 1
                
        if vote_true > vote_false:
            return True
        elif vote_true < vote_false:
            return False
        else:
            return random.choice([True, False])

### Test - Hepatitis dataset

In [10]:
hepatitis_all_data = pd.read_csv('../datasets/hepatitis.csv', header=None)

In [11]:
#transform label
for index, row in hepatitis_all_data.iterrows():
    if row[0] == 1:
        hepatitis_all_data.loc[index, 0] = False
    else:
        hepatitis_all_data.loc[index, 0] = True

In [12]:
hepatitis_train_data, hepatitis_test_data = train_test_split(hepatitis_all_data, test_size=0.2)

In [13]:
RandomF = RandomForestClassifier(tree_number=15, result_label=0, max_depth=4)
RandomF.train(hepatitis_train_data, hepatitis_train_data.columns[1:].values, alpha=1, beta=0.5)

In [14]:
accuracy(RandomF, hepatitis_test_data, result_label=0)

0.8709677419354839