# 第6章 实现最大熵模型，使用IIS优化

In [2]:
import math
from copy import deepcopy

In [8]:
class MaxEntropy:
    def __init__(self, EPS=0.005):
        self._samples = []
        self._Y = set()  # label set
        self._numXY = {}  # the parameter: key-(x, y), value- the number of times (x, y) appears
        self._N = 0  # the number of samples
        self._Ep_ = []   # the feature expectation of sample's distribution
        self._xyID = {}   # the parameter: key-(x, y), value- the number of id
        self._n = 0  # the number of feature key (x, y)
        self._C = 0   # the number of maximum feature
        self._IDxy = {}    # the parameter: key-(x, y), value- the number of id
        self._w = []
        self._EPS = EPS   # the condition of convergence
        self._lastw = []    # last w parameter value
        
    def loadData(self, dataset):
        self._samples = deepcopy(dataset)
        for items in self._samples:
                y = items[0]
                X = items[1:]
                self._Y.add(y)  # if y is in the dataset, ignore
                for x in X:
                    if (x, y) in self._numXY:
                        self._numXY[(x, y)] += 1
                    else:
                        self._numXY[(x, y)] = 1

        self._N = len(self._samples)
        self._n = len(self._numXY)
        self._C = max([len(sample)-1 for sample in self._samples])
        self._w = [0]*self._n
        self._lastw = self._w[:]

        self._Ep_ = [0] * self._n
        for i, xy in enumerate(self._numXY):   # calculating the expectation of the feature function fi on the empirical distribution
            self._Ep_[i] = self._numXY[xy]/self._N
            self._xyID[xy] = i
            self._IDxy[i] = xy

    def _Zx(self, X):    # calculating the value of every Z(x)
        zx = 0
        for y in self._Y:
            ss = 0
            for x in X:
                if (x, y) in self._numXY:
                    ss += self._w[self._xyID[(x, y)]]
            zx += math.exp(ss)
        return zx

    def _model_pyx(self, y, X):   # calculating the value of every P(y|x)
        zx = self._Zx(X)
        ss = 0
        for x in X:
            if (x, y) in self._numXY:
                ss += self._w[self._xyID[(x, y)]]
        pyx = math.exp(ss)/zx
        return pyx

    def _model_ep(self, index):   # calculating the expectation of feature function on the model
        x, y = self._IDxy[index]
        ep = 0
        for sample in self._samples:
            if x not in sample:
                continue
            pyx = self._model_pyx(y, sample)
            ep += pyx/self._N
        return ep

    def _convergence(self):  # determine whether all convergence
        for last, now in zip(self._lastw, self._w):
            if abs(last - now) >= self._EPS:
                return False
        return True

    def predict(self, X):   # calculating the probability of prediction
        Z = self._Zx(X)
        result = {}
        for y in self._Y:
            ss = 0
            for x in X:
                if (x, y) in self._numXY:
                    ss += self._w[self._xyID[(x, y)]]
            pyx = math.exp(ss)/Z
            result[y] = pyx
        return result

    def train(self, maxiter=1000):   # train the data
        for loop in range(maxiter):  # the maximum number of train 
            print("iter:%d" % loop)
            self._lastw = self._w[:]
            for i in range(self._n):
                ep = self._model_ep(i)    # calculating the model expectation of the i-th feature
                self._w[i] += math.log(self._Ep_[i]/ep)/self._C   # updating parameter
            print("w:", self._w)
            if self._convergence():  # determine the convergence
                break

In [9]:
dataset = [['no', 'sunny', 'hot', 'high', 'FALSE'],
           ['no', 'sunny', 'hot', 'high', 'TRUE'],
           ['yes', 'overcast', 'hot', 'high', 'FALSE'],
           ['yes', 'rainy', 'mild', 'high', 'FALSE'],
           ['yes', 'rainy', 'cool', 'normal', 'FALSE'],
           ['no', 'rainy', 'cool', 'normal', 'TRUE'],
           ['yes', 'overcast', 'cool', 'normal', 'TRUE'],
           ['no', 'sunny', 'mild', 'high', 'FALSE'],
           ['yes', 'sunny', 'cool', 'normal', 'FALSE'],
           ['yes', 'rainy', 'mild', 'normal', 'FALSE'],
           ['yes', 'sunny', 'mild', 'normal', 'TRUE'],
           ['yes', 'overcast', 'mild', 'high', 'TRUE'],
           ['yes', 'overcast', 'hot', 'normal', 'FALSE'],
           ['no', 'rainy', 'mild', 'high', 'TRUE']]

maxent = MaxEntropy()
x = ['overcast', 'mild', 'high', 'FALSE']
maxent.loadData(dataset)
maxent.train()
print('predict:', maxent.predict(x))

iter:0
w: [0.0455803891984887, -0.002832177999673058, 0.031103560672370825, -0.1772024616282862, -0.0037548445453157455, 0.16394435955437575, -0.02051493923938058, -0.049675901430111545, 0.08288783767234777, 0.030474400362443962, 0.05913652210443954, 0.08028783103573349, 0.1047516055195683, -0.017733409097415182, -0.12279936099838235, -0.2525211841208849, -0.033080678592754015, -0.06511302013721994, -0.08720030253991244]
iter:1
w: [0.11525071899801315, 0.019484939219927316, 0.07502777039579785, -0.29094979172869884, 0.023544184009850026, 0.2833018051925922, -0.04928887087664562, -0.101950931659509, 0.12655289130431963, 0.016078718904129236, 0.09710585487843026, 0.10327329399123442, 0.16183727320804359, 0.013224083490515591, -0.17018583153306513, -0.44038644519804815, -0.07026660158873668, -0.11606564516054546, -0.1711390483931799]
iter:2
w: [0.18178907332733973, 0.04233703122822168, 0.11301330241050131, -0.37456674484068975, 0.05599764270990431, 0.38356978711239126, -0.0748854616816094

iter:489
w: [3.4135033580868983, 0.042125378861913976, 1.4719920166710738, -4.003055212090294, 1.5876471557816685, 4.793753364916262, -0.14110352695827894, -2.0257892778805826, 1.332924425089821, -1.7010318402027622, 1.7589565009367003, -1.1287916291006714, 1.5608460387812897, 2.6832797154169663, 3.505941591242912, -8.530213012230332, -1.6705368830066658, -3.1430080512953533, -5.0365646301223626]
iter:490
w: [3.4160590556197983, 0.04206729924026736, 1.4730884643175148, -4.00603593723487, 1.5889477450740754, 4.7970775028563635, -0.14104674058908617, -2.0272804732727665, 1.3339023729201855, -1.7022530568401455, 1.7600909428250469, -1.1296767500875482, 1.561919388922316, 2.6851099079674032, 3.508581506887267, -8.536656090232658, -1.6718554984635672, -3.1452201634393435, -5.040461543088489]
iter:491
w: [3.4186106726606695, 0.04200947622329211, 1.4741830295265326, -4.00901187051487, 1.590246232582592, 4.8003964486162305, -0.14099020257808342, -2.0287690804721654, 1.3348787658597632, -1.7034