In [1]:
import math
from copy import deepcopy

In [2]:
dataset = [['no', 'sunny', 'hot', 'high', 'FALSE'],
           ['no', 'sunny', 'hot', 'high', 'TRUE'],
           ['yes', 'overcast', 'hot', 'high', 'FALSE'],
           ['yes', 'rainy', 'mild', 'high', 'FALSE'],
           ['yes', 'rainy', 'cool', 'normal', 'FALSE'],
           ['no', 'rainy', 'cool', 'normal', 'TRUE'],
           ['yes', 'overcast', 'cool', 'normal', 'TRUE'],
           ['no', 'sunny', 'mild', 'high', 'FALSE'],
           ['yes', 'sunny', 'cool', 'normal', 'FALSE'],
           ['yes', 'rainy', 'mild', 'normal', 'FALSE'],
           ['yes', 'sunny', 'mild', 'normal', 'TRUE'],
           ['yes', 'overcast', 'mild', 'high', 'TRUE'],
           ['yes', 'overcast', 'hot', 'normal', 'FALSE'],
           ['no', 'rainy', 'mild', 'high', 'TRUE']]

In [3]:
class MaxEntropy:
    def __init__(self, EPS=0.005):
        self._samples = []
        self._Y = set()  # 标签集合，相当去去重后的y
        self._numXY = {}  # key为(x,y)，value为出现次数
        self._N = 0  # 样本数
        self._Ep_ = []  # 样本分布的特征期望值
        self._xyID = {}  # key记录(x,y),value记录id号
        self._n = 0  # 特征键值(x,y)的个数
        self._C = 0  # 最大特征数
        self._IDxy = {}  # key为(x,y)，value为对应的id号
        self._w = []
        self._EPS = EPS  # 收敛条件
        self._lastw = []  # 上一次w参数值

    def loadData(self, dataset):
        self._samples = deepcopy(dataset)
        for items in self._samples:
            y = items[0]
            X = items[1:]
            self._Y.add(y)  # 集合中y若已存在则会自动忽略
            for x in X:
                if (x, y) in self._numXY:
                    self._numXY[(x, y)] += 1
                else:
                    self._numXY[(x, y)] = 1

        self._N = len(self._samples)
        self._n = len(self._numXY)
        self._C = max([len(sample) - 1 for sample in self._samples])
        self._w = [0] * self._n
        self._lastw = self._w[:]

        self._Ep_ = [0] * self._n
        for i, xy in enumerate(self._numXY):  # 计算特征函数fi关于经验分布的期望
            self._Ep_[i] = self._numXY[xy] / self._N
            self._xyID[xy] = i
            self._IDxy[i] = xy

    def _Zx(self, X):  # 计算每个Z(x)值
        zx = 0
        for y in self._Y:
            ss = 0
            for x in X:
                if (x, y) in self._numXY:
                    ss += self._w[self._xyID[(x, y)]]
            zx += math.exp(ss)
        return zx

    def _model_pyx(self, y, X):  # 计算每个P(y|x)
        zx = self._Zx(X)
        ss = 0
        for x in X:
            if (x, y) in self._numXY:
                ss += self._w[self._xyID[(x, y)]]
        pyx = math.exp(ss) / zx
        return pyx

    def _model_ep(self, index):  # 计算特征函数fi关于模型的期望
        x, y = self._IDxy[index]
        ep = 0
        for sample in self._samples:
            if x not in sample:
                continue
            pyx = self._model_pyx(y, sample)
            ep += pyx / self._N
        return ep

    def _convergence(self):  # 判断是否全部收敛
        for last, now in zip(self._lastw, self._w):
            if abs(last - now) >= self._EPS:
                return False
        return True

    def predict(self, X):  # 计算预测概率
        Z = self._Zx(X)
        result = {}
        for y in self._Y:
            ss = 0
            for x in X:
                if (x, y) in self._numXY:
                    ss += self._w[self._xyID[(x, y)]]
            pyx = math.exp(ss) / Z
            result[y] = pyx
        return result

    def train(self, maxiter=1000):  # 训练数据
        for loop in range(maxiter):  # 最大训练次数
            print(" * iter:%d" % loop)
            self._lastw = self._w[:]
            for i in range(self._n):
                ep = self._model_ep(i)  # 计算第i个特征的模型期望
                self._w[i] += math.log(self._Ep_[i] / ep) / self._C  # 更新参数
            print("w:", self._w)
            if self._convergence():  # 判断是否收敛
                break

In [4]:
maxent = MaxEntropy()
maxent.loadData(dataset)
maxent.train()

 * iter:0
