In [1]:
import numpy as np

X = np.array([[0,0,1,1,0,0,0,0,0],
              [0,0,0,0,0,1,0,0,1],
              [0,1,0,0,0,0,0,1,0],
              [0,0,0,0,0,0,1,0,1],
              [1,0,0,0,0,1,0,0,0],
              [1,1,1,1,1,1,1,1,1],
              [1,0,1,0,0,0,0,0,0],
              [0,0,0,0,0,0,1,0,1],
              [0,0,0,0,0,2,0,0,1],
              [1,0,1,0,0,0,0,1,0],
              [0,0,0,1,1,0,0,0,0]])

In [2]:
class PLSA:
    
    def __init__(self,n_components=1):
        self.n_components = n_components
        
    def fit(self,X):
        M,N = X.shape
        K = self.n_components
          
        ## initialization
        wz = np.random.random((M,K))
        zd = np.random.random((K,N))
        wz = wz/np.sum(wz,0).reshape(1,-1)
        zd = zd/np.sum(zd,0).reshape(1,-1)
        
        error = 1
        while error>1e-3:
            ## calculate P(z|w,d)
            zwd = np.array([zd[:,i].reshape(-1)*wz for i in range(zd.shape[1])])
            zwd = zwd/np.sum(zwd,2).reshape(N,M,1)
            ## update P(w|z) and P(z|d)
            zwd = zwd*X.T.reshape(N,M,1)
            wz_new = np.sum(zwd,axis=0)
            wz_new = wz_new/np.sum(wz_new,0).reshape(1,-1)
            zd_new = np.sum(zwd,axis=1).T
            zd_new = zd_new/np.sum(zd_new,0).reshape(1,-1)
        
            error = np.mean(np.abs(zd_new-zd))+np.mean(np.abs(wz_new-wz))
            wz,zd = wz_new,zd_new
            
        self.wz = wz
        self.zd = zd

In [3]:
clf = PLSA(n_components=4)
clf.fit(X)

In [4]:
print(np.around(clf.wz,3))

[[0.    0.154 0.    0.   ]
 [0.    0.    0.266 0.011]
 [0.398 0.    0.    0.   ]
 [0.    0.    0.    0.35 ]
 [0.    0.077 0.137 0.   ]
 [0.4   0.307 0.184 0.29 ]
 [0.    0.154 0.    0.   ]
 [0.    0.    0.    0.35 ]
 [0.    0.    0.412 0.   ]
 [0.201 0.153 0.    0.   ]
 [0.    0.154 0.    0.   ]]


In [5]:
print(np.around(clf.zd,3))

[[0.    1.    0.004 0.    0.    0.    0.    1.    0.   ]
 [1.    0.    0.996 1.    1.    0.    0.    0.    0.   ]
 [0.    0.    0.    0.    0.    1.    0.    0.    0.456]
 [0.    0.    0.    0.    0.    0.    1.    0.    0.544]]
