<a href="https://colab.research.google.com/github/RFesser/hello-world/blob/master/Curvature%20comparison%20orc-frc-afrc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
from time import perf_counter
import numpy as np

from GraphRicciCurvature.FormanRicci import FormanRicci
from GraphRicciCurvature.OllivierRicci import OllivierRicci


In [None]:
def plot_my_graph(G, pos, ax = None, node_col = "white", 
                  edge_lst = [], edge_col = "lightgrey", edge_lab = {},
                  bbox = None, color_map = "Set3", alpha = 1.0):
    node_options = {
        "font_size": 12, 
        "font_color": "black",
        "node_size": 300, 
        "cmap": plt.get_cmap(color_map),
        "alpha": alpha,
        "edgecolors": "black",
        "linewidths": 0.5,   
        "with_labels": True,
        "edgelist": None
        }
    edge_options = {
        "width": 0.5
        }
    fig = plt.figure(figsize=(15,15))
    nx.draw_networkx (G, pos, node_color = node_col, edge_color = edge_col, **node_options)
    nx.draw_networkx_edges (G, pos, edge_lst, edge_color = edge_col, **edge_options)
    nx.draw_networkx_edge_labels(G, pos, label_pos = 0.5, edge_labels = edge_lab, rotate=False, bbox = bbox)
    plt.gca().margins(0.20)
    plt.show()

In [None]:
def remove_permutations(ll):
    i = 0
    z = len(ll)
    while i < z:
        j = i + 1
        a = sorted(ll[i])
        while j < z:
            # falls Permutation, dann entfernen.  Liste wird dadurch kürzer, daher jedesmal len(ll) überprüfen 
            b = sorted(ll[j])
            if a == b:
                ll.pop(j)
                z = len(ll)
                break
            else:
                j += 1
        i += 1
    return ll

In [None]:
def simple_cycles(G, limit):
    subG = type(G)(G.edges())
    sccs = list(nx.strongly_connected_components(subG))
    while sccs:
        scc = sccs.pop()
        startnode = scc.pop()
        path = [startnode]
        blocked = set()
        blocked.add(startnode)
        stack = [(startnode, list(subG[startnode]))]

        while stack:
            thisnode, nbrs = stack[-1]

            if nbrs and len(path) < limit:
                nextnode = nbrs.pop()
                if nextnode == startnode:
                    yield path[:]
                elif nextnode not in blocked:
                    path.append(nextnode)
                    stack.append((nextnode, list(subG[nextnode])))
                    blocked.add(nextnode)
                    continue
            if not nbrs or len(path) >= limit:
                blocked.remove(thisnode)
                stack.pop()
                path.pop()
        subG.remove_node(startnode)
        H = subG.subgraph(scc)
        sccs.extend(list(nx.strongly_connected_components(H)))

In [None]:
def fr_curvature (G, ni, nj):
    '''
    computes the Forman-Ricci curvature of a given edge 
    
    Parameters
    ----------
    G : Graph
    ni : node i
    nj : node j

    Returns
    -------
    frc : int
        Forman Ricci curvature of the edge connecting nodes i and j

    '''
    frc = 4 - G.degree(ni) - G.degree(nj)
    return frc 

In [None]:
def afr_curvature (G, ni, nj, m):
    '''
    computes the Augmented Forman-Ricci curvature of a given edge 
    includes 3-cycles in calculation 
    
    Parameters
    ----------
    G : Graph
    ni : node i
    nj : node j
    m : number of triangles containing the edge between node i and j

    Returns
    -------
    afrc : int
        Forman Ricci curvature of the edge connecting nodes i and j   
    '''
    afrc = 4 - G.degree(ni) - G.degree(nj) + 3*m
    return afrc

In [None]:
def afr4_curvature (G, ni, nj, t, q):
    '''
    computes the Augmented Forman-Ricci curvature of a given edge, 
    includes 3- and 4-cycles in calculation 
    
    Parameters
    ----------
    G : Graph
    ni : node i
    nj : node j
    t : number of triangles containing the edge between node i and j
    q : number of quadrangles containing the edge between node i and j

    Returns
    -------
    afrc4 : int
        enhanced Forman Ricci curvature of the edge connecting nodes i and j   
    '''
    afrc4 = 4 - G.degree(ni) - G.degree(nj) + 3*t + 2*q
    return afrc4

