In [16]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook
import numpy as np
import geometric_clustering as gc
from graph_library import generate_SBM, assign_graph_metadata
import networkx as nx
import scipy as sc
import pylab as plt
import matplotlib.gridspec as gridspec
from matplotlib.animation import FFMpegWriter
from geometric_clustering import plotting
import matplotlib.colors as col

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Generate graph

In [25]:
graph, pos = generate_SBM(params = {'n':[30,40,35,50],
                                    'p':[[0.7, 0.1, 0.02, 0.02],
                                         [0.1, 0.8, 0.02, 0.02],
                                         [0.02, 0.02, 0.9, 0.1],
                                         [0.02, 0.02, 0.1, 0.6]],
                                    'seed': 1})        
graph = nx.convert_node_labels_to_integers(graph)
graph = assign_graph_metadata(graph, pos=pos)

# Plot graph and geodesics between two communities

In [26]:
plt.figure(figsize=(5,5))
nx.draw_networkx_nodes(graph,pos,node_size=80,node_color='k')
nx.draw_networkx_edges(graph,pos,edge_color='k',width=1,alpha=0.3)    
for i in range(30):
    path = nx.shortest_path(graph,source=i,target=30+i)
    path_edges = list(zip(path,path[1:]))
    nx.draw_networkx_nodes(graph,pos,nodelist=path,node_size=80,node_color='g')
    nx.draw_networkx_edges(graph,pos,edgelist=path_edges,edge_color='g',width=2)  
 
plt.box(on=None)
#plt.savefig('geodesics.svg')    

<IPython.core.display.Javascript object>

# Compute curvatures

In [27]:
times = np.logspace(-2, 2, 50)
kappas = gc.compute_curvatures(graph, times, )

  2%|▏         | 1/50 [00:01<01:03,  1.30s/it]

0.4348888397216797


  4%|▍         | 2/50 [00:01<00:45,  1.06it/s]

0.5586550235748291


  6%|▌         | 3/50 [00:02<00:39,  1.18it/s]

0.5825350284576416


  8%|▊         | 4/50 [00:03<00:36,  1.26it/s]

0.5967023372650146


 10%|█         | 5/50 [00:04<00:37,  1.21it/s]

0.756134033203125


 12%|█▏        | 6/50 [00:05<00:40,  1.09it/s]

0.9922580718994141


 14%|█▍        | 7/50 [00:06<00:46,  1.08s/it]

1.2725870609283447


 16%|█▌        | 8/50 [00:08<00:54,  1.30s/it]

1.6665589809417725


 18%|█▊        | 9/50 [00:10<01:02,  1.52s/it]

1.8863210678100586


 20%|██        | 10/50 [00:13<01:13,  1.85s/it]

2.4552218914031982


 22%|██▏       | 11/50 [00:15<01:20,  2.07s/it]

2.4654340744018555


 24%|██▍       | 12/50 [00:18<01:23,  2.21s/it]

2.3934381008148193


 26%|██▌       | 13/50 [00:20<01:26,  2.33s/it]

2.4737980365753174


 28%|██▊       | 14/50 [00:23<01:29,  2.49s/it]

2.7292377948760986


 30%|███       | 15/50 [00:26<01:31,  2.62s/it]

2.7976157665252686


 32%|███▏      | 16/50 [00:29<01:32,  2.72s/it]

2.8146181106567383


 34%|███▍      | 17/50 [00:33<01:37,  2.97s/it]

3.430461883544922


 36%|███▌      | 18/50 [00:37<01:43,  3.23s/it]

3.6996099948883057


 38%|███▊      | 19/50 [00:40<01:45,  3.40s/it]

3.6612141132354736


 40%|████      | 20/50 [00:44<01:46,  3.54s/it]

3.7416508197784424


 42%|████▏     | 21/50 [00:48<01:45,  3.63s/it]

3.7053871154785156


 44%|████▍     | 22/50 [00:52<01:42,  3.68s/it]

3.645810127258301


 46%|████▌     | 23/50 [00:55<01:39,  3.67s/it]

3.5333471298217773


 48%|████▊     | 24/50 [00:59<01:35,  3.69s/it]

3.603685140609741


 50%|█████     | 25/50 [01:03<01:32,  3.72s/it]

3.6450932025909424


 52%|█████▏    | 26/50 [01:07<01:30,  3.75s/it]

3.6999809741973877


 54%|█████▍    | 27/50 [01:11<01:27,  3.79s/it]

3.752208948135376


 56%|█████▌    | 28/50 [01:15<01:24,  3.84s/it]

