In [None]:
from numpy.random import multinomial
from numpy import log, exp

class GibbsSamplingDMM:
    def __init__(self, K=300, D=1000, alpha=0.1, beta=0.1, n_iters=10):
        self.K = K
        self.D = D
        self.alpha = alpha
        self.beta = beta
        self.n_iters = n_iters

        # slots for computed variables
        self.number_docs = None
        self.vocab_size = None
        self.cluster_doc_count = [0 for _ in range(K)]
        self.cluster_word_count = [0 for _ in range(K)]
        self.cluster_word_distribution = [{} for i in range(K)]
        self.document_word_distribution = [{} for i in range(D)]
        self.iterations_document_cluster_distribution = [[] for i in range(n_iters)]

    @staticmethod
    def _sample(p):
  
        return [i for i, entry in enumerate(multinomial(1, p)) if entry != 0][0]

    def fit(self, docs, vocab_size):

        alpha, beta, K, n_iters, V = self.alpha, self.beta, self.K, self.n_iters, vocab_size

        D = len(docs)
        self.number_docs = D
        self.vocab_size = vocab_size

        # unpack to easy var names
        m_z, n_z, n_z_w, n_d_w = self.cluster_doc_count, self.cluster_word_count, self.cluster_word_distribution, self.document_word_distribution
        cluster_count = K
        i_d_z = self.iterations_document_cluster_distribution
        d_z = [None for i in range(len(docs))]

        # initialize the clusters
        for i, doc in enumerate(docs):

            # choose a random  initial cluster for the doc
            z = self._sample([1.0 / K for _ in range(K)])
            d_z[i] = z
            m_z[z] += 1
            n_z[z] += len(doc) 

            for word in doc:
                if word not in n_d_w[i]:
                    n_d_w[i][word] = 0
                n_d_w[i][word] += 1

            for word in doc:
                if word not in n_z_w[z]:
                    n_z_w[z][word] = 0
                n_z_w[z][word] += 1

        for _iter in range(n_iters):
            total_transfers = 0

            for i, doc in enumerate(docs):

                # remove the doc from it's current cluster
                z_old = d_z[i]

                m_z[z_old] -= 1
                n_z[z_old] -= len(doc)

                for word in doc:
                    n_z_w[z_old][word] -= 1

                    # compact dictionary to save space
                    if n_z_w[z_old][word] == 0:
                        del n_z_w[z_old][word]

                # draw sample from distribution to find new cluster
                p = self.score(doc, i)
                z_new = self._sample(p)

                # transfer doc to the new cluster
                if z_new != z_old:
                    total_transfers += 1

                d_z[i] = z_new
                i_d_z[_iter].append(d_z[i])
                m_z[z_new] += 1
                n_z[z_new] += len(doc)

                for word in doc:
                    if word not in n_z_w[z_new]:
                        n_z_w[z_new][word] = 0
                    n_z_w[z_new][word] += 1

            
            cluster_count_new = sum([1 for v in m_z if v > 0])
            print("In stage %d: transferred %d clusters with %d clusters populated" % (
            _iter, total_transfers, cluster_count_new))
            if total_transfers == 0 and cluster_count_new == cluster_count and _iter>5:
                print("Converged.  Breaking out.")
                break
            cluster_count = cluster_count_new

        self.cluster_word_distribution = n_z_w
        return d_z

    def score(self, doc, index):
    
        alpha, beta, K, V, D = self.alpha, self.beta, self.K, self.vocab_size, self.number_docs
        m_z, n_z, n_z_w, n_d_w = self.cluster_doc_count, self.cluster_word_count, self.cluster_word_distribution, self.document_word_distribution

        p = [0 for _ in range(K)]

        #  We break the formula into the following pieces
        #  p = N1*N2/(D1*D2) = exp(lN1 - lD1 + lN2 - lD2)
        #  lN1 = log(m_z[z] + alpha)
        #  lD1 = log(D - 1 + K*alpha)
        #  lN2 = log(product(product(n_z_w[w] + beta + j - 1))) 
        #      = sum(log(product(n_z_w[w] + beta + j - 1)))
        #      = sum(sum(log(n_z_w[w] + beta + j  - 1)))
        #  lD2 = log(product(n_z[d] + V*beta + i -1)) 
        #      = sum(log(n_z[d] + V*beta + i -1))

        lD1 = log(D - 1 + K * alpha)
        doc_size = len(doc)
        for label in range(K):
            lN1 = log(m_z[label] + alpha)
            lN2 = 0
            lD2 = 0
            for word in doc:
              temp = 0
              for j in range(n_d_w[index][word]):
                temp += log(n_z_w[label].get(word, 0) + beta + j)
              lN2 += temp
            for i in range(1, doc_size +1):
                lD2 += log(n_z[label] + V * beta + i - 1)
            p[label] = exp(lN1 - lD1 + lN2 - lD2)

        # normalize the probability vector
        pnorm = sum(p)
        pnorm = pnorm if pnorm>0 else 1
        return [pp/pnorm for pp in p]