In [None]:
def afr5_curvature (G, ni, nj, t, q, p):
    '''
    computes the Augmented Forman-Ricci curvature of a given edge 
    includes 3-, 4- and 5-cycles in calculation 
    
    Parameters
    ----------
    G : Graph
    ni : node i
    nj : node j
    t : number of triangles containing the edge between node i and j
    q : number of quadrangles containing the edge between node i and j
    p : number of pentagons containing the edge between node i and j

    Returns
    -------
    afrc5 : int
        enhanced Forman Ricci curvature of the edge connecting nodes i and j   
    '''
    afrc5 = 4 - G.degree(ni) - G.degree(nj) + 3*t + 2*q + 1*p
    return afrc5

In [None]:
def init_edge_attributes(G):
    curv_names = ["frc", "afrc", "afrc4", "afrc5"] 
    for (u,v) in list(G.edges()):
        for i in range(3,6):
            G.edges[u,v][cyc_names[i]] = []
        for cn in curv_names:
            G.edges[u,v][cn] = 0

In [None]:
def set_edge_attributes_2 (G, ll, i):
    for l in ll:     # für jeden Zyklus in der Liste der Zyklen
        for e1 in range(0, i): 
            if e1 == i-1:
                e2 = 0
            else:
                e2 = e1 + 1
            u = l[e1]
            v = l[e2]
            G.edges[u,v][cyc_names[i]].append(l)

In [None]:
def get_orc_edge_curvatures (G):          
    # compute the Ollivier-Ricci curvature of the given graph G
    orc = OllivierRicci(G, alpha=0.5, verbose="INFO")
    orc.compute_ricci_curvature()
    # transfer curvatire values from orc.G to G 
    for (u,v) in list(orc.G.edges()):               # für jede Kante
        G.edges[u,v]["orc"] = orc.G.edges[u,v]["ricciCurvature"]

In [None]:
def get_edge_curvatures (G):            
    for (u,v) in list(G.edges()):               # für jede Kante
        tr = len(G.edges[u,v][cyc_names[3]])
        qu = len(G.edges[u,v][cyc_names[4]])
        pe = len(G.edges[u,v][cyc_names[5]])
        G.edges[u,v]["frc"] = fr_curvature(G, u, v)        
        G.edges[u,v]["afrc"] = afr_curvature(G, u, v, tr)
        G.edges[u,v]["afrc4"] = afr4_curvature(G, u, v, tr, qu)
        G.edges[u,v]["afrc5"] = afr5_curvature(G, u, v, tr, qu, pe)    

In [None]:
def show_curv_min_max_values (h_data):
    print("\nMin/Max Curvature values:")
    for k in h_data.keys():
        print(str(k).ljust(8), 
              "{0:<5s} {1:4d}".format("Min:", h_data[k]["bin_min"]), "  ",
              "{0:<5s} {1:4d}".format("Max:", h_data[k]["bin_max"])
              )
    print()

In [None]:
def show_histos (h_data, title_str, my_nrows = 2, my_ncols = 3, bin_num_lim = 40):
    fig, axes = plt.subplots(nrows=my_nrows, ncols=my_ncols, sharey = True, figsize=(16,10))
    for i,k in enumerate(h_data.keys()):
        r = i // my_ncols
        c = i % my_ncols
        bin_width = (h_data[k]["bin_max"] - h_data[k]["bin_min"]) // bin_num_lim + 1
        axes[r,c].hist(h_data[k]["curv"], bins = np.arange(h_data[k]["bin_min"], h_data[k]["bin_max"] + bin_width, bin_width), edgecolor = "white")
        axes[r,c].set_title(h_data[k]["title"])
        axes[r,c].title.set_size(16)
        axes[r,c].tick_params(axis='both', labelsize=16)
        axes[r,c].grid(visible=True, axis="both")
    fig.suptitle(title_str, size=16)
    plt.show()   

In [None]:
def show_correlation_coeffs (h_data):
    print("\nCorrelation coefficients:")
    ks = list(h_data.keys())
    for i in range(len(ks)):
        for j in range(i+1, len(ks)):
            s = h_data[ks[i]]["title"] + " / " + h_data[ks[j]]["title"]
            c = np.corrcoef(h_data[ks[i]]["curv"], h_data[ks[j]]["curv"])[1][0]
            print(s.ljust(55,"."), f"{c:8.5f}")
        print()

