In [1]:
import pandas as pd
from math import log2
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split

def accuracy(pred, labels) -> float:
    correct, total = 0, len(pred)
    for i in range(len(pred)):
        if labels.iloc[i] == pred[i]:
            correct += 1
    return correct/total

appleData = pd.read_csv('apple_quality.csv')
data = appleData.iloc[:, 1:appleData.shape[1]-1]
labels = appleData['Quality'].replace({'good': 1, 'bad':0})

xtrain, xtest, ytrain, ytest = train_test_split(data, labels, test_size=0.2, random_state=42)

In [13]:
def entropy(labels: pd.DataFrame) -> float:
    p1, p0 = 0, 0

    for val in labels.values:
        if val == 0:
            p0 += 1
        else:
            p1 += 1
    
    p1 /= labels.shape[0]
    p0 /= labels.shape[0]
    
    if p1 == 0 or p0 == 0:
        return 0
    return -p1 * log2(p1) - p0 * log2(p0)


def infoGain(data: pd.DataFrame, labels: pd.DataFrame, col: str, split: float) -> float:
    left = labels[data[col] <= split]
    right = labels[data[col] > split]
    pl, pr = len(left) / len(labels), len(right) / len(labels)
    
    return entropy(labels) - pl * entropy(left) - pr * entropy(right)
    
def bestSplit(data: pd.DataFrame, labels: pd.DataFrame) -> (str, float):
    colName, splitVal, maxInfoGain = '', 0, -1
    
    for col in data.columns:
        uniqueVals = data[col].unique()
        uniqueVals.sort()

        for i in range(len(uniqueVals) - 1):
            split = (uniqueVals[i] + uniqueVals[i+1]) / 2
            gain = infoGain(data, labels, col, split)

            if gain > maxInfoGain:
                maxInfoGain = gain
                colName = col
                splitVal = split
            
    return (colName, splitVal)

def buildTree(data: pd.DataFrame, labels: pd.DataFrame, criterion: str, min_split: int):
    
    

In [11]:
libraryTree = DecisionTreeClassifier(criterion="entropy",min_samples_split=2)
libraryTree.fit(xtrain, ytrain)
pred = libraryTree.predict(xtest)
acc = accuracy(pred, ytest)
print(acc)

0.8225
