In [1]:
import numpy as np
from loguru import logger

from CART import CART
from loadDataSet import loadDataSet

class AdaBoost:
    def __init__(self, dataset, feature, T):
        self._dataset = dataset  # 数据集 dataset
        self._feature = feature  # 特征 feature
        self._classifier = []  # 弱分类器 weak classifier
        self._Weight_classifier = []  # 弱分类器的系数 weak classifier coefficient
        self._error = []  # 弱分类器的误差 weak classifier error
        self._T = T  # 弱分类器的个数 number of weak classifiers
        self._Weight_sample = [1 / len(dataset) for _ in range(len(dataset))]  # 样本权重 sample weight

    def selectSample(self):
        select_index = np.random.choice(len(self._dataset), int(len(self._dataset) / 2), p=self._Weight_sample)
        select_samples = []
        for i in select_index:
            select_samples.append(self._dataset[i])
        return select_samples

    def fit(self):
        for i in range(self._T):
            data = self.selectSample()
            cart = CART(data, self._feature)
            cart.fit()

            predict = cart.predict(self._dataset)
            error = 0
            for j in range(len(predict)):
                if predict[j] != self._dataset[j][-1]:
                    error += self._Weight_sample[j]
            # 分类器效果差于随机分类器
            if error > 0.5:
                continue
            else:
                self._classifier.append(cart)
                # 计算分类器系数
                weight_classifier = 0.5 * np.log((1 - error) / error)
                self._Weight_classifier.append(weight_classifier)
                # 计算样本权重
                for j in range(len(predict)):
                    if predict[j] == self._dataset[j][-1]:
                        self._Weight_sample[j] = self._Weight_sample[j] * np.exp(-weight_classifier)
                    else:
                        self._Weight_sample[j] = self._Weight_sample[j] * np.exp(weight_classifier)
                # 样本权重归一化
                self._Weight_sample = self._Weight_sample / np.sum(self._Weight_sample)
                self._error.append(error)
                logger.info("第{}个弱分类器训练完成，正确率{:.2f}".format(i + 1, 1 - error))

    def classify(self, test_data):
        predict_list = []
        for sample in test_data:
            predict = []  # 每个弱分类器的预测结果
            weight = []  # 每个预测结果的权重
            for i in range(len(self._classifier)):
                pred = self._classifier[i].classify(self._classifier[i].tree, self._feature, sample)
                if pred in predict:
                    index = predict.index(pred)
                    weight[index] += self._Weight_classifier[i]
                else:
                    predict.append(pred)
                    weight.append(self._Weight_classifier[i])
            index = weight.index(max(weight))
            predict_list.append(predict[index])
        return predict_list

    def score(self, test_data):
        predict = self.classify(test_data)
        correct = 0
        for i in range(len(test_data)):
            if predict[i] == test_data[i][-1]:
                correct += 1
        return correct / len(test_data)

    def base_score(self):
        return 1 - np.average(self._error)

    def __repr__(self):
        return "AdaBoost"

    def __method__(self):
        return "CART"

In [3]:
def main():
    path = '../dataset/archive/DATA.csv'
    train_data, test_data, Feature = loadDataSet(path)
    adaboost = AdaBoost(train_data, Feature, 20)
    adaboost.fit()
    logger.info("{} 基分类器平均正确率: {:.2f}%".format(adaboost.__method__(), adaboost.base_score() * 100))
    logger.info("{}_{} 分类器正确率: {:.2f}%".format(adaboost.__repr__(), adaboost.__method__(),
                                               adaboost.score(test_data) * 100))

In [5]:
if __name__ == "__main__":
    main()

[32m2023-12-30 21:40:39.339[0m | [1mINFO    [0m | [36m__main__[0m:[36mfit[0m:[36m52[0m - [1m第1个弱分类器训练完成，正确率0.66[0m
[32m2023-12-30 21:40:39.615[0m | [1mINFO    [0m | [36m__main__[0m:[36mfit[0m:[36m52[0m - [1m第2个弱分类器训练完成，正确率0.62[0m
[32m2023-12-30 21:40:39.891[0m | [1mINFO    [0m | [36m__main__[0m:[36mfit[0m:[36m52[0m - [1m第3个弱分类器训练完成，正确率0.59[0m
[32m2023-12-30 21:40:40.163[0m | [1mINFO    [0m | [36m__main__[0m:[36mfit[0m:[36m52[0m - [1m第4个弱分类器训练完成，正确率0.60[0m
[32m2023-12-30 21:40:40.459[0m | [1mINFO    [0m | [36m__main__[0m:[36mfit[0m:[36m52[0m - [1m第5个弱分类器训练完成，正确率0.60[0m
[32m2023-12-30 21:40:40.759[0m | [1mINFO    [0m | [36m__main__[0m:[36mfit[0m:[36m52[0m - [1m第6个弱分类器训练完成，正确率0.58[0m
[32m2023-12-30 21:40:41.032[0m | [1mINFO    [0m | [36m__main__[0m:[36mfit[0m:[36m52[0m - [1m第7个弱分类器训练完成，正确率0.58[0m
[32m2023-12-30 21:40:41.308[0m | [1mINFO    [0m | [36m__main__[0m:[36mfit[0m:[36m52[0m - [1m第8个弱分类器训