In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook
import numpy as np
import geometric_clustering as gc
from graph_library import generate_SBM, assign_graph_metadata
import networkx as nx
import scipy as sc
import pylab as plt
import matplotlib.gridspec as gridspec
from matplotlib.animation import FFMpegWriter
from geometric_clustering import plotting
import matplotlib.colors as col

Cupy is not installed, GPU functions will not work.


# Generate graph

In [11]:
graph, pos = generate_SBM(params = {'n':[20,30,25,15],
                                    'p':[[0.7, 0.1, 0.02, 0.02],
                                         [0.1, 0.8, 0.02, 0.02],
                                         [0.02, 0.02, 0.85, 0.1],
                                         [0.02, 0.02, 0.1, 0.6]],
                                    'seed': 1})        
graph = nx.convert_node_labels_to_integers(graph)
graph = assign_graph_metadata(graph, pos=pos)

# Plot graph and geodesics between two communities

In [15]:
plt.figure(figsize=(5,5))
nx.draw_networkx_nodes(graph,pos,node_size=80,node_color='k')
nx.draw_networkx_edges(graph,pos,edge_color='k',width=1,alpha=0.3)    
for i in range(20):
    path = nx.shortest_path(graph,source=i,target=30+i)
    path_edges = list(zip(path,path[1:]))
    nx.draw_networkx_nodes(graph,pos,nodelist=path,node_size=80,node_color='g')
    nx.draw_networkx_edges(graph,pos,edgelist=path_edges,edge_color='g',width=2)  
 
plt.box(on=None)
#plt.savefig('geodesics.svg')    

<IPython.core.display.Javascript object>

# Compute curvatures

In [13]:
times = np.logspace(-2, 2, 50)
kappas = gc.compute_curvatures(graph, times)

  2%|▏         | 1/50 [00:01<01:07,  1.37s/it]

0.3800079822540283


  4%|▍         | 2/50 [00:01<00:40,  1.20it/s]

0.37186408042907715


  6%|▌         | 3/50 [00:02<00:32,  1.46it/s]

0.4242990016937256


  8%|▊         | 4/50 [00:02<00:28,  1.60it/s]

0.46805906295776367


 10%|█         | 5/50 [00:03<00:26,  1.69it/s]

0.46149396896362305


 12%|█▏        | 6/50 [00:03<00:25,  1.75it/s]

0.4565920829772949


 14%|█▍        | 7/50 [00:04<00:24,  1.76it/s]

0.4892420768737793


 16%|█▌        | 8/50 [00:05<00:23,  1.78it/s]

0.47633934020996094


 18%|█▊        | 9/50 [00:05<00:22,  1.79it/s]

0.4693310260772705


 20%|██        | 10/50 [00:06<00:22,  1.80it/s]

0.47695064544677734


 22%|██▏       | 11/50 [00:06<00:21,  1.81it/s]

0.4668450355529785


 24%|██▍       | 12/50 [00:07<00:21,  1.80it/s]

0.4824397563934326


 26%|██▌       | 13/50 [00:07<00:20,  1.82it/s]

0.4593799114227295


 28%|██▊       | 14/50 [00:08<00:19,  1.85it/s]

0.4527301788330078


 30%|███       | 15/50 [00:08<00:19,  1.84it/s]

0.4677128791809082


 32%|███▏      | 16/50 [00:09<00:18,  1.82it/s]

0.4858710765838623


 34%|███▍      | 17/50 [00:09<00:17,  1.85it/s]

0.44066405296325684


 36%|███▌      | 18/50 [00:10<00:17,  1.87it/s]

0.44226837158203125


 38%|███▊      | 19/50 [00:10<00:16,  1.89it/s]

0.43166589736938477


 40%|████      | 20/50 [00:11<00:15,  1.89it/s]

0.4468119144439697


 42%|████▏     | 21/50 [00:12<00:15,  1.92it/s]

0.41704797744750977


 44%|████▍     | 22/50 [00:12<00:14,  1.95it/s]

0.4140470027923584


 46%|████▌     | 23/50 [00:12<00:13,  1.98it/s]

0.39484381675720215


 48%|████▊     | 24/50 [00:13<00:13,  1.96it/s]

0.4376230239868164


 50%|█████     | 25/50 [00:14<00:12,  1.97it/s]

0.4039480686187744


 52%|█████▏    | 26/50 [00:14<00:12,  1.96it/s]

0.4056379795074463


 54%|█████▍    | 27/50 [00:15<00:11,  1.97it/s]

