In [22]:
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt 
import random 
import networkx as nx 
import itertools 
import pickle 
from scipy.special import digamma
import timeit

# load data

In [23]:
# cora dataset
G_cora=nx.read_adjlist('data/Cora_enrich/idx_adjlist.txt',nodetype=int,create_using=nx.DiGraph)
# subgraph from G ,with nodes {0-99}
G_cora_mini=G_cora.subgraph(list(range(100)))
# cora texts
texts_cora=np.loadtxt('data/Cora_enrich/BOW_texts_3876.txt',dtype=np.int)
texts_cora_mini=texts_cora[:100,:]

# Block_PLSA

## utils

In [24]:
def index2ij(index,K):
    row=int(index/K) 
    col=index%K
    return row,col 

## input tranformation

In [25]:
# transfer input to observed variables
# PLSA part
ii,jj=np.nonzero(texts_cora)
WS=np.repeat(jj,texts_cora[ii,jj])
DS=np.repeat(ii,texts_cora[ii,jj])
# blockmodel part
SS=[]
RS=[]
for e in G_cora.edges:
    SS.append(e[0])
    RS.append(e[1])
SS=np.array(SS,dtype=np.int)
RS=np.array(RS,dtype=np.int)

In [26]:
# help utils
D_ids={}
for idx,d in enumerate(DS):
    if d in D_ids:
        D_ids[d].append(idx)
    else:
        D_ids[d]=[idx]
        
S_ids={}
for idx,s in enumerate(SS):
    if s in S_ids:
        S_ids[s].append(idx)
    else:
        S_ids[s]=[idx]
        
R_ids={}
for idx,r in enumerate(RS):
    if r in R_ids:
        R_ids[r].append(idx)
    else:
        R_ids[r]=[idx]
         
W_ids={}
for idx,w in enumerate(WS):
    if w in W_ids:
        W_ids[w].append(idx)
    else:
        W_ids[w]=[idx]

## initialization

In [38]:
# model hyperparameters 
alpha=1e-2
K=7
D=texts_cora.shape[0]
V=texts_cora.shape[1]
L=len(SS)
N=len(WS)
# runtime parameters
n_iter_EM=30
n_iter_VI=100
gamma_max_gap=0.2
phi_max_gap=K*V*0.0001

In [159]:
# Initialize EM parameters
omega=np.zeros((K,D))
omega[:]=1/D
phi=np.zeros((K,V))
phi[:]=1/V
pi=np.zeros(K)
pi[:]=1/K 
    
# initialize VI parameters
gamma=np.zeros(K**2)
gamma[:]=1e-1
delta=np.zeros((L,K**2))
delta[:]=1/K**2
epsilon=np.zeros((N,K))
epsilon[:]=1/K 

In [39]:
# Initialize EM parameters randomly (use diriclet to ensure normalization)
beta=1e-1

omega=np.zeros((K,D))
for k in range(K):
    omega[k,:]=stats.dirichlet.rvs(np.repeat(beta,D))
    
phi=np.zeros((K,V))
for k in range(K):
    phi[k,:]=stats.dirichlet.rvs(np.repeat(beta,V))
    
pi=stats.dirichlet.rvs(np.repeat(beta,K)).flatten()
    
# initialize VI parameters
gamma=stats.dirichlet.rvs(np.repeat(beta,K**2))

delta=np.zeros((L,K**2))
for l in range(L):
    delta[l,:]=stats.dirichlet.rvs(np.repeat(beta,K**2))

epsilon=np.zeros((N,K))
for n in range(N):
    epsilon[n,:]=stats.dirichlet.rvs(np.repeat(beta,K))

In [8]:
# Initialize parameters randomly
# EM parameters
omega=np.random.randint(1,10,(K,D))
omega=omega/omega.sum(axis=1)[:,np.newaxis]
phi=np.random.randint(1,10,(K,V))
phi=phi/phi.sum(axis=1)[:,np.newaxis]
pi=np.random.randint(1,10,K)
pi=pi/pi.sum() 
    
# VI parameters
gamma=np.repeat(1e-2,K**2)
delta=np.random.randint(1,10,(L,K**2))
delta=delta/delta.sum(axis=1)[:,np.newaxis]
epsilon=np.random.randint(1,10,(N,K))
epsilon=epsilon/epsilon.sum(axis=1)[:,np.newaxis]

