In [1]:
import numpy as np
import scipy.stats as sps

def generate_F(h, w):
    return sps.randint.rvs(size=(h, w), low=0, high=256)


def generate_B(H, W):
    return sps.randint.rvs(size=(H, W), low=0, high=256)


def generate_prior(H, W, h, w):
    prior = sps.uniform.rvs(size=(H - h + 1, W - w + 1))
    return prior / np.sum(prior)


def generate_sigma(max_s=100):
    return sps.uniform.rvs(size=1, loc=0, scale=max_s)


def compute_prior(X, q, h, w, use_map = False):
    H, W, N = X.shape

    if use_map:
        prior = np.zeros((H - h + 1, W - w + 1))
        
        for k in np.arange(N):
            prior[q[0, k], q[1, k]] += 1
            
        prior = prior / N
    else:
        prior = np.mean(q, axis = 2)

    return prior


def compute_F(X, q, h, w, use_map = False):
    H, W, N = X.shape
    
    if use_map:
        F = np.zeros((h, w))
        
        for k in np.arange(N):
            dh, dw = q[0, k], q[1, k]
            F += X[dh:dh + h, dw:dw + w, k]
            
        F = F / N
    else:  
        F = []

        for dh in np.arange(H - h + 1):
            for dw in np.arange(W - w + 1):
                F.append(np.mean(q[dh, dw, :] * X[dh:dh + h, dw:dw + w, :], axis=2))

        F = sum(F)

    return F


def compute_B(X, q, h, w, use_map = False):
    H, W, N = X.shape

    if use_map:
        B = X.sum(axis=2)
        Z = np.ones((H, W)) * N

        for k in np.arange(N):
            B[q[0, k]:q[0, k] + h, q[1, k]:q[1, k] + w] -= X[q[0, k]:q[0, k] + h, q[1, k]:q[1, k] + w, k]
            Z[q[0, k]:q[0, k] + h, q[1, k]:q[1, k] + w] -= 1

        B = B / Z
    else:
        Q = np.cumsum(np.cumsum(q, axis=0), axis=1)

        Q1 = np.zeros((H, W, N)) + Q[-1, -1]
        Q1[H - h + 1:, :W - w + 1] = Q[-1:, :]
        Q1[:H - h + 1, W - w + 1:] = Q[:, -1:]
        Q1[:H - h + 1, :W - w + 1] = Q

        Q2 = np.zeros((H, W, N))
        Q2[:H - h + 1, w:] = Q[:, :-1]
        Q2[H - h + 1:, w:] = Q[-1:, :-1]

        Q3 = np.zeros((H, W, N))
        Q3[h:, :W - w + 1] = Q[:-1, :]
        Q3[h:, W - w + 1:] = Q[:-1, -1:]

        Q4 = np.zeros((H, W, N))
        Q4[h:, w:] = Q[:-1, :-1]

        S = 1 - (Q1 - Q2 - Q3 + Q4)

        B = (X * S).sum(axis=2) / S.sum(axis=2)

    return np.nan_to_num(B)


def compute_sigma(X, q, h, w, F, B, use_map = False):
    H, W, N = X.shape

    L = np.zeros((H - h + 1, W - w + 1, N))
    X = X.transpose([2, 0, 1])

    dB = (X - B)**2
    cumsum_dB = np.cumsum(np.cumsum(dB, axis=1), axis=2)
    cumsum_dB = cumsum_dB.transpose([1, 2, 0])

    if use_map:
        for k in np.arange(N):
            current_dB = cumsum_dB[q[0, k] + h - 1, q[1, k] + w - 1, k].copy()
            
            if q[0, k] > 0:
                current_dB -= cumsum_dB[q[0, k] - 1, q[1, k] + w - 1, k]
                
            if q[1, k] > 0:
                current_dB -= cumsum_dB[q[0, k] + h - 1, q[1, k] - 1, k]
                
            if q[0, k] > 0 and q[1, k] > 0:
                current_dB += cumsum_dB[q[0, k] - 1, q[1, k] - 1, k]

            L[dh, dw, k] = ((X[k, q[0, k]:q[0, k] + h, q[1, k]:q[1, k] + w] - F)**2).sum() - current_dB
    else:
        for dh in np.arange(H - h + 1):
            for dw in np.arange(W - w + 1):
                current_dB = cumsum_dB[dh + h - 1, dw + w - 1].copy()
                
                if dh > 0:
                    current_dB -= cumsum_dB[dh - 1, dw + w - 1]
                    
                if dw > 0:
                    current_dB -= cumsum_dB[dh + h - 1, dw - 1]
                    
                if dh > 0 and dw > 0:
                    current_dB += cumsum_dB[dh - 1, dw - 1]

                L[dh, dw] = q[dh, dw] * (((X[:, dh:dh + h, dw:dw + w] - F)**2).sum(axis=(1, 2)) - current_dB)

    s_sq = (np.sum(np.nan_to_num(L)) + np.sum(cumsum_dB[-1, -1])) / (H * W * N)

    return np.sqrt(s_sq)


