In [133]:
import numpy as np
import copy
import networkx as nx
import matplotlib.pyplot as plt
import time
import sys

def loadfile(filename1, filename2=None):
    ds1 = np.loadtxt(filename1, delimiter=",", dtype=int)
    if filename2:
        ds2 = np.loadtxt(filename2, delimiter=",", dtype=int)
        ds = np.vstack((ds1, ds2))
    else:
        ds = ds1
    return ds, ds.shape[0], ds.shape[1]

def count_matrix(ds, m, n):
    prob_xy = np.zeros((n, n, 4))
    for i in range(n):
        subds = ds[ds[:, i] == 0]
        for j in range(n):
            if prob_xy[i, j, 0] == 0:
                prob_xy[i, j, 0] = (subds[subds[:, j] == 0].shape[0]+1)/(m+4)
            if prob_xy[j, i, 0] == 0:
                prob_xy[j, i, 0] = prob_xy[i, j, 0]
            if prob_xy[i, j, 1] == 0:
                prob_xy[i, j, 1] = (subds[subds[:, j] == 1].shape[0]+1)/(m+4)
            if prob_xy[j, i, 2] == 0:
                prob_xy[j, i, 2] = prob_xy[i, j, 1]
            
        subds = ds[ds[:, i] == 1]
        for j in range(n):
            if prob_xy[i, j, 2] == 0:
                prob_xy[i, j, 2] = (subds[subds[:, j] == 0].shape[0]+1)/(m+4)
            if prob_xy[j, i, 1] == 0:
                prob_xy[j, i, 1] = prob_xy[i, j, 2]
            if prob_xy[i, j, 3] == 0:
                prob_xy[i, j, 3] = (subds[subds[:, j] == 1].shape[0]+1)/(m+4)
            if prob_xy[j, i, 3] == 0:
                prob_xy[j, i, 3] = prob_xy[i, j, 3]
    return prob_xy

def prob_matrix(ds, m, n, k=0):
    prob_xy = np.zeros((n, n, 4))
    l = 1
    for i in range(n):
        subds = ds[ds[:, i] == 0]
        for j in range(n):
            if prob_xy[i, j, 0] == 0:
                prob_xy[i, j, 0] = (np.sum(subds[subds[:, j] == 0][:, n+k]))
                cnt = subds[subds[:, j] == 0].shape[0]
                l = prob_xy[i, j, 0]/cnt if cnt>0 and prob_xy[i, j, 0]/cnt < l else l
            if prob_xy[j, i, 0] == 0:
                prob_xy[j, i, 0] = prob_xy[i, j, 0]
            if prob_xy[i, j, 1] == 0:
                prob_xy[i, j, 1] = (np.sum(subds[subds[:, j] == 1][:, n+k]))
                cnt = subds[subds[:, j] == 1].shape[0]
                l = prob_xy[i, j, 1]/cnt if cnt>0 and prob_xy[i, j, 1]/cnt < l else l
            if prob_xy[j, i, 2] == 0:
                prob_xy[j, i, 2] = prob_xy[i, j, 1]
            
        subds = ds[ds[:, i] == 1]
        for j in range(n):
            if prob_xy[i, j, 2] == 0:
                prob_xy[i, j, 2] = (np.sum(subds[subds[:, j] == 0][:, n+k]))
                cnt = subds[subds[:, j] == 0].shape[0]
                l = prob_xy[i, j, 2]/cnt if cnt>0 and prob_xy[i, j, 2]/cnt < l else l
            if prob_xy[j, i, 1] == 0:
                prob_xy[j, i, 1] = prob_xy[i, j, 2]
            if prob_xy[i, j, 3] == 0:
                prob_xy[i, j, 3] = (np.sum(subds[subds[:, j] == 1][:, n+k]))
                cnt = subds[subds[:, j] == 1].shape[0]
                l = prob_xy[i, j, 3]/cnt if cnt>0 and prob_xy[i, j, 3]/cnt < l else l
            if prob_xy[j, i, 3] == 0:
                prob_xy[j, i, 3] = prob_xy[i, j, 3]
    #l = np.sum(prob_xy)/(m*n*n)
    return (prob_xy+l)/(m+4*l)