3.8206851482391357


 58%|█████▊    | 29/50 [01:19<01:21,  3.88s/it]

3.8146419525146484


 60%|██████    | 30/50 [01:23<01:18,  3.92s/it]

3.843441963195801


 62%|██████▏   | 31/50 [01:26<01:13,  3.89s/it]

3.629126787185669


 64%|██████▍   | 32/50 [01:30<01:08,  3.83s/it]

3.5389821529388428


 66%|██████▌   | 33/50 [01:34<01:03,  3.73s/it]

3.3531341552734375


 68%|██████▊   | 34/50 [01:37<00:58,  3.65s/it]

3.2915520668029785


 70%|███████   | 35/50 [01:41<00:53,  3.59s/it]

3.304259777069092


 72%|███████▏  | 36/50 [01:44<00:49,  3.55s/it]

3.283776044845581


 74%|███████▍  | 37/50 [01:47<00:45,  3.49s/it]

3.1981959342956543


 76%|███████▌  | 38/50 [01:51<00:41,  3.44s/it]

3.152423858642578


 78%|███████▊  | 39/50 [01:54<00:37,  3.40s/it]

3.1437718868255615


 80%|████████  | 40/50 [01:58<00:34,  3.45s/it]

3.3725428581237793


 82%|████████▏ | 41/50 [02:01<00:31,  3.45s/it]

3.1987497806549072


 84%|████████▍ | 42/50 [02:04<00:27,  3.42s/it]

3.1533379554748535


 86%|████████▌ | 43/50 [02:07<00:23,  3.32s/it]

2.886530876159668


 88%|████████▊ | 44/50 [02:11<00:19,  3.27s/it]

2.947361946105957


 90%|█████████ | 45/50 [02:14<00:16,  3.21s/it]

2.8432979583740234


 92%|█████████▏| 46/50 [02:17<00:12,  3.17s/it]

2.845172882080078


 94%|█████████▍| 47/50 [02:20<00:09,  3.15s/it]

2.8456010818481445


 96%|█████████▌| 48/50 [02:23<00:06,  3.18s/it]

2.892444133758545


 98%|█████████▊| 49/50 [02:26<00:03,  3.23s/it]

2.970365047454834


100%|██████████| 50/50 [02:30<00:00,  3.01s/it]

2.93913197517395





In [72]:
plotting.plot_edge_curvatures(times, kappas, figsize=(4,3))
plt.xlabel('time')
plt.ylabel(r'$\kappa$')

plt.axvline(np.log10(times[28]))
plt.axvline(np.log10(times[33]))
plt.savefig('curvature_trajectories.svg')

<IPython.core.display.Javascript object>

In [73]:
plt.figure(figsize=(8,4))
plt.subplot(121)

kappa = kappas[28]
plotting.plot_graph(
    graph,
    edge_color=kappa,
    node_size=20,
    edge_width=1,
    node_colors='k',
    colormap="standard",
    vmin=-.5,
    vmax=0.5,
)

plt.subplot(122)

kappa = kappas[33]
plotting.plot_graph(
    graph,
    edge_color=kappa,
    node_size=20,
    edge_width=1,
    node_colors='k',
    colormap="standard",
    vmin=-.5,
    vmax=0.5,
)

plt.savefig('curvature_on_graph.svg')

<IPython.core.display.Javascript object>

# Compute geodesic distance matrix

In [74]:
dist = gc.curvature._compute_distance_geodesic(graph)

plt.figure(figsize=(4,3.5),dpi=300)
plt.imshow(dist,aspect='auto', origin='lower',cmap='Greys')

plt.xlabel('Node id')
plt.ylabel('Node id')
plt.colorbar(label=r'$d_{ij}$')
plt.axis('square')

plt.savefig('distance.svg', bbox_inches='tight')

<IPython.core.display.Javascript object>

# Functions to compute measures and make plots

In [58]:
# compute all neighbourhood densities
def mx_comp(graph, T, i):
    
    degrees = np.array([graph.degree[i] for i in graph.nodes])
    L = nx.laplacian_matrix(graph).dot(sc.sparse.diags(1.0 / degrees))
    
    N = len(graph.nodes)

    def delta(i, n):
        p0 = np.zeros(n)
        p0[i] = 1.
        return p0

    mx_all = [] 
    mx_tmp = delta(i, N) #set initial condition
    T = [0,] + list(T) #add time 0
    
    for i in range(len((T))-1): 
        #compute exponential by increments (faster than from 0)
        mx_tmp = sc.sparse.linalg.expm_multiply(-(T[i+1]-T[i])*L, mx_tmp)
        mx_all.append(sc.sparse.lil_matrix(mx_tmp))

    return mx_all