In [40]:
# variational-EM
for it_em in range(n_iter_EM):
    # E-step
    for it_vi in range(n_iter_VI):
        gamma_last=gamma[:]
        # solve gamma&delta
        #gamma=np.array([alpha+delta[:,k].sum() for k in range(K**2)])
        
        gamma=delta.sum(axis=0)+alpha
        
        #delta=np.array([[omega[index2ij(k,K)[0],SS[l]]*omega[index2ij(k,K)[1],RS[l]]*np.exp(digamma(gamma[k])) for k in range(K**2)]\
                        #for l in range(L)])
        
        for k in range(K**2):
            delta[:,k]=omega[index2ij(k,K)[0],SS[range(L)]]*omega[index2ij(k,K)[1],RS[range(L)]]*np.exp(digamma(gamma[k]))
        
        delta=delta/delta.sum(axis=1)[:,np.newaxis]
        # check convergence 
        gamma_gap=np.abs(gamma-gamma_last).sum()
        print('iter:%d,gamma_gap:%f'%(it_vi,gamma_gap))
        if gamma_gap<gamma_max_gap:
            break
        
    # solve epsilon  
    start=timeit.default_timer()
    #epsilon=np.array([[omega[k,DS[n]]*phi[k,WS[n]]*pi[k] for k in range(K)] for n in range(N)])
    
    for k in range(K):
        epsilon[:,k]=omega[k,DS[range(N)]]*phi[k,WS[range(N)]]*pi[k]
    
    epsilon=epsilon/epsilon.sum(axis=1)[:,np.newaxis]
    end=timeit.default_timer()
    print('epsilon time:%s'%str(end-start))
    # M-step
    phi_last=phi.copy()
    pi_last=pi[:]
    omega_last=omega.copy() 
    # omega
    start=timeit.default_timer()
    S_dist=np.zeros((L,K))
    R_dist=np.zeros((L,K))
    for l in range(L):
        S_dist[l,:]=delta[l,:].reshape(K,K).sum(axis=1)
        R_dist[l,:]=delta[l,:].reshape(K,K).sum(axis=0)
    term_1=term_2=term_3=0
    
    for d in range(D):
        ep_id=D_ids.get(d,[])
        term_1=epsilon[ep_id,:].sum(axis=0)
        S_id=S_ids.get(d,[])
        term_2=S_dist[S_id,:].sum(axis=0)
        R_id=R_ids.get(d,[])
        term_3=R_dist[R_id,:].sum(axis=0)
        omega[:,d]=term_1+term_2+term_3
    omega=omega/omega.sum(axis=1)[:,np.newaxis]
    end=timeit.default_timer()
    print('omega time:%s'%str(end-start))
    # phi
    start=timeit.default_timer()
           
    for w in range(V):
        ep_id=W_ids.get(w,[])
        phi[:,w]=epsilon[ep_id,:].sum(axis=0)
    phi=phi/phi.sum(axis=1)[:,np.newaxis]
    end=timeit.default_timer()
    print('phi time:%s'%str(end-start))
    
    # pi
    start=timeit.default_timer()
    pi=epsilon.sum(axis=0)
    pi=pi/pi.sum()
    end=timeit.default_timer()
    print('pi time:%s'%str(end-start))
    # check convergence
    phi_gap=np.abs(phi-phi_last).sum()
    pi_gap=np.abs(pi-pi_last).sum()
    omega_gap=np.abs(omega-omega_last).sum() 
    print('iter:%d,phi_gap:%f,pi_gap:%f,omega_gap:%f'%(it_em,phi_gap,pi_gap,omega_gap))
    if phi_gap<phi_max_gap:
        #break 
        pass 

iter:0,gamma_gap:5428.490000
iter:1,gamma_gap:540.556730
iter:2,gamma_gap:236.172073
iter:3,gamma_gap:122.593474
iter:4,gamma_gap:72.095380
iter:5,gamma_gap:46.625527
iter:6,gamma_gap:31.077203
iter:7,gamma_gap:21.095841
iter:8,gamma_gap:14.380286
iter:9,gamma_gap:9.847435
iter:10,gamma_gap:6.788282
iter:11,gamma_gap:4.715963
iter:12,gamma_gap:3.301267
iter:13,gamma_gap:2.338076
iter:14,gamma_gap:1.662006
iter:15,gamma_gap:1.185525
iter:16,gamma_gap:0.848382
iter:17,gamma_gap:0.608938
iter:18,gamma_gap:0.438287
iter:19,gamma_gap:0.316270
iter:20,gamma_gap:0.228765
iter:21,gamma_gap:0.166390
epsilon time:3.846603199999663
omega time:0.26490259999991395
phi time:0.23404470000059518
pi time:0.030825899999399553
iter:0,phi_gap:10.132278,pi_gap:0.878802,omega_gap:8.996097
iter:0,gamma_gap:0.121542
epsilon time:3.868558199999825
omega time:0.2519212000006519
phi time:0.2196524999999383
pi time:0.03331689999868104
iter:1,phi_gap:1.261946,pi_gap:0.022776,omega_gap:1.244581
iter:0,gamma_gap:938

