In [2]:
import numpy as np
from scipy import linalg 

In [126]:
def Gibbs(X, maxiter = 1000, sigma_X = 1, sigma_A = 1, a = 1, b = 1):
    N, D = X.shape

    ### Initialize
    Z = np.zeros((N, 1))
    while np.sum(Z) == 0:
        Z = np.random.binomial(1, 0.5, size = (N, 1))
    K_plus = 1
    alpha = np.random.gamma(a, b, size = 1)
    H_N = sum(1 / np.arange(1, N + 1))

    ### Likelihood
    def LR(X, Z, i, k, sigma_X, sigma_A):
        # Let Z_{ik} be 1
        Z0 = Z * 1
        Z1 = Z * 1
        Z0[i, k] = 0
        Z1[i, k] = 1
        K = Z.shape[1]
        Mat0 = (sigma_A ** 2 / sigma_X ** 2) * Z0.T @ Z0 + np.eye(K)
        Mat1 = (sigma_A ** 2 / sigma_X ** 2) * Z1.T @ Z1 + np.eye(K)
        #print('det', linalg.det(Mat0), linalg.det(Mat1))
        logcoef = D / 2 * (np.log(linalg.det(Mat0)) - np.log(linalg.det(Mat1)))
        expo = -sigma_A**2 / (2*sigma_X**4) * np.trace(X.T @ (Z0 @ linalg.inv(Mat0) @ Z0.T - Z1 @ linalg.inv(Mat1) @ Z1.T) @ X)
        #print(expo, logcoef)
        return np.exp(expo - logcoef)

    ### MCMC
    for it in range(maxiter):
        print(it)
        res = {'K': np.zeros(maxiter), 'alpha': np.zeros(maxiter), 'Z': [None] * maxiter}
        ## Update Z matrix
        for i in range(N):
            # 1. update Z_{ik}
            count = 0
            for k in range(K_plus):
                r = sum(Z[:, k]) - Z[i, k]   # m_{-i, k}
                if r < 1:
                    Z[i, k] = 0
                    count += 1
                else:
                    prob_ratio = LR(X, Z, i, k, sigma_X, sigma_A) * r / (N - r)
                    prob = prob_ratio / (1 + prob_ratio)
                    Z[i, k] = np.random.binomial(1, prob, size = 1)
                    #print("prob = ", prob, f"Z[{i}, {k}]", Z[i, k])
            # 2. delete unused columns
            # print(np.sum(Z, axis = 0))
            out = np.sum(Z, axis = 0) - Z[i, :]
            print(count, sum(out == 0), K_plus)
            Z = Z[:, out != 0]
            _, K_plus = Z.shape
            # 3. generate new columns
            K_new = int(np.random.poisson(alpha / N, size = 1)) 
            # The number of new features follows Poisson.
            if K_new:
                Z = np.c_[Z, np.zeros((N, K_new))]
                Z[i, -K_new::] = 1
            K_plus += K_new            
            
        ## Update alpha
        alpha = np.random.gamma(a + K_plus, b + H_N, size = 1)
        
        res['K'][it] = K_plus
        res['alpha'][it] = alpha
        res['Z'][it] = Z * 1
    return res

In [127]:
X = np.random.random((100,2)) @ np.random.random((2, 50)) + 1
res = Gibbs(X, maxiter = 5, sigma_A = 1)

0
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 1
0 0 2
0 0 2
0 0 2
0 0 2
0 0 2
0 0 3
0 0 4
0 0 4
0 0 4
0 0 4
0 0 4
0 0 4
0 0 4
0 0 4
0 0 4
0 0 5
0 0 6
0 0 6
0 0 6
0 0 6
0 0 6
0 0 6
0 0 7
0 0 7
0 0 7
0 0 7
0 0 7
0 0 7
0 0 7
0 0 7
0 0 7
0 0 7
0 0 8
0 0 8
0 0 8
0 0 8
0 0 8
0 0 9
0 0 9
0 0 9
0 0 9
0 0 9
0 0 9
0 0 9
0 0 9
0 0 10
0 0 10
0 0 10
0 0 10
0 0 10
0 0 10
0 0 10
0 0 10
0 0 10
0 0 10
0 

KeyboardInterrupt: 

In [116]:
res['K']

array([  0.,   0.,   0.,   0., 108.])

In [36]:
X = np.random.random((100,50))
np.sum(X, axis = 0).shape

(50,)

In [86]:
X.T @ X + (1.7 ** 2 / 0.5 ** 2) * np.eye(5)

array([[13.30326217,  2.25111066,  3.71189545,  1.02025584,  2.09428824],
       [ 2.25111066, 14.47766474,  4.77933367,  1.32334789,  2.72124749],
       [ 3.71189545,  4.77933367, 19.48166406,  2.16481968,  4.43753757],
       [ 1.02025584,  1.32334789,  2.16481968, 12.16031459,  1.23488985],
       [ 2.09428824,  2.72124749,  4.43753757,  1.23488985, 14.10238599]])