In [49]:
import numpy as np
import sklearn.datasets as datasets
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from collections import defaultdict

In [70]:
class MaxEntropy():

    def __init__(self, train_data, train_labels,test_data, test_labels):
        self.train_data = train_data
        self.train_labels = train_labels
        self.test_data = test_data
        self.test_labels = test_labels
        self.feature_nums = train_data.shape[1]#样本特征个数
        self.class_nums = len(np.unique(train_labels))#样本类的个数
        '''
        约束条件的个数，即训练样本所有出现(x,y)对的个数，注意x指单个特征
        比如样本有x1,x2,x3三个特征，(x1=1,y=1)为一个(x,y)对
        '''
        self.n = 0
        self.N = train_data.shape[0]#训练样本个数
        self.xy_num = self.calc_xy_num()#计算所有（x,y）对的个数
        self.M = self.n#类似于学习率
        self.w = np.zeros((self.n, 1))#拉格朗日乘子
        self.fxy2id, self.id2xy = self.create_search_dict()#构建(x,y)对到id及其相反的映射
        self.epf_tilde = self.calc_epf_tilde()#计算f(x,y)关于经验分布p_(x,y)的期望


    def calc_xy_num(self):
        #计算所有（x,y）对的个数
        xy_num = [defaultdict(int) for i in range(self.feature_nums)]
        for i in range(self.N):
            for f in range(self.feature_nums):
                xy_num[f][(self.train_data[i, f], self.train_labels[i])] += 1
        for d in xy_num:
            self.n += len(d)
        return xy_num

    def create_search_dict(self):
        #构建(x,y)对到id及其相反的映射
        fxy2id = {}
        id2xy = {}
        index = 0
        for f in range(self.feature_nums):
            for (x, y) in self.xy_num[f]:
                fxy2id[(f, x, y)] = index
                id2xy[index] = (f, x, y)
                index += 1
        return fxy2id, id2xy

    def calc_epf_tilde(self):
        #计算书中公式6.10右边的期望
        pxy_tilde = np.zeros((self.n, 1))
        for f in range(self.feature_nums):
            for (x, y), num in self.xy_num[f].items():
                index = self.fxy2id[(f, x, y)]
                pxy_tilde[index] = num
        pxy_tilde /= self.N
        return pxy_tilde

    def calc_pwy_x(self, x):
        #计算条件概率对应书中公式6.22
        element = np.zeros((self.class_nums, 1))
        for f in range(self.feature_nums):
            for y in range(self.class_nums):
                if (f, x[f], y) in self.fxy2id:
                    index = self.fxy2id[(f, x[f], y)]
                    element[y] += self.w[index]
        element = np.exp(element)
        z = np.sum(element)
        return element / z

    def calc_epf(self):
        #计算书中公式6.10左边期望
        epf = np.zeros((self.n, 1))
        for i in range(self.N):
            pwy_x = self.calc_pwy_x(self.train_data[i])
            for f in range(self.feature_nums):
                for y in range(self.class_nums):
                    if (f, self.train_data[i, f], y) in self.fxy2id:
                        index = self.fxy2id[(f, self.train_data[i, f], y)]
                        epf[index] += pwy_x[y] / self.N
        return epf

    def train(self, iterations=200):
        for i in tqdm(range(iterations)):
            epf = self.calc_epf()
            delta = (1 / self.M) * np.log(self.epf_tilde / epf)
            self.w += delta
            if i % 10 == 9:
                self.test(self.test_data, self.test_labels)

    def predict(self, x):
        pred = self.calc_pwy_x(x)
        return np.argmax(pred)

    def test(self, test_data, test_labels):
        correct = 0
        length = test_data.shape[0]
        for i in range(length):
            pred = self.predict(test_data[i])
            if pred == test_labels[i]:
                correct += 1
        print('测试集正确率:%.2f%%' %(correct*100/length))

In [69]:
digits_data = datasets.load_digits()
data, labels = digits_data.data, digits_data.target
labels = (labels > 0).astype(np.int32)
train_x, test_x, train_y, test_y = train_test_split(data, labels, test_size=0.3, random_state=0)
model = MaxEntropy(train_x, train_y, test_x, test_y)
model.train()

  4%|███▋                                                                              | 9/200 [00:19<07:08,  2.25s/it]

测试集正确率:91.67%


 10%|███████▋                                                                         | 19/200 [00:40<05:29,  1.82s/it]

测试集正确率:91.67%


 14%|███████████▋                                                                     | 29/200 [00:55<04:25,  1.55s/it]

测试集正确率:91.67%


 20%|███████████████▊                                                                 | 39/200 [01:11<04:01,  1.50s/it]

测试集正确率:91.67%


 24%|███████████████████▊                                                             | 49/200 [01:27<03:51,  1.53s/it]

测试集正确率:91.67%


 30%|███████████████████████▉                                                         | 59/200 [01:43<03:37,  1.54s/it]

测试集正确率:91.67%


 34%|███████████████████████████▉                                                     | 69/200 [01:59<03:21,  1.54s/it]

测试集正确率:91.67%


 40%|███████████████████████████████▉                                                 | 79/200 [02:15<03:16,  1.62s/it]

测试集正确率:91.67%


 44%|████████████████████████████████████                                             | 89/200 [02:33<03:21,  1.81s/it]

测试集正确率:91.85%


 50%|████████████████████████████████████████                                         | 99/200 [02:50<02:56,  1.74s/it]

测试集正确率:92.22%


 55%|███████████████████████████████████████████▌                                    | 109/200 [03:07<02:26,  1.61s/it]

测试集正确率:92.59%


 60%|███████████████████████████████████████████████▌                                | 119/200 [03:24<02:10,  1.61s/it]

测试集正确率:92.78%


 64%|███████████████████████████████████████████████████▌                            | 129/200 [03:43<01:53,  1.59s/it]

测试集正确率:93.52%


 70%|███████████████████████████████████████████████████████▌                        | 139/200 [04:01<02:04,  2.05s/it]

测试集正确率:94.07%


 74%|███████████████████████████████████████████████████████████▌                    | 149/200 [04:20<01:33,  1.84s/it]

测试集正确率:94.26%


 80%|███████████████████████████████████████████████████████████████▌                | 159/200 [04:40<01:15,  1.84s/it]

测试集正确率:95.19%


 84%|███████████████████████████████████████████████████████████████████▌            | 169/200 [04:57<00:50,  1.63s/it]

测试集正确率:96.11%


 90%|███████████████████████████████████████████████████████████████████████▌        | 179/200 [05:13<00:33,  1.62s/it]

测试集正确率:96.48%


 94%|███████████████████████████████████████████████████████████████████████████▌    | 189/200 [05:32<00:22,  2.02s/it]

测试集正确率:97.22%


100%|███████████████████████████████████████████████████████████████████████████████▌| 199/200 [05:51<00:01,  1.95s/it]

测试集正确率:97.22%


100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [05:52<00:00,  1.89s/it]