iter:51,gamma_gap:6.518309
iter:52,gamma_gap:6.448984
iter:53,gamma_gap:6.384037
iter:54,gamma_gap:6.323259
iter:55,gamma_gap:6.266344
iter:56,gamma_gap:6.216869
iter:57,gamma_gap:6.170935
iter:58,gamma_gap:6.125548
iter:59,gamma_gap:6.089063
iter:60,gamma_gap:6.038004
iter:61,gamma_gap:5.920232
iter:62,gamma_gap:5.593439
iter:63,gamma_gap:5.347149
iter:64,gamma_gap:5.268627
iter:65,gamma_gap:5.144796
iter:66,gamma_gap:4.840234
iter:67,gamma_gap:4.490403
iter:68,gamma_gap:4.405261
iter:69,gamma_gap:4.358573
iter:70,gamma_gap:4.315595
iter:71,gamma_gap:4.272697
iter:72,gamma_gap:4.231510
iter:73,gamma_gap:4.204241
iter:74,gamma_gap:4.176097
iter:75,gamma_gap:4.147485
iter:76,gamma_gap:4.118705
iter:77,gamma_gap:4.090329
iter:78,gamma_gap:4.062881
iter:79,gamma_gap:4.035283
iter:80,gamma_gap:4.007763
iter:81,gamma_gap:3.980495
iter:82,gamma_gap:3.953613
iter:83,gamma_gap:3.932207
iter:84,gamma_gap:3.913556
iter:85,gamma_gap:3.895309
iter:86,gamma_gap:3.878230
iter:87,gamma_gap:3.861737
i

iter:33,gamma_gap:9.163090
iter:34,gamma_gap:8.917180
iter:35,gamma_gap:8.683331
iter:36,gamma_gap:8.460715
iter:37,gamma_gap:8.248495
iter:38,gamma_gap:8.046040
iter:39,gamma_gap:7.852769
iter:40,gamma_gap:7.668145
iter:41,gamma_gap:7.491671
iter:42,gamma_gap:7.322888
iter:43,gamma_gap:7.161367
iter:44,gamma_gap:7.006710
iter:45,gamma_gap:6.858542
iter:46,gamma_gap:6.716516
iter:47,gamma_gap:6.580300
iter:48,gamma_gap:6.449586
iter:49,gamma_gap:6.325083
iter:50,gamma_gap:6.208964
iter:51,gamma_gap:6.097139
iter:52,gamma_gap:5.989363
iter:53,gamma_gap:5.885387
iter:54,gamma_gap:5.784948
iter:55,gamma_gap:5.687755
iter:56,gamma_gap:5.593457
iter:57,gamma_gap:5.501593
iter:58,gamma_gap:5.412185
iter:59,gamma_gap:5.324085
iter:60,gamma_gap:5.234133
iter:61,gamma_gap:5.136716
iter:62,gamma_gap:5.015044
iter:63,gamma_gap:4.809716
iter:64,gamma_gap:4.414581
iter:65,gamma_gap:4.169063
iter:66,gamma_gap:4.091748
iter:67,gamma_gap:4.026082
iter:68,gamma_gap:3.959705
iter:69,gamma_gap:3.890655
i

