In [4]:
import tqdm
import numpy as np
import time

def load_data(filename):
    data = []
    label = []
    with open(filename, 'r') as f:
        for line in f:
            line = line.strip().split(',')
            data.append([1 if int(i) >= 128 else 0 for i in line[1:]])
            label.append(int(line[0]))
    return data, label

def max_class(label):
    class_dict = {}
    for i in range(len(label)):
        class_dict[label[i]] = class_dict.get(label[i], 0) + 1
    class_sort = sorted(class_dict.items(), key=lambda x: x[1], reverse=True)
    return class_sort[0][0]

def calculate_H_D(label):
    H_D = 0

    class_set = set([l for l in label])
    for i in class_set:
        p = label[label == i].size / label.size
        H_D -= p * np.log2(p)
    return H_D

def calculate_H_D_A(data, label):
    H_D_A = 0

    class_set = set([d for d in data])
    for i in class_set:
        H_D_A += data[data == i].size / data.size  * calculate_H_D(label[data == i])
    return H_D_A    

def choose_best_feature(data, label):
    max_gda = 0
    data = np.array(data)
    label = np.array(label)
    feature_num = data.shape[1]
    H_D = calculate_H_D(label)
    for feature in range(feature_num):
        gda = H_D - calculate_H_D_A(np.array(data[:, feature].flat), label)
        if gda > max_gda:
            max_gda = gda
            best_feature = feature
    return best_feature, max_gda

def get_sub_data(data, label, feature, value):
    ret_data = []
    ret_label = []
    for i in range(len(label)):
        if data[i][feature] == value:
            ret_data.append(data[i][:feature] + data[i][feature+1:])
            ret_label.append(label[i])
    return ret_data , ret_label


def create_tree(*dataset, epsilon=0.1):
    data = dataset[0][0]
    label = dataset[0][1]

    class_dic = {i for i in label}

    if len(class_dic) == 1:
        return label[0]

    if len(data[0]) == 0:
        return max_class(label)

    best_feature, max_gda = choose_best_feature(data, label)
    if max_gda < epsilon:
        return max_class(label)
    
    tree = {best_feature:{}}
    tree[best_feature][0] = create_tree(get_sub_data(data, label, best_feature, 0))
    tree[best_feature][1] = create_tree(get_sub_data(data, label, best_feature, 1))

    return tree



def test(data, label, tree):
    '''
    测试准确率
    :param testDataList:待测试数据集
    :param testLabelList: 待测试标签集
    :param tree: 训练集生成的树
    :return: 准确率
    '''
    #错误次数计数
    errorCnt = 0
    #遍历测试集中每一个测试样本
    for i in range(len(data)):
        #判断预测与标签中结果是否一致
        if label[i] != predict(label[i], tree):
            errorCnt += 1
    #返回准确率
    return 1 - errorCnt / len(data)

def main():
        #开始时间
    start = time.time()

    # 获取训练集
    trainDataList, trainLabelList = load_data('../mnist_train.csv')
    # 获取测试集
    testDataList, testLabelList = load_data('../mnist_test.csv')

    #创建决策树
    print('start create tree')
    tree = create_tree((trainDataList, trainLabelList))
    print('tree is:', tree)

    #测试准确率
    print('start test')
    accur = test(testDataList, testLabelList, tree)
    print('the accur is:', accur)

    #结束时间
    end = time.time()
    print('time span:', end - start)

if __name__ == '__main__':
    main()
        





    
    
    
    
    

start create tree
tree is: {378: {0: {567: {0: {541: {0: {458: {0: {403: {0: {590: {0: {483: {0: {154: {0: {429: {0: 7, 1: {209: {0: {265: {0: {237: {0: {181: {0: {317: {0: {542: {0: {607: {0: 4, 1: {99: {0: 5, 1: 2}}}}, 1: {202: {0: 9, 1: 3}}}}, 1: {261: {0: 5, 1: {158: {0: 9, 1: 8}}}}}}, 1: {480: {0: {155: {0: 9, 1: 3}}, 1: {121: {0: {153: {0: 4, 1: 8}}, 1: 2}}}}}}, 1: {403: {0: {286: {0: {209: {0: 2, 1: 4}}, 1: 5}}, 1: {401: {0: {185: {0: 4, 1: 9}}, 1: {231: {0: 9, 1: 3}}}}}}}}, 1: {403: {0: {98: {0: {149: {0: 5, 1: 7}}, 1: 2}}, 1: {398: {0: {644: {0: {178: {0: 7, 1: 3}}, 1: {204: {0: 9, 1: {174: {0: 7, 1: 2}}}}}}, 1: {263: {0: 4, 1: {405: {0: 9, 1: 5}}}}}}}}}}, 1: {404: {0: {431: {0: {323: {0: {240: {0: 5, 1: 2}}, 1: {185: {0: {177: {0: 9, 1: 3}}, 1: 8}}}}, 1: {481: {0: {272: {0: {316: {0: 9, 1: 5}}, 1: {179: {0: 4, 1: 3}}}}, 1: 8}}}}, 1: {426: {0: {643: {0: {372: {0: 7, 1: {103: {0: {175: {0: 9, 1: 8}}, 1: 6}}}}, 1: 3}}, 1: {177: {0: {289: {0: 9, 1: {400: {0: {210: {0: {205: {0: 4

AttributeError: 'int' object has no attribute 'items'

In [8]:



accur = model_test(testDataList, testLabelList, tree)
print('the accur is:', accur)

NameError: name 'testDataList' is not defined