def likelihood(X, F, B, sigma):
    H, W, N = X.shape
    h, w = F.shape

    l = np.zeros((H - h + 1, W - w + 1, N))
    x = X.transpose([2, 0, 1])

    cs = -h * w * np.log(2 * np.pi * sigma**2) / 2
    cm = -1. / (2 * sigma ** 2)

    logpdf_B = -np.log(2 * np.pi * sigma**2) / 2 + cm * (x - B)**2
    cumsum_logpdf_B = np.cumsum(np.cumsum(logpdf_B, axis=1), axis=2)
    cumsum_logpdf_B = cumsum_logpdf_B.transpose([1, 2, 0])

    for dh in np.arange(H - h + 1):
        for dw in np.arange(W - w + 1):
            logpdf_F = ((x[:, dh:dh + h, dw:dw + w] - F)**2).sum(axis=(1, 2))
            l[dh, dw] = cumsum_logpdf_B[-1, -1] -\
                        cumsum_logpdf_B[dh + h - 1, dw + w - 1] +\
                        cs + cm * logpdf_F
                    
            if dh > 0:
                l[dh, dw] += cumsum_logpdf_B[dh - 1, dw + w - 1]
            if dw > 0:
                l[dh, dw] += cumsum_logpdf_B[dh + h - 1, dw - 1]
            if dh > 0 and dw > 0:
                l[dh, dw] -= cumsum_logpdf_B[dh - 1, dw - 1]

    return l


def lower_bound(X, F, B, sigma, prior, q, is_map = False, ll_x = None):
    if ll_x is None:
        ll_x = likelihood(X, F, B, sigma)

    if is_map:
        lb = 0
        
        for k in range(ll_xk.shape[2]):
            lb += ll_xk[q[0, k], q[1, k], k] + np.log(prior[q[0, k], q[1, k]])
    else:
        ll_x = ll_x.transpose([2, 0, 1])
        q = q.transpose([2, 0, 1])

        lb = np.sum(q * (ll_x + np.nan_to_num(np.log(prior) - np.log(q))))

    return lb


def e_step(X, F, B, sigma, prior, is_map = False, ll_x = None):
    if ll_x is None:
        ll_x = likelihood(X, F, B, sigma)
        
    ll_x = ll_x.transpose([2, 0, 1])

    nominator = (ll_x + np.log(prior)).transpose([1, 2, 0])
    nominator = nominator - nominator.max(axis=(0,1))
    nominator = np.exp(nominator)
    q = nominator / nominator.sum(axis=(0,1))

    if is_map:
        _, W, N = X.shape
        _, w = F.shape

        q_map = np.zeros((2, N))

        for k in np.arange(N):
            index = q[:, :, k].argmax()
            q_map[0, k] = index / (W - w + 1)
            q_map[1, k] = index % (W - w + 1)
        return q_map
    else:
        return q


def m_step(X, q, h, w, use_map = False, B = None):
    prior = compute_prior(X, q, h, w, use_map = use_map)
    F = compute_F(X, q, h, w, use_map = use_map)
    
    if B is None:
        B = compute_B(X, q, h, w, use_map = use_map)
    
    sigma = compute_sigma(X, q, h, w, F, B, use_map = use_map)

    return F, B, sigma, prior


def run_EM(X,
           h,
           w,
           F = None,
           B = None,
           sigma = None,
           prior = None,
           tol = 1e-3,
           max_iter = 50,
           use_map = False,
           fix_B = False):
    H, W, N = X.shape

    # Initialisation
    if F is None:
        F = generate_F(h, w)
        
    if B is None:
        if fix_B:
            raise ValueError('Provide B !!!')
        B = generate_B(H, W)
        
    if prior is None:
        prior = generate_prior(H, W, h, w)
        
    if sigma is None:
        sigma = generate_sigma()

    ll_x = likelihood(X, F, B, sigma)
    
    LL = []

    for i in np.arange(max_iter):
        q = e_step(X, F, B, sigma, prior, use_map = use_map, ll_x = ll_x)

        if fix_B:
            F, _, sigma, prior = m_step(X, q, h, w, use_map = use_map, B = B)
        else:
            F, B, sigma, prior = m_step(X, q, h, w, use_map = use_map)

        ll_x = likelihood(X, F, B, sigma)
        LL.append(lower_bound(X, F, B, sigma, prior, q, use_map = use_map, ll_x = ll_x))

        if i > 0 and abs(LL[-1] - LL[-2]) < tol:
            break

    return F, B, sigma, prior, np.array(LL)