# compute curvature for an edge ij
def zeta_comp(mx_all, dist, it, e):
    import ot
    
    i, j = e[0], e[1]
    nt = len(mx_all[0][0])
    K = np.zeros(nt)

    Nx = np.array(mx_all[i][1][it]).flatten()
    Ny = np.array(mx_all[j][1][it]).flatten()
    mx = mx_all[i][0][it].toarray().flatten()
    my = mx_all[j][0][it].toarray().flatten()

    dNxNy = dist[Nx,:][:,Ny].copy(order='C')
    zeta = ot.emd(mx, my, dNxNy) 
            
    return zeta

# plot the curvature on the graph for a given time t
def plot_measure_graph(t, mx1, mx2, kappas, graph, pos, node_size = 100, edge_width = 1, ax=None):

    edge_vmin = -np.max(abs(kappas[:,t]))
    edge_vmax = np.max(abs(kappas[:,t]))
              
    vmin = 0# -np.max(abs(mx))
    vmax = 1#np.max(abs(mx))
    mx1 /= np.max(mx1)
    mx2 /= np.max(mx2)

    node_size1, node_size2 = mx1*node_size, mx2*node_size 

    from matplotlib.markers import MarkerStyle
    ms1 = MarkerStyle('o', fillstyle = 'left')
    ms2 = MarkerStyle('o', fillstyle = 'right')
    node_color1 = 'C0'
    node_color2 = 'C1'

    nodes = nx.draw_networkx_nodes(graph, pos = pos, node_size = node_size1, node_color = node_color1, 
                                   vmin = vmin, vmax = vmax, cmap=plt.get_cmap('viridis'),node_shape=ms1,
                                  ax=ax)
    
    nodes = nx.draw_networkx_nodes(graph, pos = pos, node_size = node_size2, node_color = node_color2, 
                                   vmin = vmin, vmax = vmax, cmap=plt.get_cmap('viridis'),node_shape=ms2,
                                   ax=ax)
    
    #edges = nx.draw_networkx_edges(graph, pos = pos, width = edge_width, alpha=0.3, ax=ax)

    limits = plt.axis('off') #turn axis off
    
    return ax

# Make video of diffusion evolution

In [None]:
i = 1
j = 62
    
mx_1 = mx_comp(graph, times, i)
mx_2 = mx_comp(graph, times, j)
    
fig = plt.figure(figsize=(5,4))
ax = plt.subplot(111)

edges = nx.draw_networkx_edges(graph, pos = pos, width = 1, alpha=0.3, ax=ax)
nx.draw_networkx_edges(graph, pos = pos, edgelist=[(i,j),], edge_color='r',width = 3,ax=ax)

metadata = dict(title='Movie Test', artist='Matplotlib',comment='Movie support!')
writer = FFMpegWriter(fps=1, metadata=metadata)
with writer.saving(fig, "diffusion_between.mp4", 100):
    for t in range(len(times)):       
        mx1, mx2 = mx_1[t].toarray().flatten(), mx_2[t].toarray().flatten()
        ax = plot_measure_graph(t, mx1, mx2, kappas, graph, pos, node_size = 100, edge_width = 1, ax=ax)
        plt.title('Diffusion time ' + str(np.round(times[t],2)))
        writer.grab_frame()

# Plot graph with snapshots of diffusion measures

In [67]:
plt.figure(figsize=(8,4))
plt.subplot(121)

t = 26
i = 2
j = 5

mx_1 = mx_comp(graph, times, i)
mx_2 = mx_comp(graph, times, j)
mx1, mx2 = mx_1[t].toarray().flatten(), mx_2[t].toarray().flatten()

plot_measure_graph(t, mx1, mx2, kappas, graph, pos, node_size = 1000, edge_width = 1)
nx.draw_networkx_edges(graph, pos = pos, width = 1, alpha=0.3)
nx.draw_networkx_edges(graph, pos = pos, edgelist=[(i,j),], edge_color='g',width = 3,alpha=0.2)
plt.title(np.log10(times[t]))


plt.subplot(122)

t = 26
i = 35
j = 5

mx_1 = mx_comp(graph, times, i)
mx_2 = mx_comp(graph, times, j)
mx1, mx2 = mx_1[t].toarray().flatten(), mx_2[t].toarray().flatten()

