In [8]:
import numpy as np
import copy
import networkx as nx
import matplotlib.pyplot as plt
import time
import sys
import math

def loadfile(filename1, filename2=None):
    ds1 = np.loadtxt(filename1, delimiter=",", dtype=int)
    if filename2:
        ds2 = np.loadtxt(filename2, delimiter=",", dtype=int)
        ds = np.vstack((ds1, ds2))
    else:
        ds = ds1
    return ds, ds.shape[0], ds.shape[1]

def count_matrix(ds, m, n):
    prob_xy = np.zeros((n, n, 4))
    for i in range(n):
        subds = ds[ds[:, i] == 0]
        for j in range(n):
            if prob_xy[i, j, 0] == 0:
                prob_xy[i, j, 0] = (subds[subds[:, j] == 0].shape[0]+1)/(m+4)
            if prob_xy[j, i, 0] == 0:
                prob_xy[j, i, 0] = prob_xy[i, j, 0]
            if prob_xy[i, j, 1] == 0:
                prob_xy[i, j, 1] = (subds[subds[:, j] == 1].shape[0]+1)/(m+4)
            if prob_xy[j, i, 2] == 0:
                prob_xy[j, i, 2] = prob_xy[i, j, 1]
            
        subds = ds[ds[:, i] == 1]
        for j in range(n):
            if prob_xy[i, j, 2] == 0:
                prob_xy[i, j, 2] = (subds[subds[:, j] == 0].shape[0]+1)/(m+4)
            if prob_xy[j, i, 1] == 0:
                prob_xy[j, i, 1] = prob_xy[i, j, 2]
            if prob_xy[i, j, 3] == 0:
                prob_xy[i, j, 3] = (subds[subds[:, j] == 1].shape[0]+1)/(m+4)
            if prob_xy[j, i, 3] == 0:
                prob_xy[j, i, 3] = prob_xy[i, j, 3]
    return prob_xy

def prob_matrix(ds, m, n, k=0):
    prob_xy = np.zeros((n, n, 4))
    l = 1
    for i in range(n):
        subds = ds[ds[:, i] == 0]
        for j in range(n):
            if prob_xy[i, j, 0] == 0:
                prob_xy[i, j, 0] = (np.sum(subds[subds[:, j] == 0][:, n+k]))
                cnt = subds[subds[:, j] == 0].shape[0]
                l = prob_xy[i, j, 0]/cnt if cnt>0 and prob_xy[i, j, 0]/cnt < l else l
            if prob_xy[j, i, 0] == 0:
                prob_xy[j, i, 0] = prob_xy[i, j, 0]
            if prob_xy[i, j, 1] == 0:
                prob_xy[i, j, 1] = (np.sum(subds[subds[:, j] == 1][:, n+k]))
                cnt = subds[subds[:, j] == 1].shape[0]
                l = prob_xy[i, j, 1]/cnt if cnt>0 and prob_xy[i, j, 1]/cnt < l else l
            if prob_xy[j, i, 2] == 0:
                prob_xy[j, i, 2] = prob_xy[i, j, 1]
            
        subds = ds[ds[:, i] == 1]
        for j in range(n):
            if prob_xy[i, j, 2] == 0:
                prob_xy[i, j, 2] = (np.sum(subds[subds[:, j] == 0][:, n+k]))
                cnt = subds[subds[:, j] == 0].shape[0]
                l = prob_xy[i, j, 2]/cnt if cnt>0 and prob_xy[i, j, 2]/cnt < l else l
            if prob_xy[j, i, 1] == 0:
                prob_xy[j, i, 1] = prob_xy[i, j, 2]
            if prob_xy[i, j, 3] == 0:
                prob_xy[i, j, 3] = (np.sum(subds[subds[:, j] == 1][:, n+k]))
                cnt = subds[subds[:, j] == 1].shape[0]
                l = prob_xy[i, j, 3]/cnt if cnt>0 and prob_xy[i, j, 3]/cnt < l else l
            if prob_xy[j, i, 3] == 0:
                prob_xy[j, i, 3] = prob_xy[i, j, 3]
    #l = np.sum(prob_xy)/(m*n*n)
    return (prob_xy+l)/(m+4*l)