def mutual_info(prob_xy, n):
    I_xy = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            if i < j:
                I_xy[i, j] = prob_xy[i, j, 0]*np.log(prob_xy[i, j, 0]/(prob_xy[i, i, 0]*prob_xy[j, j, 0])) \
                + prob_xy[i, j, 1]*np.log(prob_xy[i, j, 1]/(prob_xy[i, i, 0]*prob_xy[j, j, 3])) \
                + prob_xy[i, j, 2]*np.log(prob_xy[i, j, 2]/(prob_xy[i, i, 3]*prob_xy[j, j, 0])) \
                + prob_xy[i, j, 3]*np.log(prob_xy[i, j, 3]/(prob_xy[i, i, 3]*prob_xy[j, j, 3]))
    return I_xy

def draw_tree(edge_wts, prnt = False, k=0, step=0):
    edge_wts_cp = copy.deepcopy(edge_wts)
    edges = [np.unravel_index(np.argmax(edge_wts_cp), edge_wts_cp.shape)]
    visited = [[edges[-1][0],edges[-1][1]]]
    edge_wts_cp[edges[-1]] = 0
    while(len(edges) < edge_wts.shape[0]-1):
        i = j = -1
        edge = np.unravel_index(np.argmax(edge_wts_cp), edge_wts_cp.shape)
        for bag in range(len(visited)):
            if edge[0] in visited[bag]:
                i = bag
            if edge[1] in visited[bag]:
                j = bag
        if i == -1 and j != -1:
            edges.append(edge)
            visited[j].append(edge[0])
        elif i != -1 and j == -1:
            edges.append(edge)
            visited[i].append(edge[1])
        elif i == -1 and j == -1:
            edges.append(edge)
            visited.append([edge[0], edge[1]])
        elif i != -1 and j != -1 and i != j:
            edges.append(edge)
            visited[i] += visited[j]
            visited.remove(visited[j])
        elif i == j != -1:
            pass
            #print("Discarded", edge)
        else:
            print("Discarded in else", edge)
        edge_wts_cp[edge] = 0
    
    if prnt:
        G = nx.Graph()
        G.add_nodes_from(visited[0])
        G.add_edges_from(edges)
        plt.figure(str(step)+str(k))
        nx.draw_networkx(G, with_labels=True, arrows=True)
    
    return edges

def count_matrix(ds, tree, cols):
    count_xy = np.zeros((len(tree), cols))
    for idx, node in enumerate(tree):
        i, j = node
        count_xy[idx] = [ds[(ds[:, i]==0) & (ds[:, j]==0)].shape[0], ds[(ds[:, i]==0) & (ds[:, j]==1)].shape[0], ds[(ds[:, i]==1) & (ds[:, j]==0)].shape[0], ds[(ds[:, i]==1) & (ds[:, j]==1)].shape[0]]
    #print(count_xy)
    return count_xy

def exist_matrix(ds, tree, cols):
    rows = ds.shape[0]
    exist_xy = np.zeros((rows, len(tree), cols))
    for idx, node in enumerate(tree):
        i, j = node
        exist_xy[:,idx,:] = np.hstack((((ds[:, i]==0) & (ds[:, j]==0)).astype(int).reshape(rows,1), ((ds[:, i]==0) & (ds[:, j]==1)).astype(int).reshape(rows,1), ((ds[:, i]==1) & (ds[:, j]==0)).astype(int).reshape(rows,1), ((ds[:, i]==1) & (ds[:, j]==1)).astype(int).reshape(rows,1)))
    return exist_xy