In [None]:
def show_curv_data (G, title_str):
    h_data = {"orc":  {"curv": [d["orc"]   for u,v,d in G.edges.data()], "bin_min":0, "bin_max":0, "title":"Ollivier Ricci (OR)"},
              "frc":  {"curv": [d["frc"]   for u,v,d in G.edges.data()], "bin_min":0, "bin_max":0, "title":"Forman Ricci (FR)"},
              "afrc": {"curv": [d["afrc"]  for u,v,d in G.edges.data()], "bin_min":0, "bin_max":0, "title":"Augm. FR curv. (triangles)"},
              "afrc4":{"curv": [d["afrc4"] for u,v,d in G.edges.data()], "bin_min":0, "bin_max":0, "title":"AFR curv. (tri/quad)"},
              "afrc5":{"curv": [d["afrc5"] for u,v,d in G.edges.data()], "bin_min":0, "bin_max":0, "title":"AFR curv. (tri/quad/pent)"}
              }
    
    for k in h_data.keys():
        h_data[k]["bin_min"] = int(min(h_data[k]["curv"]))
        h_data[k]["bin_max"] = int(max(h_data[k]["curv"]))
        
    show_curv_min_max_values (h_data)
    show_histos (h_data, title_str, my_nrows = 2, my_ncols = 3, bin_num_lim = 40)
    show_correlation_coeffs(h_data)

In [None]:
cyc_names = {3:"triangles", 4:"quadrangles", 5:"pentagons"}        


In [None]:
def build_size_list (k, l):
    ll = [k  for i in range(l)]
    return ll

In [None]:
def build_prob_list (n, p_in, p_out):
    ll = []
    for i in range(n):    
        temp_l = [p_out  for j in range(0,i)] + [p_in] + [p_out  for j in range(i+2,n+1)]
        ll.append(temp_l)
    return ll

In [None]:
def calculate_SBM(k, l, p_in, p_out, title_str):
    print("k:",k," l:",l," p_in:",p_in," p_out:",p_out)
    sizes = build_size_list(k, l)
    probs = build_prob_list(l, p_in, p_out)
    
    G = nx.stochastic_block_model(sizes, probs, seed = 0)
    init_edge_attributes(G)
      
    H = G.to_directed()
    
    pos1 = nx.kamada_kawai_layout(H)
    blocks = [v["block"]  for u,v in H.nodes.data()]
    # plot_my_graph(H, pos1, node_col = blocks)
    
    cycles = []
    for c in simple_cycles(H, 6):
        cycles.append(c) 
    
    d = dict()
    for i in range(3,6):
        d[i] = [c  for c in cycles  if len(c) == i]
        d[i] = remove_permutations(d[i])
        set_edge_attributes_2(G, d[i], i)
        
    get_orc_edge_curvatures (G)
    get_edge_curvatures (G)
    show_curv_data (G, title_str)
        
    return d

In [None]:
def calculate_SBMs():
    ll_k = [5,10,15,20]
    k_def = 15
    ll_l = [2,3,4,5]
    l_def = 5
    ll_p_in = [0.6, 0.7, 0.8, 0.9]
    p_in_def = 0.7
    ll_p_out = [0.05, 0.03, 0.02, 0.01]
    p_out_def = 0.05
    for k in ll_k:
        s = "Variation of community size / k = " + str(k) + "\n" + \
            "k=" + str(k) + " l=" + str(l_def) + " p_in:" + str(p_in_def) + " p_out:" + str(p_out_def)
        calculate_SBM(k, l_def, p_in_def, p_out_def, s)
    for l in ll_l:
        s = "Variation of number of communities / l = " + str(l) + "\n" + \
            "k=" + str(k_def) + "  l=" + str(l) +  "  p_in=" + str(p_in_def) + "  p_out=" + str(p_out_def)
        d = calculate_SBM(k_def, l, p_in_def, p_out_def, s)
    for p_in in ll_p_in:
        s = "Variation of p_in / p_in = " + str(p_in) + "\n" + \
            "k=" + str(k_def) + " l=" + str(l_def) +  " p_in:" + str(p_in) + " p_out:" + str(p_out_def)
        calculate_SBM(k_def, l_def, p_in, p_out_def, s)
    for p_out in ll_p_out:
        s = "Variation of p_out / p_out = " + str(p_out) + "\n" + \
            "k=" + str(k_def) + " l=" + str(l_def) +  " p_in:" + str(p_in_def) + " p_out:" + str(p_out)
        calculate_SBM(k_def, l_def, p_in_def, p_out, s)
    
    return d

In [None]:
d = calculate_SBMs()  