def mutual_info(prob_xy, n):
    I_xy = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            if i < j:
                I_xy[i, j] = prob_xy[i, j, 0]*np.log(prob_xy[i, j, 0]/(prob_xy[i, i, 0]*prob_xy[j, j, 0])) \
                + prob_xy[i, j, 1]*np.log(prob_xy[i, j, 1]/(prob_xy[i, i, 0]*prob_xy[j, j, 3])) \
                + prob_xy[i, j, 2]*np.log(prob_xy[i, j, 2]/(prob_xy[i, i, 3]*prob_xy[j, j, 0])) \
                + prob_xy[i, j, 3]*np.log(prob_xy[i, j, 3]/(prob_xy[i, i, 3]*prob_xy[j, j, 3]))
    return I_xy

def draw_tree(edge_wts, prnt = False, k=0, step=0):
    edge_wts_cp = copy.deepcopy(edge_wts)
    edges = [np.unravel_index(np.argmax(edge_wts_cp), edge_wts_cp.shape)]
    visited = [[edges[-1][0],edges[-1][1]]]
    edge_wts_cp[edges[-1]] = 0
    while(len(edges) < edge_wts.shape[0]-1):
        i = j = -1
        edge = np.unravel_index(np.argmax(edge_wts_cp), edge_wts_cp.shape)
        for bag in range(len(visited)):
            if edge[0] in visited[bag]:
                i = bag
            if edge[1] in visited[bag]:
                j = bag
        if i == -1 and j != -1:
            edges.append(edge)
            visited[j].append(edge[0])
        elif i != -1 and j == -1:
            edges.append(edge)
            visited[i].append(edge[1])
        elif i == -1 and j == -1:
            edges.append(edge)
            visited.append([edge[0], edge[1]])
        elif i != -1 and j != -1 and i != j:
            edges.append(edge)
            visited[i] += visited[j]
            visited.remove(visited[j])
        elif i == j != -1:
            pass
            #print("Discarded", edge)
        else:
            print("Discarded in else", edge)
        edge_wts_cp[edge] = 0
    
    new_tree = []
    make_tree(edges, new_tree, edges[0][0])
    
    if prnt:
        G = nx.Graph()
        G.add_nodes_from(visited[0])
        G.add_edges_from(new_tree)
        plt.figure(str(step)+str(k))
        nx.draw_networkx(G, with_labels=True, arrows=True)
    
    return new_tree

def make_tree(ls, new_tree, parent):
    for node in [item for item in ls if parent in item]:
        if node[0] == parent:
            new_tree.append(node)
            ls.remove(node)
            #print(node, ls, new_tree)
            make_tree(ls, new_tree, node[1])
        else:
            new_tree.append((node[1],node[0]))
            ls.remove(node)
            #print(node, ls, new_tree)
            make_tree(ls, new_tree, node[0])

def count_matrix(ds, tree, cols):
    count_xy = np.zeros((len(tree), cols))
    for idx, node in enumerate(tree):
        i, j = node
        count_xy[idx] = [ds[(ds[:, i]==0) & (ds[:, j]==0)].shape[0], ds[(ds[:, i]==0) & (ds[:, j]==1)].shape[0], ds[(ds[:, i]==1) & (ds[:, j]==0)].shape[0], ds[(ds[:, i]==1) & (ds[:, j]==1)].shape[0]]
    #print(count_xy)
    return count_xy

def exist_matrix(ds, tree, cols):
    rows = ds.shape[0]
    exist_xy = np.zeros((rows, len(tree), cols))
    for idx, node in enumerate(tree):
        i, j = node
        exist_xy[:,idx,:] = np.hstack((((ds[:, i]==0) & (ds[:, j]==0)).astype(int).reshape(rows,1), ((ds[:, i]==0) & (ds[:, j]==1)).astype(int).reshape(rows,1), ((ds[:, i]==1) & (ds[:, j]==0)).astype(int).reshape(rows,1), ((ds[:, i]==1) & (ds[:, j]==1)).astype(int).reshape(rows,1)))
    return exist_xy

