In [89]:
%matplotlib inline
import pandas as pd
import torch
import numpy as np
from d2l import torch as d2l

In [458]:
class Tree:
    def __init__(self, out, att):
        self.out = out
        self.att = att
        self.children = {}

In [459]:
from math import log
def entropy(labels):
    cnt = [0, 0]
    for label in labels:
        cnt[label] += 1
    log2 = lambda x:log(x)/log(2)
    ent = 0.0
    for c in cnt:
        if not c == 0:
            ent -= log2(float(c/labels.size))*c/labels.size
    return ent

In [460]:
def divide_data(features, labels, att):
    att_unique = features[att].unique()
    for a in att_unique:
        feature = features[features[att]==a]
        label = labels[feature.index]
        yield feature, label

In [461]:
def info_gain(features, labels, att):
    ent = entropy(labels)
    for feature, label in divide_data(features, labels, att):
        div_ent = entropy(label)
        ent -= div_ent*feature.size/features.size
    return ent

In [462]:
def select_best_att(features, labels):
    column = features.columns
    max_gain = 0
    best_att = ""
    for c in column:
        gain = info_gain(features, labels, c)
        if gain > max_gain:
            max_gain = gain
            best_att = c
    return best_att

In [463]:
def check_eq(features):
    df = features.iloc[0, :]
    for i in range(1, features.size):
        tmp_df = features.iloc[i, :]
        if not tmp_df.equals(df):
            return False
    return True

In [464]:
def cnt_label(labels):
    cnt = 0
    for label in labels:
        if label == 1:
            cnt += 1
    if cnt >= labels.size/2:
        return 1
    else:
        return 0

In [471]:
def build_dt(features, labels, out_label)->Tree:
    if np.all(labels == labels[0]):
        return Tree(labels[0], "")
    elif check_eq(features):
        return Tree(cnt_label(labels), "")
    else:
        att = select_best_att(features, labels)
        node = Tree(2, att)
        value = features[att].unique()
        for v in value:
            node.children[v] = build_dt(features.drop(columns=att), labels, out_label)
        return node

In [504]:
def decide_label(feature, tree):
    if not tree.out == 2:
        return tree.out
    else:
        return decide_label(feature, tree.children[feature[tree.att]])

In [505]:
def train(features, labels, out_label):
    return build_dt(features, labels, out_label)

In [506]:
def predict(features, tree):
    for i in range(0, len(features)):
        feature = features.iloc[i, 1:]
        label = decide_label(feature, tree)
        print("feature:", feature, ", label:", label)

In [507]:
data = pd.read_csv("watermelon.CSV", encoding="GBK")
features, labels = data.iloc[:, 1:7], data.iloc[:, 7]
out_label = 0

In [508]:
DTree = train(features, labels, out_label)

In [509]:
test_data = pd.read_csv("test.CSV", encoding="GBK")
predict(test_data, DTree)

feature: 色泽    1
根蒂    0
敲声    0
纹理    0
脐部    1
触感    1
Name: 0, dtype: int64 , label: 0
feature: 色泽    2
根蒂    1
敲声    0
纹理    1
脐部    0
触感    0
Name: 1, dtype: int64 , label: 0
feature: 色泽    0
根蒂    0
敲声    1
纹理    1
脐部    1
触感    1
Name: 2, dtype: int64 , label: 0