plot_measure_graph(t, mx1, mx2, kappas, graph, pos, node_size = 1000, edge_width = 1)
nx.draw_networkx_edges(graph, pos = pos, width = 1, alpha=0.3)
nx.draw_networkx_edges(graph, pos = pos, edgelist=[(i,j),], edge_color='g',width = 3,alpha=0.2)

plt.title(np.log10(times[t]))

plt.savefig('mxs.svg', bbox_inches='tight')

<IPython.core.display.Javascript object>

# Plot transport maps

In [None]:
import ot

def plot_transport_plan(zeta, mx1, mx2, ax1, ax2, ax3):       
   
    ax1.imshow((zeta.T), cmap='viridis', norm=col.Normalize(vmin=np.min(zeta), vmax=0.01*np.max(zeta)), aspect='auto', origin='lower')
    ax1.set_xlabel('Node id')
    ax1.set_ylabel('Node id')
  
    ax2.bar(np.arange(len(mx1)), mx1,color='C0',log=False)
    ax2.set_xlim(-0.5,len(mx1)-0.5)
    ax2.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=False)
    ax2.set_ylabel('$p_x$')
    ax2.set_ylim([0,0.02])
    
    ax3.barh(np.arange(len(mx2)),mx2,color='C1',log=False)
    ax3.set_ylim(-0.5,len(mx2)-0.5)
    ax3.set_xlabel('$p_y$')
    ax3.set_xlim([0,0.02])

    ax3.tick_params(
        axis='y',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        left=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelleft=False)
    
    ax1.axvline(29.5,c='w',ls='--',lw=0.8)
    ax1.axvline(29.5,c='w',ls='--',lw=0.8)
    ax1.axvline(69.5,c='w',ls='--',lw=0.8)
    ax1.axvline(69.5,c='w',ls='--',lw=0.8)
    ax1.axvline(104.5,c='w',ls='--',lw=0.8)
    ax1.axvline(104.5,c='w',ls='--',lw=0.8)
    
    ax1.axhline(29.5,c='w',ls='--',lw=0.8)
    ax1.axhline(29.5,c='w',ls='--',lw=0.8)
    ax1.axhline(69.5,c='w',ls='--',lw=0.8)
    ax1.axhline(69.5,c='w',ls='--',lw=0.8)
    ax1.axhline(104.5,c='w',ls='--',lw=0.8)
    ax1.axhline(104.5,c='w',ls='--',lw=0.8)
    
    return ax1, ax2, ax3

t = 20
i = 35
j = 5

mx_all = []
for k in range(nx.number_of_nodes(graph)):
    mx_all.append(mx_comp(graph, times, k))

mx = mx_all[i][t].toarray().flatten()
my = mx_all[j][t].toarray().flatten()

#dNxNy = dist[Nx,:][:,Ny].copy(order='C')
zeta = ot.emd(mx, my, dist) 

fig= plt.figure(figsize=(5,5), dpi=300)
gs = gridspec.GridSpec(2, 2, height_ratios = [ 0.2, 1], width_ratios = [1,0.2] )
gs.update(wspace=0.00)
gs.update(hspace=0)
ax1 = plt.subplot(gs[1, 0])
ax2 = plt.subplot(gs[0, 0])
ax3 = plt.subplot(gs[1, 1])

plot_transport_plan(zeta, 
                    mx_all[i][t].toarray().flatten(), 
                    mx_all[j][t].toarray().flatten(), 
                    ax1, ax2, ax3)

plt.savefig('zeta_between.svg', bbox_inches='tight')

In [None]:
import ot

t = 20
i = 2
j = 5

mx_all = []
for k in range(nx.number_of_nodes(graph)):
    mx_all.append(mx_comp(graph, times, k))

mx = mx_all[i][t].toarray().flatten()
my = mx_all[j][t].toarray().flatten()

zeta = ot.emd(mx, my, dist) 

fig= plt.figure(figsize=(5,5),dpi=300)
gs = gridspec.GridSpec(2, 2, height_ratios = [ 0.2, 1], width_ratios = [1,0.2] )
gs.update(wspace=0.00)
gs.update(hspace=0)
ax1 = plt.subplot(gs[1, 0])
ax2 = plt.subplot(gs[0, 0])
ax3 = plt.subplot(gs[1, 1])

plot_transport_plan(zeta, 
                    mx_all[i][t].toarray().flatten(), 
                    mx_all[j][t].toarray().flatten(), 
                    ax1, ax2, ax3)

plt.savefig('zeta_within.svg', bbox_inches='tight')