def M_step(ds, m, n, k, step, prnt, pk):
    trees = []
    cond_probs = []
    for ki in range(k):
        #prob_xy = prob_matrix(np.hstack((ds[:,:n], pk)), m, n, ki)
        prob_xy = prob_matrix(ds, m, n, ki)
        I_xy = mutual_info(prob_xy, n)        
        tree = draw_tree(I_xy, prnt, ki, step)
        tree = [(tree[0][0], tree[0][0])] + tree
        trees.append(tree)
        cond_prob = np.zeros((len(tree), prob_xy.shape[2]))
        for idx, node in enumerate(tree):
            if node[0] == node[1]:
                #cond_prob[idx] = np.log(prob_xy[node[0], node[1],:])
                cond_prob[idx] = prob_xy[node[0], node[1],:]
            else:
                #cond_prob[idx] = np.log(np.hstack(((prob_xy[node[0], node[1],:2]/prob_xy[node[0], node[0], 0]),(prob_xy[node[0], node[1],2:]/prob_xy[node[0], node[0], 3]))))
                cond_prob[idx] = np.hstack(((prob_xy[node[0], node[1],:2]/prob_xy[node[0], node[0], 0]),(prob_xy[node[0], node[1],2:]/prob_xy[node[0], node[0], 3])))
        cond_probs.append(cond_prob)
    return trees, cond_probs

def random_init(ds, m, n, k, step, prnt):
    trees = []
    cond_probs = []
    for ki in range(k):
        
        prob_xy = prob_matrix(ds[np.random.choice(m, 8, replace=False), :], m, n, ki)

        I_xy = mutual_info(prob_xy, n)
        
        tree = draw_tree(I_xy, prnt, ki, step)
        tree = [(tree[0][0], tree[0][0])] + tree
        trees.append(tree)
        cond_prob = np.zeros((len(tree), prob_xy.shape[2]))
        for idx, node in enumerate(tree):
            if node[0] == node[1]:
                #cond_prob[idx] = np.log(prob_xy[node[0], node[1],:])
                cond_prob[idx] = prob_xy[node[0], node[1],:]
            else:
                #cond_prob[idx] = np.log(np.hstack(((prob_xy[node[0], node[1],:2]/prob_xy[node[0], node[0], 0]),(prob_xy[node[0], node[1],2:]/prob_xy[node[0], node[0], 3]))))
                cond_prob[idx] = np.hstack(((prob_xy[node[0], node[1],:2]/prob_xy[node[0], node[0], 0]),(prob_xy[node[0], node[1],2:]/prob_xy[node[0], node[0], 3])))
        cond_probs.append(cond_prob)
    return trees, cond_probs

def E_step(ds, m, n, k, trees, cond_probs):
    
    ph = ds[:, n:].sum(axis = 0)/m
    weight_ij = np.zeros((m, k))

    #for i in range(m):
    #    for j in range(k):
    #        weight_ij[i, j] = ph[j] * np.sum(count_matrix(ds[i:i+1, 0:n], trees[j], 4)*cond_probs[j])

    for j in range(k):
        if j < 0:
            print("ph[j] * np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1)")
            print(ph[j] * np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1))
            print("ph[j]" )
            print(ph[j] )
            print("np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1)")
            print(np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1))
            print("np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2)")
            print(np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2))
            print("exist_matrix(ds, trees[j], 4)*cond_probs[j]")
            print(exist_matrix(ds, trees[j], 4)*cond_probs[j])
            print("exist_matrix(ds, trees[j], 4)")
            print(exist_matrix(ds, trees[j], 4))
            print("cond_probs[j]")
            print(cond_probs[j])
        #weight_ij[:, j] = ph[j] * np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 1).sum(axis=1)
        weight_ij[:, j] = ph[j] * np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1)

    weight_ij = weight_ij/np.sum(weight_ij, axis = 1).reshape(m, 1)

    ll1 = 0
    #for i in range(m):
    #    for j in range(k):
    #        ll1 += (weight_ij[i, j] * (np.log(ph[j]) + np.log(np.sum(count_matrix(ds[i:i+1, 0:n], trees[j], 4)*cond_probs[j], axis = 1)).sum()))
    ll2 = 0
    for j in range(k):
        ll2 += (weight_ij[:, j] * (np.log(ph[j]) + np.log(np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1)))).sum()
        #print("LL2", j, ll2)
    print("LL:", ll1, ll2)
    
    ds[:, n:] = weight_ij
    pk = weight_ij / np.sum(weight_ij, axis = 0).reshape(1, k)
    return ds, pk