0.4051547050476074


 56%|█████▌    | 28/50 [00:15<00:11,  1.92it/s]

0.43030285835266113


 58%|█████▊    | 29/50 [00:16<00:11,  1.88it/s]

0.44464635848999023


 60%|██████    | 30/50 [00:16<00:10,  1.88it/s]

0.4081439971923828


 62%|██████▏   | 31/50 [00:17<00:10,  1.89it/s]

0.4030017852783203


 64%|██████▍   | 32/50 [00:17<00:09,  1.88it/s]

0.40465497970581055


 66%|██████▌   | 33/50 [00:18<00:08,  1.92it/s]

0.37302494049072266


 68%|██████▊   | 34/50 [00:18<00:08,  1.91it/s]

0.36447787284851074


 70%|███████   | 35/50 [00:19<00:07,  1.92it/s]

0.35311317443847656


 72%|███████▏  | 36/50 [00:19<00:07,  1.89it/s]

0.3600127696990967


 74%|███████▍  | 37/50 [00:20<00:06,  1.86it/s]

0.369736909866333


 76%|███████▌  | 38/50 [00:20<00:06,  1.79it/s]

0.36411213874816895


 78%|███████▊  | 39/50 [00:21<00:06,  1.70it/s]

0.3815498352050781


 80%|████████  | 40/50 [00:22<00:06,  1.61it/s]

0.36803126335144043


 82%|████████▏ | 41/50 [00:23<00:05,  1.53it/s]

0.36321187019348145


 84%|████████▍ | 42/50 [00:23<00:05,  1.45it/s]

0.3544337749481201


 86%|████████▌ | 43/50 [00:24<00:05,  1.35it/s]

0.3405601978302002


 88%|████████▊ | 44/50 [00:25<00:05,  1.11it/s]

0.33255600929260254


 90%|█████████ | 45/50 [00:27<00:05,  1.09s/it]

0.3227870464324951


 92%|█████████▏| 46/50 [00:29<00:04,  1.24s/it]

0.29859495162963867


 94%|█████████▍| 47/50 [00:30<00:04,  1.40s/it]

0.2758042812347412


 96%|█████████▌| 48/50 [00:32<00:03,  1.55s/it]

0.2975947856903076


 98%|█████████▊| 49/50 [00:34<00:01,  1.67s/it]

0.2498610019683838


100%|██████████| 50/50 [00:36<00:00,  1.35it/s]

0.26582813262939453





In [14]:
plotting.plot_edge_curvatures(times, kappas, figsize=(4,3))
plt.xlabel('time')
plt.ylabel(r'$\kappa$')

plt.axvline(times[28])
plt.axvline(times[34])
#plt.savefig('curvature_trajectories.svg')

<IPython.core.display.Javascript object>

<matplotlib.lines.Line2D at 0x7fc638fb4850>

In [None]:
plt.figure(figsize=(8,4))
plt.subplot(121)

kappa = kappas[25]
plotting.plot_graph(
    graph,
    edge_color=kappa,
    node_size=20,
    edge_width=1,
    node_colors='k',
    colormap="standard",
    vmin=-.5,
    vmax=0.5,
)

plt.subplot(122)

kappa = kappas[34]
plotting.plot_graph(
    graph,
    edge_color=kappa,
    node_size=20,
    edge_width=1,
    node_colors='k',
    colormap="standard",
    vmin=-.5,
    vmax=0.5,
)

#plt.savefig('curvature_on_graph.svg')

# Compute geodesic distance matrix

In [None]:
dist = gc.curvature.compute_distance_geodesic(graph)

plt.figure(figsize=(4,3.5))
plt.imshow(dist,aspect='auto', origin='auto',cmap='Greys')

plt.axvline(6, c='C0',lw=3,ls='--')
plt.axhline(16, c='C1',lw=3,ls='--')

plt.xlabel('Node id')
plt.ylabel('Node id')
plt.colorbar(label=r'$d_{ij}$')
plt.axis('square')

#plt.savefig('distance.svg', bbox_inches='tight')

# Functions to compute measures and make plots

In [None]:
# compute all neighbourhood densities
def mx_comp(graph, T, i):
    
    degrees = np.array([graph.degree[i] for i in graph.nodes])
    L = nx.laplacian_matrix(graph).dot(sc.sparse.diags(1.0 / degrees))
    
    N = len(graph.nodes)

    def delta(i, n):
        p0 = np.zeros(n)
        p0[i] = 1.
        return p0

    mx_all = [] 
    mx_tmp = delta(i, N) #set initial condition
    T = [0,] + list(T) #add time 0
    
    for i in range(len((T))-1): 
        #compute exponential by increments (faster than from 0)
        mx_tmp = sc.sparse.linalg.expm_multiply(-(T[i+1]-T[i])*L, mx_tmp)
        mx_all.append(sc.sparse.lil_matrix(mx_tmp))

    return mx_all


