In [None]:
from collections import Counter
import numpy as np
import pandas as pd
import igraph as ig
import itertools as it
import random
from scipy.special import loggamma
from pyitlib import discrete_random_variable as drv
import matplotlib.pyplot as plt
from joblib import Parallel,delayed
import pickle
import time
import itertools
import glob
from matplotlib import cm
import graph_tool as gt
from graph_tool.all import *


def synthetic_reconstruction(S,N,modes,alphas,betas,pis,K0,n_fails,n_reruns):
    
    Ktrue = len(modes)
    all_mode_dists,all_label_dists,all_Ls,all_Ks = [],[],[],[]
    for _ in range(n_reruns):
        D,true_labels = generate_synthetic(S,N,modes,alphas,betas,pis)
        MDLobj = MDL_populations(D,N,K0,n_fails)
        MDLobj.initialize_clusters()
        C,A,L = MDLobj.run_sims()
        inferred_labels = [p[1] for p in sorted({node:c for c in C for node in C[c] }.items(),key=lambda x: x[0])]
        label_dist = drv.information_variation(true_labels,inferred_labels)
        mode_dists = []
        for ki in A:
            mode_dists_ki = []
            for kt in range(Ktrue):
                hamming = 2*(len(modes[kt] - A[ki]) + len(A[ki] - modes[kt]))
                mode_dists_ki.append(hamming)
            mode_dists.append(min(mode_dists_ki))
        mode_dist = np.mean(mode_dists)#sum(sorted(mode_dists)[:2])  
        all_mode_dists.append(mode_dist)
        all_label_dists.append(label_dist)
        all_Ls.append(L)
        all_Ks.append(len(A))
       
    return [np.mean(all_mode_dists),np.mean(all_label_dists),np.mean(all_Ls),np.mean(all_Ks)]

def synthetic_reconstruction_contiguous(true_labels,N,modes,alphas,betas,n_reruns):
    
    Ktrue = len(modes)
    S = len(true_labels)
    all_mode_dists,all_label_dists,all_Ls,all_Ks = [],[],[],[]
    for _ in range(n_reruns):
        D = [generate_synthetic(1,N,[modes[l%Ktrue]],[alphas[l%Ktrue]],[betas[l%Ktrue]],[1.])[0][0] for l in true_labels]
        MDLobj = MDL_populations(D,N,K0=np.inf,n_fails=np.inf)
        C,A,L = MDLobj.dynamic_contiguous()
        inferred_labels = [p[1] for p in sorted({node:c for c in C for node in C[c] }.items(),key=lambda x: x[0])]
        label_dist = drv.information_variation(true_labels,inferred_labels)

        mode_dists = []
        for ki in A:
            mode_dists_ki = []
            for kt in range(Ktrue):
                hamming = 2*(len(modes[kt] - A[ki]) + len(A[ki] - modes[kt]))
                mode_dists_ki.append(hamming)
            mode_dists.append(min(mode_dists_ki))
        mode_dist = np.mean(mode_dists)#sum(sorted(mode_dists)[:2])  
        all_mode_dists.append(mode_dist)
        all_label_dists.append(label_dist)
        all_Ls.append(L)
        all_Ks.append(len(A))
   
    return [np.mean(all_mode_dists),np.mean(all_label_dists),np.mean(all_Ls),np.mean(all_Ks)]

In [None]:
# flip_ps = np.linspace(0.01,0.5,20)
# modes = [{(0,1),(1,5),(4,5),(0,4),(0,5),(1,4),(1,2),(3,7),(6,7)},\
#              {(1,2),(2,6),(5,6),(1,5),(1,6),(2,5),(0,1),(0.4),(2,3)},\
#              {(2,3),(3,7),(6,7),(2,6),(2,7),(3,6),(4,5),(1,5),(5,6)}]
# N = 8
# modes = [modes[0],modes[2]]
# all_recovery_results = {}
# param_combos = [(100,'non'),(100,'cont2'),(100,'cont4')]
# for combo in param_combos:
    
    
#     S = combo[0]
#     if combo[1] == 'non':
#         start = time.time()
#         all_recovery_results[combo] = Parallel(n_jobs=10)(delayed(synthetic_reconstruction)(S=S,N=N,modes=modes,alphas=(1-p)*np.ones(2),\
#                                                                              betas=p*np.ones(2),pis=[0.5,0.5],\
#                                                                                K0=1,n_fails=500,n_reruns=200)\
#                                                                                for p in flip_ps)
#         duration = time.time() - start
    
#     elif combo[1] == 'cont2':
#         true_labels = [0]*int(S/2) + [1]*int(S/2)
#         all_recovery_results[combo] = Parallel(n_jobs=10)(delayed(synthetic_reconstruction_contiguous)(true_labels=true_labels,N=N,modes=modes,\
#                                                                                 alphas=(1-p)*np.ones(2),\
#                                                                              betas=p*np.ones(2),n_reruns=200)\
#                                                                                for p in flip_ps)
    
#     elif combo[1] == 'cont4':
#         true_labels = [0]*int(S/4) + [1]*int(S/4) + [2]*int(S/4) + [3]*int(S/4)
#         all_recovery_results[combo] = Parallel(n_jobs=10)(delayed(synthetic_reconstruction_contiguous)(true_labels=true_labels,N=N,modes=modes,\
#                                                                                 alphas=(1-p)*np.ones(2),\
#                                                                              betas=p*np.ones(2),n_reruns=200)\
#                                                                                for p in flip_ps)
    
