In [1]:
import cvxpy as cp
import numpy as np
import tqdm
import scipy
import math
from scipy.special import xlogy
import time

## init

In [454]:
X = np.array([-5, 2, 5])
Q = len(X)
N = 200
start = -8
end = 8
step = (end-start)/N
S = np.linspace(start, end, N+1)
M = 3

sigma = 1
Y = X + np.random.randn(Q)*sigma

Phi = [scipy.stats.norm(loc=X[i], scale=sigma) for i in range(Q)]

Px = [1/3, 1/3, 1/3]

Compute A(Y|X) and A(X|Y)

In [455]:
Ayx = np.zeros((N, Q))

for j in range(Q):
    for i in range(N):
        Ayx[i, j] = Phi[j].cdf(S[i+1]) - Phi[j].cdf(S[i])
        
Axy = np.zeros((Q, N))
for m in range(Q):
    for n in range(N):
        Axy[m, n] = Px[m]*Ayx[n, m]/np.sum(Px*Ayx[n,:])
        
# fix nan values by repeating nearest row
Axy_cp = Axy.T.copy()

nan_index = np.arange(N)[np.any(np.isnan(Axy.T), axis=1) == True]
upper_half = nan_index[nan_index<N/2]
lower_half = nan_index[nan_index>=N/2]

if len(upper_half) > 0:
    upper_half_idx = upper_half[-1]
    Axy_cp[:upper_half_idx+1,:] = Axy_cp[upper_half_idx+1,:]

if len(lower_half) > 0:
    lower_half_idx = lower_half[0]
    Axy_cp[lower_half_idx:,:] = Axy_cp[lower_half_idx-1,:]
    
Axy = Axy_cp.T

Py = np.matmul(Ayx, Px)

Pxy = Axy*Py

In [457]:
def calc_w(l, r, Pxy, Py):
    tmp = []
    dem = np.sum(Py[l:r+1])
    for k in range(l, r+1):
        tmp_tmp = []
        for i in range(Q):
            num = np.sum(Pxy[i,l:r+1])
            ent = xlogy(num/dem, num/dem)
            # print(num, dem, ent)
            tmp_tmp.append(ent)
            # print(num, dem, ent)
        tmp.append(Py[k]*sum(tmp_tmp))
    return -np.sum(tmp)

In [458]:
calc_w(12, 19, Pxy, Py)

2.701497043691747e-16

In [459]:
calc_w(0, 7, Pxy, Py)

5.586549053562965e-20

In [460]:
DP = np.zeros((N, M))
SOL = np.zeros((N, M))

for n in range(N):
    DP[n, 0] = calc_w(0, n, Pxy, Py)
    SOL[n, 0] = 0
    
for m in range(1, M):
    for n in tqdm.tqdm(np.arange(m, N-M+m+1)[::-1]):
        tmp = []
        for t in range(m-1, n):
            tmp.append(DP[t, m-1] + calc_w(t+1, n, Pxy, Py))
        # SOL[n, m] = np.argmin(tmp)
        SOL[n, m] = np.arange(m-1, n)[np.argmin(tmp)]
        t = int(SOL[n, m])
        DP[n, m] = DP[t, m-1] + calc_w(t+1, n, Pxy, Py)

100%|█████████████████████████████████████████████████████████████████████████████████| 198/198 [00:19<00:00, 10.31it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 198/198 [00:19<00:00, 10.26it/s]


In [461]:
H = []
h_prev = N
H.append(h_prev)
for m in np.arange(M)[::-1]:
    h_prev = int(SOL[h_prev-1, m]) + 1
    H.append(h_prev)
H[-1] -= 1
H = H[::-1]

In [462]:
print(H, S[H])

[0, 81, 144, 200] [-8.   -1.52  3.52  8.  ]