iter:19,gamma_gap:4.407073
iter:20,gamma_gap:4.116987
iter:21,gamma_gap:3.859175
iter:22,gamma_gap:3.621072
iter:23,gamma_gap:3.400757
iter:24,gamma_gap:3.196562
iter:25,gamma_gap:3.007026
iter:26,gamma_gap:2.830866
iter:27,gamma_gap:2.666943
iter:28,gamma_gap:2.514246
iter:29,gamma_gap:2.371869
iter:30,gamma_gap:2.239001
iter:31,gamma_gap:2.114907
iter:32,gamma_gap:1.998925
iter:33,gamma_gap:1.890452
iter:34,gamma_gap:1.788941
iter:35,gamma_gap:1.698878
iter:36,gamma_gap:1.620533
iter:37,gamma_gap:1.546859
iter:38,gamma_gap:1.477542
iter:39,gamma_gap:1.412291
iter:40,gamma_gap:1.353517
iter:41,gamma_gap:1.305401
iter:42,gamma_gap:1.259962
iter:43,gamma_gap:1.217040
iter:44,gamma_gap:1.176488
iter:45,gamma_gap:1.138167
iter:46,gamma_gap:1.101948
iter:47,gamma_gap:1.068813
iter:48,gamma_gap:1.042663
iter:49,gamma_gap:1.017891
iter:50,gamma_gap:0.994427
iter:51,gamma_gap:0.972205
iter:52,gamma_gap:0.958536
iter:53,gamma_gap:0.947835
iter:54,gamma_gap:0.940122
iter:55,gamma_gap:0.933319
i

iter:6,gamma_gap:7.476434
iter:7,gamma_gap:6.199000
iter:8,gamma_gap:5.377202
iter:9,gamma_gap:4.750978
iter:10,gamma_gap:4.381888
iter:11,gamma_gap:4.087976
iter:12,gamma_gap:3.836157
iter:13,gamma_gap:3.617652
iter:14,gamma_gap:3.434231
iter:15,gamma_gap:3.282941
iter:16,gamma_gap:3.146164
iter:17,gamma_gap:3.039537
iter:18,gamma_gap:2.937988
iter:19,gamma_gap:2.841673
iter:20,gamma_gap:2.750547
iter:21,gamma_gap:2.664448
iter:22,gamma_gap:2.583156
iter:23,gamma_gap:2.506420
iter:24,gamma_gap:2.433978
iter:25,gamma_gap:2.365573
iter:26,gamma_gap:2.300952
iter:27,gamma_gap:2.239879
iter:28,gamma_gap:2.182127
iter:29,gamma_gap:2.127487
iter:30,gamma_gap:2.075764
iter:31,gamma_gap:2.026776
iter:32,gamma_gap:1.980355
iter:33,gamma_gap:1.936345
iter:34,gamma_gap:1.894604
iter:35,gamma_gap:1.854998
iter:36,gamma_gap:1.817405
iter:37,gamma_gap:1.781710
iter:38,gamma_gap:1.747810
iter:39,gamma_gap:1.715606
iter:40,gamma_gap:1.685009
iter:41,gamma_gap:1.655934
iter:42,gamma_gap:1.628305
iter:

iter:2,gamma_gap:36.768359
iter:3,gamma_gap:20.481133
iter:4,gamma_gap:12.543091
iter:5,gamma_gap:7.995408
iter:6,gamma_gap:5.459188
iter:7,gamma_gap:3.928524
iter:8,gamma_gap:2.858452
iter:9,gamma_gap:2.101348
iter:10,gamma_gap:1.559366
iter:11,gamma_gap:1.169230
iter:12,gamma_gap:0.898071
iter:13,gamma_gap:0.694936
iter:14,gamma_gap:0.541617
iter:15,gamma_gap:0.425045
iter:16,gamma_gap:0.335783
iter:17,gamma_gap:0.266966
iter:18,gamma_gap:0.214170
iter:19,gamma_gap:0.179128
epsilon time:4.099353999999948
omega time:0.2674771999991208
phi time:0.23189339999953518
pi time:0.03148030000011204
iter:21,phi_gap:0.259504,pi_gap:0.006282,omega_gap:0.278849
iter:0,gamma_gap:0.151288
epsilon time:4.116052599998511
omega time:0.2955658000009862
phi time:0.2662694000009651
pi time:0.03506319999905827
iter:22,phi_gap:0.249134,pi_gap:0.006145,omega_gap:0.270175
iter:0,gamma_gap:50.955558
iter:1,gamma_gap:71.464123
iter:2,gamma_gap:34.212588
iter:3,gamma_gap:18.983940
iter:4,gamma_gap:11.809394
ite

In [41]:
# derive expectation from posterior (variational) distribution
theta=gamma/gamma.sum()
theta=theta.reshape(K,K)

# evaluation 

## utils

In [31]:
with open('data/Cora_enrich/tokens_3876.pickle','rb') as f:
    tokens=pickle.load(f)
tokens=np.array(tokens)

In [32]:
labels=[]
with open('data/Cora_enrich/labels.txt') as f:
    for line in f:
        labels.append(line.strip())
labels=np.array(labels)
labels_mini=labels[:100]