# with open('results/reconstructions.pkl','wb') as f:
#     pickle.dump(all_recovery_results,f)
    
with open('results/reconstructions.pkl','rb') as f:
    all_recovery_results = pickle.load(f)
    
with open('results/recovery_new_modes_equal_density.pkl','rb') as f:
    gibbs_results = pickle.load(f)

In [None]:
#colors = ['orangered','dodgerblue','gold','mediumpurple','limegreen','lightpink']
#colors = ['lightcoral','deepskyblue','limegreen']
colors = ['orangered','dodgerblue','gold','mediumpurple']
markers = ['o','^','s']
curve_labels = [r'discontiguous',r'contiguous, $K=2$',r'contiguous, $K=4$']
yaxislabels = ['Network distance','Partition distance','Inverse compression ratio','Number of clusters']
yticks_arr = [[0,5,10,15],[0,0.5,1,1.5,2,2.5,3],[0,0.2,0.4,0.6,0.8,1,1.2],[0,2,4,6,8,10]]
flip_ps = np.linspace(0.01,0.5,20)
param_combos = [(100,'non'),(100,'cont2'),(100,'cont4')]

fontProperties = {'family':'Times New Roman', 'size' : 18}
fig, ax = plt.subplots(4,1,figsize=(6, 18),sharex=True,sharey=False)
plt.rcParams["mathtext.fontset"] = "cm"

for y in range(4):
    for v in range(3):
        combo = param_combos[v]
        if y == 0: 
            ax[y].plot(flip_ps,[r[y]/2 for r in all_recovery_results[combo]],\
                   marker=markers[0],markerfacecolor=colors[v],markeredgecolor='k',color=colors[v],label=curve_labels[v])
        else:
            ax[y].plot(flip_ps,[r[y] for r in all_recovery_results[combo]],\
                   marker=markers[0],markerfacecolor=colors[v],markeredgecolor='k',color=colors[v])
        
        
    ax[y].tick_params(which='major',left=True, bottom=True,top=True,right=True, labelleft=True, labelbottom=True)
    ax[y].tick_params(which='minor',left=True, bottom=True,top=True,right=True)
    ax[y].tick_params(direction='in',labelsize=18,which='major',pad=10,length=12)
    ax[y].tick_params(direction='in',labelsize=18,which='minor',pad=10,length=6)
    if y == 2:
        ax[y].axhline(1,linewidth=1.5,c='k',linestyle='--')
    if y == 3:
        ax[y].set_ylabel(yaxislabels[y],fontdict=fontProperties,labelpad=14)
    else:
        ax[y].set_ylabel(yaxislabels[y],fontdict=fontProperties,labelpad=10)
    
    ax[y].set_yticks(yticks_arr[y])
    ax[y].set_yticklabels([r'$'+str(np.round(i,2))+'$' for i in yticks_arr[y]], fontproperties=fontProperties)
    if y == 3:
        ax[y].set_xlabel(r'Flip probability',fontdict=fontProperties,labelpad=15)
        ax[y].axhline(2,linewidth=1.5,c='k',linestyle='--',label=r'$K=2$')
        ax[y].axhline(4,linewidth=1.5,c='k',linestyle=':',label=r'$K=4$')
    ax[y].set_xticks([0,0.1,0.2,0.3,0.4,0.5])
    ax[y].set_xticklabels([r'$'+str(np.round(i,2))+'$' for i in [0,0.1,0.2,0.3,0.4,0.5]], fontproperties=fontProperties)
    #ax[y].set_ylim([-0.1,45.1])
    ax[y].set_xlim([-0.01,0.51])
    ax[y].minorticks_on()
    ax[y].text(-0.3,1,['(a)','(b)','(c)','(d)'][y],fontproperties={'family':'Times New Roman', 'size' : 20},\
               transform = ax[y].transAxes)

    
ax[0].plot(flip_ps,[np.mean(r[0])/4 for r in gibbs_results[0]][2::4],marker=markers[0],\
         markerfacecolor='mediumpurple',markeredgecolor='k',color='mediumpurple',label='Gibbs')
ax[1].plot(flip_ps,[np.mean(r[1]) for r in gibbs_results[0]][2::4],marker=markers[0],\
         markerfacecolor='mediumpurple',markeredgecolor='k',color='mediumpurple')
ax[0].axhline(14,linewidth=1.5,c='k',linestyle='--',label='coin flip')
leg = ax[0].legend(ncol=1,prop=fontProperties,frameon=False,edgecolor='black',\
                       fancybox=False,bbox_to_anchor=(0.3,0.2,0.35,0.23))
leg.get_frame().set_alpha(1.)

ax[2].axhline(1,linewidth=1.5,c='k',linestyle='--',label='Naive transmission')
leg = ax[2].legend(ncol=1,prop=fontProperties,frameon=False,edgecolor='black',\
                       fancybox=False,bbox_to_anchor=(0.6,0.0,0.35,0.23))
leg.get_frame().set_alpha(1.)


leg = ax[-1].legend(ncol=1,prop=fontProperties,frameon=False,edgecolor='black',\
                       fancybox=False,bbox_to_anchor=(0.1,0.75,0.35,0.23))
leg.get_frame().set_alpha(1.)
ax[2].yaxis.set_label_coords(-0.12, 0.4)
plt.savefig('figs/reconstruction.pdf',bbox='tight')