w: [0.0455803891984887, -0.002832177999673058, 0.031103560672370825, -0.1772024616282862, -0.0037548445453157455, 0.16394435955437575, -0.02051493923938058, -0.049675901430111545, 0.08288783767234777, 0.030474400362443962, 0.05913652210443954, 0.08028783103573349, 0.1047516055195683, -0.017733409097415182, -0.12279936099838235, -0.2525211841208849, -0.033080678592754015, -0.06511302013721994, -0.08720030253991244]
 * iter:1
w: [0.11525071899801315, 0.019484939219927316, 0.07502777039579785, -0.29094979172869884, 0.023544184009850026, 0.2833018051925922, -0.04928887087664562, -0.101950931659509, 0.12655289130431963, 0.016078718904129236, 0.09710585487843026, 0.10327329399123442, 0.16183727320804359, 0.013224083490515591, -0.17018583153306513, -0.44038644519804815, -0.07026660158873668, -0.11606564516054546, -0.1711390483931799]
 * iter:2
w: [0.18178907332733973, 0.04233703122822168, 0.11301330241050131, -0.37456674484068975, 0.05599764270990431, 0.38356978711239126, -0.0748854

w: [2.501222567189321, 0.07660294489544052, 1.0698479288060314, -2.9388812234297106, 1.1222929054819264, 3.617315289711025, -0.17436469624865278, -1.477232815260255, 0.9849043527057243, -1.2755231410610537, 1.345606576827002, -0.7916103080415674, 1.1708068359801682, 2.0455876865092146, 2.500103832194786, -6.190962116652396, -1.1966302261985702, -2.3444255784053825, -3.6418667922102084]
 * iter:220
w: [2.5058658398161286, 0.07634355634284108, 1.0719562244728462, -2.9442804763682826, 1.1246589207134539, 3.623238765355398, -0.17411802990946834, -1.4801154164075447, 0.986664421987592, -1.2776219907893833, 1.3477652396792692, -0.793458197553706, 1.1728304543768011, 2.048733564294805, 2.505621601353709, -6.2030803691367336, -1.199052472330526, -2.348548991370976, -3.6489786092214884]
 * iter:221
w: [2.510493990960098, 0.07608595862721762, 1.0740570452750384, -2.9496626182056476, 1.1270174235551282, 3.629143694292232, -0.17387300656885243, -1.4829877457879022, 0.9884189706414899, -1.279714792

w: [3.0042704633052573, 0.053842900013909074, 1.2944830603559314, -3.525448600758213, 1.379070609088711, 4.263209437240499, -0.15250900621160315, -1.7840201338102224, 1.1764527926559614, -1.5072519463804777, 1.5758763948193175, -0.9833606876858102, 1.3877077017248847, 2.3928763466454406, 3.072231790767347, -7.491177334629225, -1.4586885080720642, -2.7873045847460918, -4.411751506905066]
 * iter:349
w: [3.007583612940455, 0.05372632696303156, 1.2959374541865332, -3.529316638566334, 1.380761554555724, 4.267488825575211, -0.1523960347083192, -1.7860037018312092, 1.1777182010267457, -1.5088045126968812, 1.5773716674632476, -0.984571505260355, 1.389120549633482, 2.3952030335784165, 3.0758427941370483, -7.499653207128619, -1.460409275193239, -2.7901983197981437, -4.416816216857005]
 * iter:350
w: [3.01088967901, 0.053610395435904616, 1.2973884301624294, -3.533176410168411, 1.3824488545351221, 4.271759357008156, -0.15228367498030393, -1.7879825680015577, 1.1789809349231992, -1.510354060500101

w: [3.5901178866136347, 0.038477299499200486, 1.5474697598232439, -4.208959885084161, 1.6774674334876654, 5.023726951454974, -0.13753167195904084, -2.1283744232011848, 1.4005201905369669, -1.785683004669823, 1.8371429994633013, -1.1894038252219525, 1.6348244158639695, 2.8101454428871873, 3.6867383032315755, -8.974320209509877, -1.7615403497253723, -3.295671493715933, -5.3057119298826905]
 * iter:563
w: [3.5924073839023642, 0.03843460009044659, 1.5484445182937783, -4.211627933941348, 1.67863103407836, 5.02671133282787, -0.13748980365245503, -2.1296983701567855, 1.4013966082883493, -1.7867835572546147, 1.8381539240822165, -1.1901825997350908, 1.6357809303960866, 2.811794858953236, 3.6890615341410324, -8.980062727770449, -1.7627184836519956, -3.2976479799764755, -5.309198906022444]
 * iter:564
w: [3.5946936303861565, 0.038392071465137956, 1.5494178047929361, -4.214292164169716, 1.6797929636038385, 5.029691552330711, -0.13744810137799668, -2.131020295000344, 1.4022717851304467, -1.78788262

In [5]:
x = ['overcast', 'mild', 'high', 'FALSE']
print('predict:', maxent.predict(x))

predict: {'yes': 0.9999971802186581, 'no': 2.8197813418816512e-06}