In [33]:
def get_top_tokens(phi,tokens,top=10):
    results=[]
    for i in range(phi.shape[0]):
        results.append(tokens[np.argsort(-phi[i,:])[:top]])
    return results 

In [34]:
def get_top_docs(omega,labels,top=10):
    results=[]
    for i in range(omega.shape[0]):
        results.append(labels[np.argsort(-omega[i,:])[:top]])
    return results 

## evaluate 

In [42]:
theta 

array([[1.54988442e-01, 1.84179361e-06, 1.84179361e-06, 1.84179361e-06,
        1.84179361e-06, 1.84179361e-06, 1.84179361e-06],
       [1.84179361e-06, 2.11502536e-01, 1.84179361e-06, 1.84179361e-06,
        1.84179361e-06, 1.84179361e-06, 1.84179361e-06],
       [1.84179361e-06, 1.84179361e-06, 1.53836132e-01, 1.84179361e-06,
        3.72236338e-02, 1.84179361e-06, 1.84179361e-06],
       [1.84179361e-06, 1.84179361e-06, 1.84179361e-06, 1.26098837e-01,
        1.84179361e-06, 1.84179361e-06, 1.84179361e-06],
       [1.84179361e-06, 1.84179361e-06, 1.84179361e-06, 1.84179361e-06,
        1.84179361e-06, 1.84179361e-06, 1.84179361e-06],
       [1.84179361e-06, 1.84179361e-06, 1.84179361e-06, 1.84179361e-06,
        1.84179361e-06, 7.55227321e-02, 1.84179361e-06],
       [1.84179361e-06, 6.68002006e-02, 1.14633733e-02, 5.73865877e-02,
        1.84179361e-06, 1.84179361e-06, 1.05105695e-01]])

In [46]:
get_top_tokens(phi,tokens,20)

[array(['class', 'distribut', 'probabl', 'bound', 'case', 'theori',
        'concept', 'given', 'instanc', 'queri', 'bayesian', 'sampl',
        'condit', 'follow', 'consid', 'theorem', 'number', 'work', 'relat',
        'definit'], dtype='<U15'),
 array(['network', 'data', 'train', 'neural', 'featur', 'tree', 'input',
        'et', 'weight', 'al', 'estim', 'rule', 'error', 'valu', 'perform',
        'test', 'classif', 'predict', 'linear', 'classifi'], dtype='<U15'),
 array(['gen', 'knowledg', 'program', 'domain', 'rule', 'task', 'process',
        'robot', 'search', 'state', 'solv', 'agent', 'strategi', 'action',
        'induct', 'control', 'represent', 'design', 'behavior', 'oper'],
       dtype='<U15'),
 array(['genet', 'number', 'program', 'fit', 'perform', 'sequenc', 'ga',
        'size', 'optim', 'popul', 'select', 'point', 'code', 'data',
        'time', 'neural', 'comput', 'error', 'studi', 'work'], dtype='<U15'),
 array(['case', 'reason', 'plan', 'explan', 'similar', 'adapt',

In [44]:
get_top_docs(omega,labels,10)

[array(['Theory', 'Probabilistic_Methods', 'Theory', 'Rule_Learning',
        'Probabilistic_Methods', 'Theory', 'Theory', 'Theory', 'Theory',
        'Probabilistic_Methods'], dtype='<U22'),
 array(['Theory', 'Neural_Networks', 'Probabilistic_Methods', 'Theory',
        'Neural_Networks', 'Theory', 'Neural_Networks', 'Neural_Networks',
        'Theory', 'Theory'], dtype='<U22'),
 array(['Genetic_Algorithms', 'Reinforcement_Learning', 'Rule_Learning',
        'Reinforcement_Learning', 'Reinforcement_Learning',
        'Reinforcement_Learning', 'Case_Based', 'Theory',
        'Genetic_Algorithms', 'Theory'], dtype='<U22'),
 array(['Genetic_Algorithms', 'Genetic_Algorithms', 'Genetic_Algorithms',
        'Genetic_Algorithms', 'Genetic_Algorithms', 'Neural_Networks',
        'Genetic_Algorithms', 'Genetic_Algorithms', 'Neural_Networks',
        'Genetic_Algorithms'], dtype='<U22'),
 array(['Case_Based', 'Probabilistic_Methods', 'Case_Based', 'Case_Based',
        'Case_Based', 'Case_Based

In [45]:
pi 

array([0.17951863, 0.21200674, 0.15018488, 0.2016221 , 0.06841285,
       0.1479296 , 0.04032521])