def M_step(ds, m, n, k, step, prnt, pk):
    trees = []
    cond_probs = []
    for ki in range(k):
        prob_xy = prob_matrix(np.hstack((ds[:,:n], pk)), m, n, ki)
        #prob_xy = prob_matrix(ds, m, n, ki)
        I_xy = mutual_info(prob_xy, n)        
        tree = draw_tree(I_xy, prnt, ki, step)
        tree = [(tree[0][0], tree[0][0])] + tree
        trees.append(tree)
        cond_prob = np.zeros((len(tree), prob_xy.shape[2]))
        for idx, node in enumerate(tree):
            if node[0] == node[1]:
                #cond_prob[idx] = np.log(prob_xy[node[0], node[1],:])
                cond_prob[idx] = prob_xy[node[0], node[1],:]
            else:
                #cond_prob[idx] = np.log(np.hstack(((prob_xy[node[0], node[1],:2]/prob_xy[node[0], node[0], 0]),(prob_xy[node[0], node[1],2:]/prob_xy[node[0], node[0], 3]))))
                cond_prob[idx] = np.hstack(((prob_xy[node[0], node[1],:2]/prob_xy[node[0], node[0], 0]),(prob_xy[node[0], node[1],2:]/prob_xy[node[0], node[0], 3])))
        cond_probs.append(cond_prob)
    return trees, cond_probs

def random_init(ds, m, n, k, step, prnt):
    trees = []
    cond_probs = []
    for ki in range(k):
        
        prob_xy = prob_matrix(ds[np.random.choice(m, 8, replace=False), :], m, n, ki)

        I_xy = mutual_info(prob_xy, n)
        
        tree = draw_tree(I_xy, prnt, ki, step)
        tree = [(tree[0][0], tree[0][0])] + tree
        trees.append(tree)
        cond_prob = np.zeros((len(tree), prob_xy.shape[2]))
        for idx, node in enumerate(tree):
            if node[0] == node[1]:
                #cond_prob[idx] = np.log(prob_xy[node[0], node[1],:])
                cond_prob[idx] = prob_xy[node[0], node[1],:]
            else:
                #cond_prob[idx] = np.log(np.hstack(((prob_xy[node[0], node[1],:2]/prob_xy[node[0], node[0], 0]),(prob_xy[node[0], node[1],2:]/prob_xy[node[0], node[0], 3]))))
                cond_prob[idx] = np.hstack(((prob_xy[node[0], node[1],:2]/prob_xy[node[0], node[0], 0]),(prob_xy[node[0], node[1],2:]/prob_xy[node[0], node[0], 3])))
        cond_probs.append(cond_prob)
    return trees, cond_probs

