In [104]:
import numpy as np
import math
import gc
from collections import defaultdict, Counter
from sklearn.datasets import load_iris

class Node():
    def __init__(self, feature=None, predict=None, leaf=False):
        self.feature = feature
        self.leaf = leaf
        self.predict = predict
        self.child = {}
        
class ID3():
    def __init__(self, min_gain):
        self.min_gain = min_gain
        self.root = Node()
        self.feature = {}
        
    def build_tree(self, data, label, root, feature):
        best_feature = self.find_best_feature(data, label, feature)
        if best_feature == -1:
            root.leaf = True
            dic = Counter(label)
            root.predict = dic.most_common(1)[0][0]
            return 
        
        root.feature = best_feature
        feature.remove(best_feature)
        
        data_dic  = defaultdict(list)
        label_dic = defaultdict(list)
        
        for d, l in zip(data, label):
            if d[best_feature] not in root.child:
                new_node = Node()
                root.child[d[best_feature]] = new_node
            data_dic[d[best_feature]].append(d)
            label_dic[d[best_feature]].append(l)
        for k in data_dic.keys():
            self.build_tree(data_dic[k], label_dic[k], root.child[k], feature)
            # gc.collect()
        
    def fit(self, data, label):
        feature = set(i for i in range(len(data[0])))
        self.build_tree(data, label, self.root, feature)
        
    def predict(self, data):
        node = self.root
        print(node.feature, node.child, node.predict)
        while not node.leaf:
            try:
                node = node.child[data[node.feature]]
                print(node.feature, node.child, node.predict)
            except:
                raise ValueError("undefined value")
        return node.predict
        
    def find_best_feature(self, data, label, feature):
        if not feature:
            return -1
        best_feature = None
        best_entropy = float('inf')
        origin_entropy = self.entropy(label)
        for i in feature:
            dic = defaultdict(list)
            entropy = 0
            for d, l in zip(data, label):
                dic[d[i]].append(l)
            for k,v in dic.items():  
                entropy += self.entropy(v) * len(v) / len(data)
            if entropy < best_entropy:
                best_entropy = entropy
                best_feature = i
        if best_entropy >= origin_entropy or not best_feature:
            return -1
        return best_feature
    
    def entropy(self, label):
        dic = Counter(label)
        prob = list(dic.values())
        prob /= np.sum(prob)
        return np.sum(- prob * np.log(prob))
    
class KNN():
    def __init__(self, k):
        self.k = k
        self.data = []
        self.label = []
        
    def fit(self, data, label):
        size = data.shape[0]
        self.data = data.reshape(size, -1)
        self.label = label
        
    def predict(self, data):
        dist = []
        data = data.reshape(-1)
        for d,l in zip(self.data, self.label):
            dist.append((self.l2norm(d, data), l))
        dist.sort(key=lambda x:x[0])
        dic = defaultdict(int)
        for i in range(self.k):
            dic[dist[i][1]] += 1
        return max(dic, key=lambda x: x.__getitem__)
    
    def l2norm(self, d1, d2):
        return np.sum(np.square(d1 - d2))

In [105]:
if __name__ == '__main__':
    model = ID3(0)
    data = [['高','否','一般'],
            ['高','否','好'],
            ['高','否','一般'],
            ['中','否','一般'],
            ['低','是','一般'],
            ['低','是','好'],
            ['低','是','好'],
            ['中','否','一般'],
            ['低','是','一般'],
            ['中','是','一般'],
           ]
    label = ['不买','不买','买','买','买','不买','买','不买','买','买']
    model.fit(data, label)
    test = ['中','是','好']
    print('Predict : {} , True : {}'.format(model.predict(test), '买'))

1 {'否': <__main__.Node object at 0x11ed6f358>, '是': <__main__.Node object at 0x11ed6f208>} None
None {} 买
Predict : 买 , True : 买
