In [1]:
import numpy as np
np.seterr(all='raise')

{'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}

In [62]:
class FastGradMethod:
    def __init__(self, gamma, n, epsilon):
        self.alpha = 0
        self.a     = 0
        
        # constants
        self.epsilon = epsilon
        self.gamma = gamma
        self.l     = 1 / gamma
        self.n     = n
        
        self.u_lambda = np.ones(n) / n**2
        self.u_mu     = np.ones(n) / n**2
         
        self.y_lambda = np.zeros(n)
        self.y_mu     = np.zeros(n)
        
        self.x_lambda   = np.zeros([n])
        self.x_mu       = np.zeros([n])
        
        self.x_0 = np.ones([n, n]) / (n**2)
        
        self.alpha_sum = 0
        # x(y_lambda, y_mu)
        self.x_sum = np.zeros([n, n])
        
    def x_hat(self, c):
        n = self.n
        a = self.x_0 * np.exp(
            -(self.gamma + c + self.y_lambda.repeat(n).reshape(-1, n) + self.y_mu.repeat(n).reshape(-1, n)) / self.gamma
        )
        
        print(a.sum())
        
        return a / a.sum()
    
    def x_wave(self, c):
        return 1 / self.a * self.alpha_sum * self.x_sum
    
    def __new_alpha(self):
        return 1 / (2 * self.l) + np.sqrt(1 / (4 * (self.l**2)) + self.alpha**2)
    
    def __new_a(self, new_alpha):
        return self.a + new_alpha
    
    def __new_y(self):
        new_a = self.__new_a(self.__new_alpha())
        return (self.__new_alpha() * self.u_lambda + self.a * self.x_lambda) / new_a,\
               (self.__new_alpha() * self.u_mu + self.a * self.x_mu) / new_a
    
    def __new_u(self):
        x_hat = self.x_hat(c)
        return self.u_lambda - self.__new_alpha() * (p - x_hat.sum(1)),\
               self.u_mu - self.__new_alpha() * (q - x_hat.sum(0))
        
    def __new_x(self):
        new_alpha = self.__new_alpha()
        new_a = self.__new_a(new_alpha)
        return (new_alpha * self.u_lambda + self.a * self.x_lambda) / new_a,\
               (new_alpha * self.u_mu + self.a * self.x_mu) / new_a
    
    def f(self, c, x_wave):
        return np.sum(c * x_wave) + self.gamma * np.sum(x_wave * np.log(x_wave / self.x_0))
        
    def phi(self, c, lambda_x, mu_x):
        n = self.n
        return -np.sum(lambda_x * p) - np.sum(mu_x * q) +\
        self.gamma * np.log(1 / np.exp(1) * np.sum(
            self.x_0 * np.exp(
                -(self.gamma + c + lambda_x.repeat(n).reshape(-1, n) + mu_x.repeat(n).reshape(-1, n).T)
            )
        ))
    
    def fit(self, c, p, q):
        k = 0
        while True:
            k+=1
            
            self.y_lambda, self.y_mu = self.__new_y()          
            self.u_lambda, self.u_mu = self.__new_u()            
            self.x_lambda, self.x_mu = self.__new_x()
            
            self.alpha = self.__new_alpha()
            self.a     = self.__new_a(self.alpha)
            
            self.alpha_sum += self.alpha
            self.x_sum += self.x_hat(c)
            
            x_wave = self.x_wave(c)
            r = self.epsilon / np.sqrt((self.x_lambda**2).sum() + (self.x_mu**2).sum())
            
            criteria_a = np.sqrt(np.sum((x_wave.sum(1) - p)**2) + np.sum((x_wave.sum(0) - 1)**2)) <= r
            criteria_b = self.f(c, x_wave) - self.phi(c, p, q) <= self.epsilon
            
            if criteria_a and criteria_b:
                return k

In [63]:
fgrad = FastGradMethod(0.3, 5, 0.001)

In [64]:
def sample_batch(n):
    C = np.random.uniform(0, 10, size=[n, n])
    p = np.random.dirichlet(np.ones(5), size=1).ravel()
    q = np.random.dirichlet(np.ones(5), size=1).ravel()
    return C, p, q

In [65]:
c, p, q = sample_batch(5)

In [66]:
fgrad.fit(c,p,q)

0.00363375258193
0.00363375258193
0.00310719754473
0.00310719754473
0.00339035027417
0.00339035027417
0.00423160785421
0.00423160785421
0.00556901439545
0.00556901439545
0.00752844765304
0.00752844765304
0.0103856775616
0.0103856775616
0.0146145881406
0.0146145881406
0.0210109112333
0.0210109112333
0.0309374652324
0.0309374652324
0.046810114998
0.046810114998
0.0731018592003
0.0731018592003
0.118530813724
0.118530813724
0.201044642232
0.201044642232
0.359359904111
0.359359904111
0.679072571152
0.679072571152
1.34931311419
1.34931311419
2.78458080253
2.78458080253
5.89670475507
5.89670475507
12.7189232086
12.7189232086
27.8482076997
27.8482076997
61.840101244
61.840101244
139.434531057
139.434531057
320.039916076
320.039916076
750.043629584
750.043629584
1799.80589555
1799.80589555
4431.56498657
4431.56498657
11212.1640658
11212.1640658
29169.0813148
29169.0813148
78038.9144806
78038.9144806
214655.79268
214655.79268
606758.99313
606758.99313
1761616.32163
1761616.32163
5251015.55433
52

FloatingPointError: underflow encountered in true_divide