# compute curvature for an edge ij
def zeta_comp(mx_all, dist, it, e):
    import ot
    
    i, j = e[0], e[1]
    nt = len(mx_all[0][0])
    K = np.zeros(nt)

    Nx = np.array(mx_all[i][1][it]).flatten()
    Ny = np.array(mx_all[j][1][it]).flatten()
    mx = mx_all[i][0][it].toarray().flatten()
    my = mx_all[j][0][it].toarray().flatten()

    dNxNy = dist[Nx,:][:,Ny].copy(order='C')
    zeta = ot.emd(mx, my, dNxNy) 
            
    return zeta

# plot the curvature on the graph for a given time t
def plot_measure_graph(t, mx1, mx2, kappas, graph, pos, node_size = 100, edge_width = 1, ax=None):

    edge_vmin = -np.max(abs(kappas[:,t]))
    edge_vmax = np.max(abs(kappas[:,t]))
              
    vmin = 0# -np.max(abs(mx))
    vmax = 1#np.max(abs(mx))
    mx1 /= np.max(mx1)
    mx2 /= np.max(mx2)

    node_size1, node_size2 = mx1*node_size, mx2*node_size 

    from matplotlib.markers import MarkerStyle
    ms1 = MarkerStyle('o', fillstyle = 'left')
    ms2 = MarkerStyle('o', fillstyle = 'right')
    node_color1 = 'C0'
    node_color2 = 'C1'

    nodes = nx.draw_networkx_nodes(graph, pos = pos, node_size = node_size1, node_color = node_color1, 
                                   vmin = vmin, vmax = vmax, cmap=plt.get_cmap('viridis'),node_shape=ms1,
                                  ax=ax)
    
    nodes = nx.draw_networkx_nodes(graph, pos = pos, node_size = node_size2, node_color = node_color2, 
                                   vmin = vmin, vmax = vmax, cmap=plt.get_cmap('viridis'),node_shape=ms2,
                                   ax=ax)
    
    #edges = nx.draw_networkx_edges(graph, pos = pos, width = edge_width, alpha=0.3, ax=ax)

    limits = plt.axis('off') #turn axis off
    
    return ax


def plot_transport_plan(zeta, mx1, mx2, ax1, ax2, ax3):       
   
    ax1.imshow((zeta.T), cmap='viridis', norm=col.Normalize(vmin=np.min(zeta), vmax=0.05*np.max(zeta)), aspect='auto', origin='auto')
    ax1.set_xlabel('Node id')
    ax1.set_ylabel('Node id')
  
    ax2.bar(np.arange(len(mx1)), mx1,color='C0',log=False)
    ax2.set_xlim(-0.5,len(mx1)-0.5)
    ax2.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=False)
    ax2.set_ylabel('$p_x$')
    ax2.set_ylim([0,0.03])
    
    ax3.barh(np.arange(len(mx2)),mx2,color='C1',log=False)
    ax3.set_ylim(-0.5,len(mx2)-0.5)
    ax3.set_xlabel('$p_y$')
    ax3.set_xlim([0,0.03])

    ax3.tick_params(
        axis='y',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        left=False,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelleft=False)

    ax1.axvline(29.5,c='w',ls='--',lw=0.8)
    ax1.axvline(29.5,c='w',ls='--',lw=0.8)
    ax1.axvline(59.5,c='w',ls='--',lw=0.8)
    ax1.axvline(59.5,c='w',ls='--',lw=0.8)
    ax1.axvline(89.5,c='w',ls='--',lw=0.8)
    ax1.axvline(89.5,c='w',ls='--',lw=0.8)
    
    ax1.axhline(29.5,c='w',ls='--',lw=0.8)
    ax1.axhline(29.5,c='w',ls='--',lw=0.8)
    ax1.axhline(59.5,c='w',ls='--',lw=0.8)
    ax1.axhline(59.5,c='w',ls='--',lw=0.8)
    ax1.axhline(89.5,c='w',ls='--',lw=0.8)
    ax1.axhline(89.5,c='w',ls='--',lw=0.8)
    
    return ax1, ax2, ax3

# Make video of diffusion evolution

In [None]:
i = 1
j = 62
    