def E_step(ds, m, n, k, trees, cond_probs):
    
    ph = ds[:, n:].sum(axis = 0)/m
    weight_ij = np.zeros((m, k))

    #for i in range(m):
    #    for j in range(k):
    #        weight_ij[i, j] = ph[j] * np.sum(count_matrix(ds[i:i+1, 0:n], trees[j], 4)*cond_probs[j])

    for j in range(k):
        if j < 0:
            print("ph[j] * np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1)")
            print(ph[j] * np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1))
            print("ph[j]" )
            print(ph[j] )
            print("np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1)")
            print(np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1))
            print("np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2)")
            print(np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2))
            print("exist_matrix(ds, trees[j], 4)*cond_probs[j]")
            print(exist_matrix(ds, trees[j], 4)*cond_probs[j])
            print("exist_matrix(ds, trees[j], 4)")
            print(exist_matrix(ds, trees[j], 4))
            print("cond_probs[j]")
            print(cond_probs[j])
        #weight_ij[:, j] = ph[j] * np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 1).sum(axis=1)
        weight_ij[:, j] = ph[j] * np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1)

    weight_ij = weight_ij/np.sum(weight_ij, axis = 1).reshape(m, 1)

    ll1 = 0
    #for i in range(m):
    #    for j in range(k):
    #        ll1 += (weight_ij[i, j] * (np.log(ph[j]) + np.log(np.sum(count_matrix(ds[i:i+1, 0:n], trees[j], 4)*cond_probs[j], axis = 1)).sum()))
    ll2 = 0
    for j in range(k):
        ll2 += (weight_ij[:, j] * (np.log(ph[j]) + np.log(np.sum(exist_matrix(ds, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1)))).sum()
        #print("LL2", j, ll2)
    print("LL:", ll1, ll2)
    print(ph)
    if math.isnan(ll2):        
        print(weight_ij)
    
    ds[:, n:] = weight_ij
    pk = weight_ij / np.sum(weight_ij, axis = 0).reshape(1, k)
    return ds, pk

if __name__ == "__main__":
    ds, m, n = loadfile("small-10-datasets/nltcs.ts.data")
    #ds, m, n = ds[:10, :3], 10, 3
    k = 2
    weight_ij = np.random.rand(m, k)
    weight_ij = weight_ij/np.sum(weight_ij, axis = 1).reshape(m, 1)
    ds = np.hstack((ds, weight_ij))
    trees, cond_probs = random_init(ds, m, n, k, 0, False)
    
    for step in range(1,100): 
        print("For step:", step)
        ds, pk = E_step(ds, m, n, k, trees, cond_probs)
        #print(pk)
        trees, cond_probs = M_step(ds, m, n, k, step, False, pk)        

    ts, m1, n1 = loadfile("small-10-datasets/nltcs.test.data")
    lambda_k = ds[:, n:].sum(axis = 0)/m
    L = 0
    ll1 = 0
    ll2 = 0
    for j in range(k):
        #t0=time.time()
        #L += ts.shape[0]*np.log(lambda_k[j]) + np.sum(count_matrix(ts, trees[j], 4)*np.log(cond_probs[j]))
        #t1=time.time()
        ll1 +=  np.log(lambda_k[j]*(np.sum(exist_matrix(ts, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1))).sum()
        #t2=time.time()
        #ll2 += (np.log(lambda_k[j]) + np.log(np.sum(exist_matrix(ts, trees[j], 4)*cond_probs[j], axis = 2).prod(axis=1))).sum()
        #t3=time.time()
        #print(t1-t0, t2-t1, t3-t2)
    print("LL:", L,"LL1", ll1/ts.shape[0], "LL2", ll2)
    
    #print(edge_wts[np.unravel_index(np.argmax(edge_wts), edge_wts.shape)] == edge_wts.max())

For step: 1
LL: 0 -248571.90424392378
[0.50177753 0.49822247]
For step: 2
LL: 0 -264402.0061215316
[0.61018722 0.38981278]
For step: 3
LL: 0 -262541.27951605595
[0.56277461 0.43722539]
For step: 4
LL: 0 -261697.61225406738
[0.51808563 0.48191437]
For step: 5
LL: 0 -261172.11036618234
[0.48796647 0.51203353]
For step: 6
LL: 0 -260852.95394825004
[0.46805048 0.53194952]
For step: 7
LL: 0 -260663.56373030823
[0.45483848 0.54516152]
For step: 8
LL: 0 -260475.8778139172
[0.44600014 0.55399986]
For step: 9
LL: 0 -260358.291273482
[0.43986292 0.56013708]
For step: 10
LL: 0 -260289.4733034735
[0.43582899 0.56417101]
For step: 11
LL: 0 -260246.32508314477
[0.4331333 0.5668667]
For step: 12
LL: 0 -260217.77663079102
[0.43132228 0.56867772]
For step: 13
LL: 0 -260198.0688652753
[0.43010082 0.56989918]
For step: 14
LL: 0 -260184.01303793804
[0.42927312 0.57072688]
For step: 15
LL: 0 -260173.73256083563
[0.42870902 0.57129098]
For step: 16
LL: 0 -260166.06393472708
[0.42832191 0.57167809]
For step:

In [122]:
k=2 : LL: 0 LL1 -38.08617375805374 LL2 0
k=3 : LL: 0 LL1 -99.32483270920115 LL2 0

-31.775642761127013

In [159]:
print(np.sum(count_matrix(ds[-3:-2, 0:n], trees[0], 4)*cond_probs[0]))

-6.513819707245405


In [111]:
np.array([0 1 0 0 0 0 1 0]).reshape[]

array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       ...,
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 1.]])