if __name__ == "__main__":
    ds, m, n = loadfile("small-10-datasets/nltcs.ts.data")
    #ds, m, n = ds[:10, :3], 10, 3
    k = 2
    weight_ij = np.random.rand(m, k)
    weight_ij = weight_ij/np.sum(weight_ij, axis = 1).reshape(m, 1)
    ds = np.hstack((ds, weight_ij))
    trees, cond_probs = random_init(ds, m, n, k, 0, False)

    
    #ds, pk = E_step(ds, m, n, k, trees, cond_probs)
    #print(pk)
    #trees, cond_probs = M_step(ds, m, n, k, step, False, pk)
    #print(trees)
    #print(cond_probs)
    
    for step in range(1,30): 
        ds, pk = E_step(ds, m, n, k, trees, cond_probs)
        #print(pk)
        trees, cond_probs = M_step(ds, m, n, k, step, False, pk)        

    ts = np.loadtxt("small-10-datasets/nltcs.test.data", delimiter=",", dtype=int)
    lambda_k = ds[:, n:].sum(axis = 0)/m
    L = 0
    ll1 = 0
    ll2 = 0
    for j in range(k):
        #t0=time.time()
        #L += ts.shape[0]*np.log(lambda_k[j]) + np.sum(count_matrix(ts, trees[j], 4)*np.log(cond_probs[j]))
        #t1=time.time()
        ll1 +=  np.log(lambda_k[j]*(np.sum(exist_matrix(ts, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1))).sum()
        #t2=time.time()
        #ll2 += (np.log(lambda_k[j]) + np.log(np.sum(exist_matrix(ts, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1))).sum()
        #t3=time.time()
        #print(t1-t0, t2-t1, t3-t2)
    print("LL:", L,"LL1", ll1/ts.shape[0], "LL2", ll2)
    
    #print(edge_wts[np.unravel_index(np.argmax(edge_wts), edge_wts.shape)] == edge_wts.max())

LL: 0 -242205.58209980346
LL: 0 -122346.03458183493
LL: 0 -119522.79941509123
LL: 0 -115118.62463463368
LL: 0 -111701.91345738183
LL: 0 -110407.49762976074
LL: 0 -110100.51292950283
LL: 0 -109981.36063221797
LL: 0 -109842.07315150074
LL: 0 -109920.35492120196
LL: 0 -109965.869045493
LL: 0 -109911.14557820334
LL: 0 -109864.05812897913
LL: 0 -109815.31439119179
LL: 0 -109750.11523556376
LL: 0 -109724.53344900104
LL: 0 -109702.53125139774
LL: 0 -109710.01445394265
LL: 0 -110980.18330440384
LL: 0 -110930.29809870235
LL: 0 -110945.81099679609
LL: 0 -110937.86870320693
LL: 0 -110941.32878324192
LL: 0 -110932.14725041742
LL: 0 -110960.52904777652
LL: 0 -110952.4730259808
LL: 0 -110923.4877642873
LL: 0 -110871.70019825072
LL: 0 -110851.71515090956
LL: 0 LL1 -46.05173372880206 LL2 0


In [122]:
29.527144146457992 
#LL: 0 LL1 -139430.12802308943 LL2 0
-110642.78809424426/ts.shape[0]

-31.775642761127013

In [159]:
print(np.sum(count_matrix(ds[-3:-2, 0:n], trees[0], 4)*cond_probs[0]))

-6.513819707245405


In [111]:
np.array([0 1 0 0 0 0 1 0]).reshape[]

array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       ...,
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 1.]])