mx_1 = mx_comp(graph, times, i)
mx_2 = mx_comp(graph, times, j)
    
fig = plt.figure(figsize=(5,4))
ax = plt.subplot(111)

edges = nx.draw_networkx_edges(graph, pos = pos, width = 1, alpha=0.3, ax=ax)
nx.draw_networkx_edges(graph, pos = pos, edgelist=[(i,j),], edge_color='r',width = 3,ax=ax)

metadata = dict(title='Movie Test', artist='Matplotlib',comment='Movie support!')
writer = FFMpegWriter(fps=1, metadata=metadata)
with writer.saving(fig, "diffusion_between.mp4", 100):
    for t in range(len(times)):       
        mx1, mx2 = mx_1[t].toarray().flatten(), mx_2[t].toarray().flatten()
        ax = plot_measure_graph(t, mx1, mx2, kappas, graph, pos, node_size = 100, edge_width = 1, ax=ax)
        plt.title('Diffusion time ' + str(np.round(times[t],2)))
        writer.grab_frame()

# Plot graph with snapshots of diffusion measures

In [None]:
plt.figure(figsize=(8,4))
plt.subplot(121)

t = 26
i = 1
j = 5

mx_1 = mx_comp(graph, times, i)
mx_2 = mx_comp(graph, times, j)
mx1, mx2 = mx_1[t].toarray().flatten(), mx_2[t].toarray().flatten()

plot_measure_graph(t, mx1, mx2, kappas, graph, pos, node_size = 1000, edge_width = 1)
nx.draw_networkx_edges(graph, pos = pos, width = 1, alpha=0.3)
nx.draw_networkx_edges(graph, pos = pos, edgelist=[(i,j),], edge_color='g',width = 3)
ax1.set_title(np.log10(times[t]))


plt.subplot(122)

t = 26
i = 1
j = 62

mx_1 = mx_comp(graph, times, i)
mx_2 = mx_comp(graph, times, j)
mx1, mx2 = mx_1[t].toarray().flatten(), mx_2[t].toarray().flatten()

plot_measure_graph(t, mx1, mx2, kappas, graph, pos, node_size = 1000, edge_width = 1)
nx.draw_networkx_edges(graph, pos = pos, width = 1, alpha=0.3)
nx.draw_networkx_edges(graph, pos = pos, edgelist=[(i,j),], edge_color='g',width = 3)

ax2.set_title(np.log10(times[t]))

#plt.savefig('mxs.svg', bbox_inches='tight')

# Plot transport maps

In [None]:
import ot

t = 25
i = 1
j = 5

mx_all = []
for k in range(nx.number_of_nodes(graph)):
    mx_all.append(mx_comp(graph, times, k))

mx = mx_all[i][t].toarray().flatten()
my = mx_all[j][t].toarray().flatten()

#dNxNy = dist[Nx,:][:,Ny].copy(order='C')
zeta = ot.emd(mx, my, dist) 

fig= plt.figure(figsize=(5,5))
gs = gridspec.GridSpec(2, 2, height_ratios = [ 0.2, 1], width_ratios = [1,0.2] )
gs.update(wspace=0.00)
gs.update(hspace=0)
ax1 = plt.subplot(gs[1, 0])
ax2 = plt.subplot(gs[0, 0])
ax3 = plt.subplot(gs[1, 1])

plot_transport_plan(zeta, 
                    mx_all[i][t].toarray().flatten(), 
                    mx_all[j][t].toarray().flatten(), 
                    ax1, ax2, ax3)

#plt.savefig('zeta_within.svg', bbox_inches='tight')

In [None]:
import ot

t = 25
i = 1
j = 62

mx_all = []
for k in range(nx.number_of_nodes(graph)):
    mx_all.append(mx_comp(graph, times, k))

mx = mx_all[i][t].toarray().flatten()
my = mx_all[j][t].toarray().flatten()

zeta = ot.emd(mx, my, dist) 

fig= plt.figure(figsize=(5,5))
gs = gridspec.GridSpec(2, 2, height_ratios = [ 0.2, 1], width_ratios = [1,0.2] )
gs.update(wspace=0.00)
gs.update(hspace=0)
ax1 = plt.subplot(gs[1, 0])
ax2 = plt.subplot(gs[0, 0])
ax3 = plt.subplot(gs[1, 1])

plot_transport_plan(zeta, 
                    mx_all[i][t].toarray().flatten(), 
                    mx_all[j][t].toarray().flatten(), 
                    ax1, ax2, ax3)

#plt.savefig('zeta_between.svg', bbox_inches='tight')