In [1]:
import dgl
import torch
from tqdm import tqdm
import numpy as np
import collections
import itertools
import os
import copy
from torch.sparse import *
os.environ['DGLBACKEND'] = 'pytorch'
import random
import math
import sqlite3
import json
from tabulate import tabulate

In [2]:
def simplex_verifier(simplex, graph):
    """This function takes input a k-tuple [v_1,v_2,...,v_k] for k>1 and a graph G
    and verifies if it a simplex of G. This is done by simply checking if there is 
    an edge between v_i and v_j for i<j"""
    
    assert isinstance(graph, dgl.DGLHeteroGraph), \
        'Keyword argument \"graph\" must be a dgl.DGLHeteroGraph.'
    assert isinstance(simplex, list), \
        'Keyword argument \"simplex\" must be a list.'
    
    is_simplex = False
    
    simplex_reduced = copy.deepcopy(simplex)
    for i in simplex:
        simplex_reduced.pop(0)
        for j in simplex_reduced:
            #built-in DGL function that checks if there is an edge between nodes i and j.
            if graph.has_edges_between(i, j):
                continue
            else:
                return is_simplex
    return True


def edges_to_identify(graph):
    """Constructing the Hadamard product here to find the list of edges to contract.
    All the edges that are parts of a 2-simplex are identified in this function
    as indices of the Hadamard product. Moreover, other edges that are not part of the Hadamard
    product, but are part of the transitive closure of existing edges are also identified.
    These are the edges that form a complete 2-simplex."""

    assert isinstance(graph, dgl.DGLHeteroGraph), \
        'Keyword argument \"graph\" of create_hadamard must be a dgl.DGLHeteroGraph.'
    
    #First, we remove all diag entries from adj matrix A by removing all self-loops
    loopless = dgl.transforms.RemoveSelfLoop()
    graph = loopless(graph)
    
    #Then we remove all diagonals from A^2 and convert the matrix to one with binary entries
    adj_squared = torch.sparse.mm(graph.adj_external(),graph.adj_external())
    diagonal_mask = (adj_squared._indices()[0] == adj_squared._indices()[1])
    off_diagonal_mask = ~diagonal_mask
    adj_squared._values()[off_diagonal_mask] = 1.0
    new_indices = adj_squared._indices()[:, off_diagonal_mask]
    new_values = adj_squared._values()[off_diagonal_mask]
    new_size = adj_squared.size()
    squared_no_diag_binary = torch.sparse_coo_tensor(indices=new_indices, 
                                                    values=new_values, size=new_size)
    
    #the hadamard product is sparse, but keeps track of entries that are zero
    false_hadamard_product = graph.adj_external() * squared_no_diag_binary
    
    #we, therefore need to remove those entries
    false_hadamard_product = false_hadamard_product.coalesce()
    non_zero_mask = false_hadamard_product._values().nonzero().squeeze()
    non_zero_values = false_hadamard_product._values()[non_zero_mask]
    non_zero_indices = false_hadamard_product.indices()[:, non_zero_mask]
    hadamard_product = torch.sparse_coo_tensor(indices=non_zero_indices,
                                               values=non_zero_values,
                                               size=false_hadamard_product.size())
    
    """The following loop finds all edges that are
    part of a simplex need to be collapsed"""
    row_indices, col_indices = hadamard_product._indices()
    extra_edges = list()
    
    #for i,j in tqdm(zip(row_indices,col_indices), position=0, leave=False):
    for i,j in zip(row_indices,col_indices):
        out_nodes = set([int(v) for v in list(graph.successors(i))])
        in_nodes = set([int(v) for v in list(graph.predecessors(j))])
        #these are the elements in the (reverse?) transitive closure of (i,j)
        intersection = set.intersection(out_nodes,in_nodes)
        for k in intersection:
            extra_edges = extra_edges + [(int(i),int(k))] + [(int(k),int(j))]
        
    return hadamard_product, extra_edges

In [3]:
def generate_alternative_lists(mapping, simplex):
    """Takes as input the pre-images of a quotient for node projection
    and a simplex and spits out all possible simplices in the pre-image"""
    indices = [range(len(mapping[key])) for key in simplex]
    combinations = itertools.product(*indices)

    result_lists = []
    for combo in combinations:
        result_lists.append([mapping[simplex[i]][combo[i]] for i in range(len(simplex))])

    return result_lists

def edge_present_in_simplex(edge, simplices):
    """Takes as input an edge [i,j] and a simplex [v_0,v_1,...,v_n] and
    checks if the there are x,y such that [i,j] = [v_x,v_y]"""
    query = False
    for simplex in simplices:
        if edge[0] in simplex:
            if edge[1] in simplex:
                if simplex.index(edge[0]) < simplex.index(edge[1]):
                    query = True
                    break
    return query

class AdjGraph():
    
    
    empty_graph = dgl.heterograph({('node', 'to', 'node'): ([], [])})  
        
    def __init__(self,
        graph = empty_graph):
        
        assert isinstance(graph, dgl.DGLHeteroGraph), \
        'Keyword argument \"graph\" of AdjGraph\'s init methodmust be a dgl.DGLHeteroGraph.'
        
        self.seed_graph = graph
        self.nodes = [int(node) for node in torch.cat(
            self.seed_graph.edges(),dim=0).unique().numpy()]
        self.edges = [edge.tolist() for edge in torch.stack(
            self.seed_graph.edges(), dim = 1).numpy()]
        #Simplices of dimension >2 are imported into this variable 
        self.preexisting_simplices = dict()
        #this dictionary carries all simplices of the current graph
        self.simplices             = dict()
        #simplices that have not been inductively connected are placed in this dictionary
        self.new_simplices         = list()
        #simplex_id is updated externally once preexisting_simplices are fed in.
        #This is used to give new simplex IDs to the new quotient.
        self.simplex_id            = int
        
    def create_simplices_dict(self):
        old_edges          = list()
        old_nodes          = list()

        for key, element in self.preexisting_simplices.items():
            if key[0] == 0:
                old_nodes = old_nodes + [element]
                self.simplices.update({(0, self.simplex_id): element})
                self.simplex_id  = self.simplex_id + 1
                continue
            if key[0] == 1:
                old_edges = old_edges + [element]
                self.simplices.update({(1, self.simplex_id) : element})
                self.simplex_id  = self.simplex_id + 1
                continue
            else:
                #We want to include simplices of dimension at least 2 from the lift
                #as pre-existing simplices on which we 'inductively' search for higher simplices
                self.new_simplices = self.new_simplices + [element]
                self.simplices.update({(len(element)-1, self.simplex_id) : element})
                self.simplex_id  = self.simplex_id + 1  
                
        for node in self.nodes:
            #keeping track of old nodes is probably not important.
            if node not in old_nodes:
                self.simplices.update({(0,self.simplex_id):[node]})
                self.simplex_id = self.simplex_id + 1
                
        for edge in self.edges:
            #but we need to keep track of edges that appear in pre-image of a quotient
            #since they need to be inductively connected
            if edge not in old_edges:
                if edge_present_in_simplex(edge, self.new_simplices):
                    continue
                else:
                    self.simplices.update({(1,self.simplex_id):edge})
                    self.new_simplices = self.new_simplices + [edge] 
                    self.simplex_id = self.simplex_id + 1                
        
    def connectivity_update(self):
        """This is the method that 'connects up' our AdjGraph so that it jumps
        up from being the adjacency graph
        of a k-connected simplicial set to the adjacency graph of a (k+1)-connected
        simplicial set. This is done
        by adding in (k+1)-simplicies where ever our simplicial set contains a 
        non-degenerate boundary of a
        standard (k+1)-simplex.
        
        At a high level, this method procedes as follows:
        1. Locate all (0-simplex, k-simplex)-pairs (u, s) such that the vertex u 
        is not contained in the k-simplex s.
        2. For each such pair, let [v_0, v_1, ...., v_k] be the 0-skeleton sk_0(s) of our 
        k-simplex s.
        3. For each resulting (0-simplex, 0-simplex)-pair (u, v_i), for 0<=i<=k, query if 
        [u, v_i] is a directed edge in our original graph, for all i ([v_i,u]?)
        4. If the answer is affirmative for every pair (u, v_i) in Step 3 above, then 
        [u, v_0, v_1, ...., v_k] is the 0-skeleton of a (k+1)-simplex that needs to be 
        added to AdjGraph.
        
        Note that k varies.
        """
    
        # Create a list of the 0-skeleta of all top-dimensional simplices 
        #in our current simpliciat set:
        new_zero_skeleta = self.new_simplices
        
        # Begin list of all non-degenerate boundaries of standard (k+1)-simplices 
        #in our current simplicial set:
        self.new_boundaries = []
        
        
        #for src in tqdm(self.nodes, 
                        #position=5, desc="i", leave=False, colour='green', ncols=200):
        for src in self.nodes:
            # Step 2: iterate over k-simplices in our current simplicial set, 
            #extracting the 0-skeleton of each:
            #for zero_skel in tqdm(new_zero_skeleta, 
                                  #position=6, desc="i", leave=False, colour='green', ncols=200):
            for zero_skel in new_zero_skeleta:
                # Check that our new 0-simplex doesn't already lie in the 
                #0-skeleton of our k-simplex:
                if src in zero_skel:
                    edge_present_query = False
                # If it doesn't, begin Step 3: 
                else:
                    edge_present_query = True
                    # Check that our original graph contains all necessary edges:
                    for dst in zero_skel:
                        if [src, dst] in self.edges:
                            edge_present_query = True
                        else:
                            edge_present_query = False
                        if edge_present_query == False:
                            break
                    # If it does, adjoin a new 0-skeleton to our list of 
                    #(k+1)-simplices to add to AdjGraph:
                    if edge_present_query == True:
                        self.new_boundaries.append([src]+zero_skel)

                        
        # Update our dictionary of 0-skeleta associated to vertices in AdjGraph by introducing one new key
        # (simplex dimension, simplex index) for each new simplex in AdjGraph, and define the value at this
        # new key to be the 0-skeleton of this simplex:
        self.new_simplices = list()
        for zero_skeleton in self.new_boundaries:
            self.simplices.update({(len(zero_skeleton)-1, self.simplex_id): zero_skeleton})
            self.new_simplices = self.new_simplices + [zero_skeleton]
            self.simplex_id = self.simplex_id + 1
        
    def highest_dimension(self):
        if len(self.new_boundaries) == 0:
                    return True
        return False
    
    
def eliminate_duplicate_simplices(input_dictionary):
    #some duplicates might still creep in. These have to be eliminated before being put
    #in the database
    reverse_dict = {tuple(value): key for key, value in input_dictionary.items()}
    unique_values = {value: list(key) for key, value in reverse_dict.items()}
    return unique_values

In [4]:
class graph_towers():
    src=list()
    dst=list()
    empty_graph = dgl.heterograph({('node', 'to', 'node'): (src, dst)})
    ratio = 0.0
    bottom_level = 0
    
    assert isinstance(empty_graph, dgl.DGLHeteroGraph), \
        'Keyword argument \"graph\" of graph_towers\'s init method must be a dgl.DGLHeteroGraph.'
    
    assert isinstance(ratio, float), \
        'Keyword argument \"ratio\" of graph_towers\'s init method must be a float.'
    
    assert ratio<=1 and ratio>=0, \
        'Keyword argument \"ratio\" of graph_towers\'s init method must be between 0 and 1.'

    assert isinstance(bottom_level, int), \
        'Keyword argument \"bottom_level\" of graph_towers\'s init method must be an integer.'
    

    def __init__(self, file_path, ratio, database_name, max_dimension, graph=empty_graph,
                 bottom_level=bottom_level):
        
        self.seed_graph        = graph
        self.srcs_and_dsts     = self.seed_graph.edges()   
        self.file_path         = file_path        
        self.ratio             = ratio
        self.updated_graph     = dgl.heterograph({('node', 'to', 'node'): ([], [])})
        self.bottom_level      = bottom_level
        self.database_name     = database_name
        self.maximum_dimension = max_dimension
        self.connection        = None
        self.cursor            = None
        self.number_of_nodes   = len(self.seed_graph.nodes())
        self.number_of_edges   = len(self.seed_graph.edges()[0])
        self.selected_edges    = None
        self.quotient_number   = 0
        self.simplex_id        = 0
        
        #Find list of edges that will be used to create a quotient graph
        self.hadamard_product, self.extra_edges  = edges_to_identify(self.seed_graph)
        rows, columns = self.hadamard_product._indices()
        self.edges_to_collapse_as_pairs = torch.cat(
            (torch.transpose(self.hadamard_product._indices(),0,1)
             ,torch.tensor(self.extra_edges)),dim=0)
        
        #Some edges from the 'extra_edges' and those given by the Hadamard product
        #are duplicated. We need to combine these in one variable.
        self.edges_to_collapse_as_pairs = torch.unique(self.edges_to_collapse_as_pairs, dim=0)
    
        self.all_nodes_to_identify  = torch.cat((rows,columns),dim=0).unique()
        self._equivalenceclasses    = dict()
        self.appendage_index        = len(self.seed_graph.nodes())
        self.edge_index             = 0        
        edge_pairs                  = torch.stack(self.seed_graph.edges(), dim = 1).int()
        self.all_edges_as_pairs     = edge_pairs
        self.edges_carry_fwd        = list()
        self.edges_never_contracted = None
        self.simplex_id             = 0
        
        #Find all node classes to yield maximum class size. This
        #maximum size then becomes number of columns for simplices, too
        self._globes = {element:partition for partition in relation(
            self.edges_to_collapse_as_pairs) for element in partition}
        
        self.maximum_class_size = max(len(set_value) for set_value in self._globes.values())
        self.loop_indicator     = torch.eq(self.srcs_and_dsts[0],self.srcs_and_dsts[1])
        self.existing_loops     = (self.seed_graph.edges()[0][self.loop_indicator],
                                   self.seed_graph.edges()[1][self.loop_indicator])
        self.existing_loops     = torch.stack(self.existing_loops, dim = 1).int()
        
        #Finds maximum dimension of a simplex.
        self.in_degrees  = self.seed_graph.in_degrees()
        self.out_degrees = self.seed_graph.out_degrees()
        potential_max    = min(int(torch.max(self.in_degrees)),
                               int(torch.max(self.out_degrees)),self.maximum_class_size)
        if self.maximum_dimension > potential_max:
            #print("The given graph cannot have simplices of dimension", self.maximum_dimension)
            #print("Changing ",self.maximum_dimension, "to", potential_max)
            self.maximum_dimension = potential_max
            
    
    def _close_db(self):
        self.connection.commit()
        self.connection.close()
        
        
    def _connect_db(self):
        self.connection = sqlite3.connect(self.database_name)
        self.cursor = self.connection.cursor()
        
        
    def _view_db(self,table_name):
        #for easier visualization of db
        self._connect_db()
        self.cursor.execute(f"PRAGMA table_info({table_name})")
        if table_name == 'edge_details':
            columns_info = self.cursor.fetchall()
            columns = [column[1] for column in columns_info[1:9]]
            self.cursor.execute(f"SELECT {', '.join(columns)} FROM {table_name}")
            rows = self.cursor.fetchall()
        else:
            self.cursor.execute(f"PRAGMA table_info({table_name})")
            columns = [column[1] for column in self.cursor.fetchall()]
            self.cursor.execute(f"SELECT * FROM {table_name}")
            rows = self.cursor.fetchall()
        table = tabulate(rows, headers=columns, tablefmt="pretty")
        
        print(table)
        self._close_db()
        
        
    def create_table(self):
        """This function creates the table, but does not intialize
        data. Therefore, this function only needs to be run the first time. """
        self._connect_db()

        self.cursor.execute('''
            CREATE TABLE IF NOT EXISTS graph (
                quotient_number INTEGER PRIMARY KEY,
                number_of_nodes INTEGER,
                number_of_edges INTEGER,
                number_of_simplices INTEGER,
                FOREIGN KEY (number_of_nodes) REFERENCES node_classes(node_class_id),
                FOREIGN KEY (number_of_edges) REFERENCES edge_details(edge_id),
                FOREIGN KEY (number_of_simplices) REFERENCES simplices(simplex_id)
                                            )
                            ''')
        self.cursor.execute(f'''
            CREATE TABLE IF NOT EXISTS node_classes (
                node_class_id INTEGER PRIMARY KEY,
                quotient_id INTEGER,
                number_of_nodes INTEGER,
                {', '.join(f'node_{i} INTEGER DEFAULT NULL' for i in range(0, self.maximum_class_size))}
                                                    )
                            ''')
        self.cursor.execute('''
            CREATE TABLE IF NOT EXISTS edge_details (
                edge_id INTEGER PRIMARY KEY,
                e_src INTEGER,
                e_dst INTEGER,
                edge_changed BOOLEAN,
                to_contract BOOLEAN,
                sampled BOOLEAN,
                contracted BOOLEAN,
                quotient_id INTEGER,
                multiplicity INTEGER,
                number_of_edges INTEGER
                                                    )
                            ''')
        self.cursor.execute(f'''
            CREATE TABLE IF NOT EXISTS simplices (
                simplex_id INTEGER PRIMARY KEY,
                quotient_id INTEGER,
                dimension INTEGER,
                number_of_simplices INTEGER,
                {', '.join(f'index_{i} INTEGER DEFAULT NULL' for i in range(0, self.maximum_class_size))}
                                                    )
                            ''')

        self._close_db()
        
        
    def initial_db_fill(self):
        """Fills in the database with the details from the (unquotiented) graph itself
        i.e., before the quotienting process """
        self._connect_db()
        self.cursor.execute('''INSERT INTO graph 
                            (quotient_number, number_of_nodes, number_of_edges, 
                            number_of_simplices) VALUES
                            (0, ?, ?, 0)
                            ''', (len(self.seed_graph.nodes()), 
                                  len(self.seed_graph.edges()[0]))
                           )
        
        for node in self.seed_graph.nodes():
            self.cursor.execute('''INSERT INTO node_classes
                                (node_class_id, node_0, quotient_id, number_of_nodes) VALUES 
                                (?, ?, 0, ?)
                                ''', (int(node), int(node),  
                                      len(self.seed_graph.nodes())))
        
        for edge in self.all_edges_as_pairs:
            if torch.any(torch.all(self.edges_to_collapse_as_pairs == edge, dim=1)):
                self.cursor.execute('''INSERT INTO edge_details
                                    (edge_id, e_src, e_dst, edge_changed, 
                                    to_contract, quotient_id, multiplicity, sampled,
                                    number_of_edges, contracted) VALUES 
                                    (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                                    ''', (self.edge_index, int(edge[0]), int(edge[1]), 
                                      False, True, 0, 1, False, len(self.seed_graph.edges()[0]),
                                          False)
                                   )
            elif torch.any(torch.all(self.existing_loops == edge, dim=1)):
                self.cursor.execute('''INSERT INTO edge_details
                                    (edge_id, e_src, e_dst, edge_changed, 
                                    to_contract, quotient_id, multiplicity, sampled,
                                    number_of_edges, contracted) VALUES 
                                    (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                                    ''', (self.edge_index, int(edge[0]), int(edge[1]), 
                                      False, False, 0, 1, False, len(self.seed_graph.edges()[0]),
                                          True)
                                   )

            else:
                self.cursor.execute('''INSERT INTO edge_details
                                    (edge_id, e_src, e_dst, edge_changed, 
                                    to_contract, quotient_id, multiplicity, sampled,
                                    number_of_edges, contracted) VALUES 
                                    (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                                    ''', (self.edge_index, int(edge[0]), int(edge[1]), 
                                      False, False, 0, 1, False, len(self.seed_graph.edges()[0]),
                                         False)
                                   )
            self.edge_index = self.edge_index + 1
        
        #Only edges from the Hadamard product and its transitive closure
        #need to be collapsed. The rest are gathered here. Query and execution
        #only made for debugging.
        self.cursor.execute(f'''SELECT e_src, e_dst
                            FROM edge_details 
                            WHERE to_contract = False
                            ''')
        
        self.edges_never_contracted = self.cursor.fetchall()
        
        self._close_db()
        
        
    def make_quotient(self):
        """This section can be combined with the db fill function above
        The only reason this is kept distinct is to keep the class modular"""
        #an empty graph where we add the quotient. This is only used as a container
        #and is forwarded to be added to the db.
        self.updated_graph = dgl.heterograph({('node', 'to', 'node'): ([], [])})
        self.new_edges_added = None
        
        self._sets = {id(self._equivalenceclasses[key]):self._equivalenceclasses[key] for key in self._equivalenceclasses.keys()}
        #create mapping of class names
        self._classesnamesmapping = dict()
        
        for setid in self._sets.keys():
            self._classesnamesmapping[setid] = self.appendage_index
            self.appendage_index = self.appendage_index + 1
        
        #each set (node class) is assigned a new node label

        
        #add the collapsed edges with their new source and target labels
        for edge in self.selected_edges:
            nodeclass    = self._equivalenceclasses[edge[0]]
            newnodelabel = self._classesnamesmapping[id(nodeclass)]
            self.updated_graph.add_edges(newnodelabel,newnodelabel)
        
        #since the node labels are changed, the edges that still need to be
        #contracted will have their src and dst changed. To keep track of
        #these un-contracted edges, we put them in the edges_carry_fwd
        #variable
        for edge in self.edges_carry_fwd:
            srcnodeclass    = self._equivalenceclasses.get(edge[0], edge[0])
            srcnewnodelabel = self._classesnamesmapping.get(id(srcnodeclass),edge[0])
            dstnodeclass    = self._equivalenceclasses.get(edge[1], edge[1])
            dstnewnodelabel = self._classesnamesmapping.get(id(dstnodeclass),edge[1])
            self.updated_graph.add_edges(srcnewnodelabel,dstnewnodelabel)
            
        self.new_edges_added = torch.stack(self.updated_graph.edges(), dim = 1).int()
        
        edges_never_contracted_copy = copy.copy(self.edges_never_contracted)
                        
        for edge in edges_never_contracted_copy:
            src = edge[0]
            dst = edge[1]
            #if the edge has been changed, remove it from the list of edges that never
            #need to be contracted and replace it with its new node labels
            if src in self._equivalenceclasses.keys() or dst in self._equivalenceclasses.keys():
                index_of_edge   = self.edges_never_contracted.index(edge)
                self.edges_never_contracted.pop(index_of_edge)
                srcnodeclass    = self._equivalenceclasses.get(edge[0], edge[0])
                srcnewnodelabel = self._classesnamesmapping.get(id(srcnodeclass),edge[0])
                dstnodeclass    = self._equivalenceclasses.get(edge[1], edge[1])
                dstnewnodelabel = self._classesnamesmapping.get(id(dstnodeclass),edge[1])
                new_edgelabel   = (srcnewnodelabel, dstnewnodelabel)
                self.edges_never_contracted.append(new_edgelabel)
            else:
                self.updated_graph.add_edges(src,dst)
                                    
        #print("Created a quotient")
        
        #remove variables to save space
        self.edges_carry_fwd = None
        self.selected_edges  = None
        
        self.quotient_number = self.quotient_number + 1
        
        
    def db_fill(self):
        """This function saves the details of a quotiented graph in the database.
        Therefore, this function should be called immediately after the quotient is made.
        All the edges of the graph are added, and are given different IDs, even if the edges
        are already present in a previous quotient. However, we have different collections
        here to ensure that the edges that have been changed are discarded from the edges
        that need to be sampled. To this end, we add two new columns viz. to_contract, sampled
        and contracted."""
        number_of_edges     = len(torch.unique(torch.stack(self.updated_graph.edges(), dim=1)))
        new_edges           = [tuple(edge.numpy()) for edge in self.new_edges_added]
        #to count the multiplicities of the edges
        edge_counter        = collections.Counter(new_edges)
        number_of_nodes     = len(torch.unique(torch.cat(self.updated_graph.edges())))
        
        self._connect_db()
        self.cursor.execute('''INSERT INTO graph 
                            (quotient_number, number_of_nodes, number_of_edges, 
                            number_of_simplices) VALUES
                            (?, ?, ?, 0)
                            ''', (self.quotient_number, 
                                  number_of_nodes,
                                  number_of_edges)
                           )
    
        for edge, count in edge_counter.items():
            if edge[0] == edge[1]:
                self.cursor.execute('''INSERT INTO edge_details
                                    (edge_id, e_src, e_dst, edge_changed, 
                                    to_contract, quotient_id, multiplicity, sampled,
                                    number_of_edges, contracted) VALUES 
                                    (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                                    ''', (self.edge_index, int(edge[0]), int(edge[1]), 
                                    False, True, self.quotient_number, count, True, 
                                    number_of_edges, True)
                                   )
            else:
                self.cursor.execute('''INSERT INTO edge_details
                                    (edge_id, e_src, e_dst, edge_changed, 
                                    to_contract, quotient_id, multiplicity, sampled,
                                    number_of_edges, contracted) VALUES 
                                    (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                                    ''', (self.edge_index, int(edge[0]), int(edge[1]), 
                                      False, True, self.quotient_number, count, False, 
                                      number_of_edges, False
                                         )
                                   )
                
            self.edge_index = self.edge_index + 1
            
        for edge in self.edges_never_contracted:
            if edge not in edge_counter.keys():
                self.cursor.execute('''INSERT INTO edge_details
                                    (edge_id, e_src, e_dst, edge_changed, 
                                    to_contract, quotient_id, multiplicity, sampled,
                                    number_of_edges, contracted) VALUES 
                                    (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                                    ''', (self.edge_index, int(edge[0]), int(edge[1]), 
                                          False, False, self.quotient_number, None, None, 
                                          number_of_edges, None)
                                   )
                self.edge_index = self.edge_index + 1
        
        #removing variable to save space
        self.new_edges_added = None
        
        #add new nodes to the database
        for setid, classlabel in self._classesnamesmapping.items():
            columns = ['node_class_id'] + ['quotient_id'] + ['number_of_nodes']
            node_data = self._sets[setid]
            columns = columns + [f'node_{i}' for i in range(0, len(node_data))]
            self.cursor.execute('''INSERT INTO node_classes ({})
                                VALUES ({})
                                '''.format(','.join(columns), ','.join(['?'] * len(columns))),
                                (classlabel, self.quotient_number, 
                                 number_of_nodes,) + tuple(node_data)
                               )       
        self._close_db()
        
                
    def empty_table(self, table_name):
        """Made this function to empty database without deleting database. 
        Helpful for debugging"""
        self._connect_db()
        self.cursor.execute(f"DELETE FROM {table_name}")
        self._close_db()
        
        
    def sampling(self):
        """samples edges to collapse, keeps track of changes 
        in db for edges remaining to collapse"""
        self._connect_db()
        #return edges that have not yet been contracted
        self.cursor.execute(f'''SELECT e_src, e_dst
                            FROM edge_details 
                            WHERE to_contract = True AND sampled = False
                            AND edge_changed = False
                            AND contracted = False AND quotient_id = {self.quotient_number}
                            ORDER BY RANDOM()
                            LIMIT {math.ceil(self.ratio * len(self.edges_to_collapse_as_pairs))}
                            ''')
        
        self.selected_edges = self.cursor.fetchall()
        
        
        #print("Creating equivalence classes of nodes..")
        self._equivalenceclasses = {
            element:partition for partition in relation(self.selected_edges) for element in partition}
        #print("equivalence classes created! These are",self._equivalenceclasses)
        
        #When an edge collapses, we assign a new node label. The following
        #list contains all the nodes that need to be changed.
        changed_nodes = list(self._equivalenceclasses.keys())
        
        #To update entries of sampled edges in db, we first find all
        #the unique classes constructed above.
        unique_sets = set()
        for value in self._equivalenceclasses.values():
            unique_sets.add(frozenset(value))
        
        selected_edges = []
        
        for equivalence_class in unique_sets:
        #updates the column entry `sampled' for edges that have been
        #sampled and those that naturally won't be available for future quotienting
            self.cursor.execute('''UPDATE edge_details
                                SET sampled = True
                                WHERE e_src IN ({}) AND e_dst IN ({}) AND quotient_id = ?
                                '''.format(','.join(['?']*len(equivalence_class)),
                                           ','.join(['?']*len(equivalence_class))),
                                list(equivalence_class) + list(equivalence_class) +
                                [self.quotient_number]
                               )
            
            self.cursor.execute('''SELECT e_src, e_dst
                                FROM edge_details
                                WHERE contracted = False AND
                                e_src IN ({}) AND e_dst IN ({})
                                AND quotient_id = ?
                                '''.format(','.join(['?']*len(equivalence_class)), 
                                            ','.join(['?']*len(equivalence_class))),
                                    list(equivalence_class) + list(equivalence_class) +
                               [self.quotient_number]
                               )
            selected_edges = selected_edges + self.cursor.fetchall()
        self.selected_edges = selected_edges

        #edges to carry forward
        self.cursor.execute(f'''SELECT e_src, e_dst
                            FROM edge_details 
                            WHERE to_contract = True AND sampled = False 
                            AND edge_changed = False AND contracted = False
                            AND quotient_id ={self.quotient_number}
                            ''')
        
        #These are the egdes we do  not collapse after sampling and 
        #considering transitive closure
        self.edges_carry_fwd = self.cursor.fetchall()
                
        self.cursor.execute('''UPDATE edge_details
                            SET edge_changed = True
                            WHERE e_src IN ({}) AND e_dst IN ({});
                            '''.format(','.join(['?']*len(changed_nodes)),
                                       ','.join(['?']*len(changed_nodes))), 
                            changed_nodes + changed_nodes)
        self._close_db()
        
            
            
    def create_towers(self):
        self.initial_db_fill()
        for _ in range(self.bottom_level+1):
            self.sampling()
            if len(self.selected_edges) == 0:
                #print("The graph cannot be quotiented further")
                #change bottom level to wherever the smallest possible quotient level is
                self.bottom_level = self.quotient_number
                break 
            self.make_quotient()
            self.db_fill()
        #self._view_db('edge_details')
        #self._view_db('node_classes')
            
            
    def extract_graph_from_db(self,quotient_number):
        """This function extracts all edges, given the quotient number.
        It then uses these edges to construct a DGL graph, which is necessary
        to search for simplices.
        
        In addition, it also picks up simplices from a graph below it in the hierarchy,
        and computes the lift of each simplex. These lifts are fed to the simplicial search
        function AdjGraph"""
        
        extracted_graph = dgl.heterograph({('node', 'to', 'node'): ([], [])})
        
        self._connect_db()
        #the contracted = False ensures that we don't pick up loops.
        self.cursor.execute(f'''SELECT e_src, e_dst
                            FROM edge_details
                            WHERE quotient_id = {quotient_number} AND contracted = False
                            ''')
        edges = self.cursor.fetchall()
        
        self.cursor.execute(f'''SELECT * FROM simplices
                            WHERE quotient_id = {quotient_number+1}
                            ''')
        simplex_details = self.cursor.fetchall()
        self._close_db()
        
        simplices_ids = list()
        simplices_to_lift = dict()
        
        #create the graph
        for src, dst in edges:
            extracted_graph.add_edges(src,dst)
        
        prev_nodes = list()
            
        #change simplices extracted in a dictionary (but why..?)
        for row in simplex_details:
            simplex_id, dimension = row[0], row[2]
            if dimension > 1:
                index_values = row[4:]
                index_values = list(value for value in index_values if value is not None)
                simplices_to_lift[(dimension, simplex_id)] = index_values
            simplices_ids = simplices_ids + [simplex_id]
            if dimension == 0:
                prev_nodes = prev_nodes + [row[4]]
        
        simplex_index = max(simplices_ids, default=0) + 1
        
        #node IDs are kept the same as their labels. Since a node may be present in different
        #quotient, the nodes for each quotient are not present in the database. This information
        #is extracted when edges are constructed.
        new_nodes = torch.unique(torch.cat(extracted_graph.edges()))
        new_nodes = new_nodes.tolist()
        nodes = new_nodes + prev_nodes
        new_nodes = list()
        prev_nodes = list()
        nodes = list(set(nodes))
        
        #Construct a dictionary of pre-images of the quotient ----> quotient+1 map on nodes
        columns = [f'node_{i}' for i in range(0, self.maximum_class_size)]
        column_names = ', '.join(columns)
        in_placeholders = ', '.join(['?' for _ in nodes])
        self._connect_db()
        self.cursor.execute(f'''SELECT {column_names}
                            FROM node_classes
                            WHERE node_class_id IN ({in_placeholders})
                            ''', tuple(nodes))
        node_mapping_db = self.cursor.fetchall()
        self._close_db()
        #create a mapping of node labels to their equivalence classes
        node_mapping = dict()
        for node, row in zip(nodes,node_mapping_db):
            key    = node
            values = row[0:]
            values = list(value for value in values if value is not None)
            node_mapping.update({key: values})
        
        #create potiential pre-image of each simplex already found in a previous quotient 
        potential_lifts = dict()
        
        for key, simplex in simplices_to_lift.items():
            #all potential lifts, for now, get the same keys
            potential_lifts.update({key: generate_alternative_lists(node_mapping, simplex)})
                    
        unique_lifts = dict()
        
        for key, values in potential_lifts.items():
            #if there are multiple pre-images
            if type(values[0]) == int:
                if simplex_verifier(values, extracted_graph):
                    unique_lifts.update({key: values})
            else:
                for simplex in values:
                    if simplex_verifier(simplex, extracted_graph):
                        unique_lifts.update({(key[0],simplex_index): simplex})
                        simplex_index = simplex_index + 1
                    
        return unique_lifts, extracted_graph, simplex_index
            
        
    def simplices_of_quotient(self,quotient_number):
        lifts, quotient_graph, simplex_index = self.extract_graph_from_db(quotient_number)
        simplex_finder = AdjGraph(quotient_graph)
        simplex_finder.preexisting_simplices = lifts
        simplex_finder.simplex_id = simplex_index
        simplex_finder.create_simplices_dict()
        #for _ in tqdm(range(1,self.maximum_dimension)):
        for _ in range(1,self.maximum_dimension):
            simplex_finder.connectivity_update()
            if simplex_finder.highest_dimension() == True:
                simplex_finder.simplices = eliminate_duplicate_simplices(simplex_finder.simplices)
                break
        #print("updated dictionary of simplices=",simplex_finder.simplices)
        
        number_of_simplices = len(simplex_finder.simplices.values())
        number_of_simplices = number_of_simplices + len(lifts.values())
        
        #add the found simplices to the database
        self._connect_db()
        
        for key, element in simplex_finder.simplices.items():
            dimension  = key[0]
            simplex_id = key[1]
            columns = ['dimension', 'simplex_id', 'quotient_id', 'number_of_simplices']
            columns = columns + [f'index_{i}' for i in range(len(element))]
            placeholders = ','.join(['?'] * len(columns))
            query = f'''INSERT INTO simplices ({','.join(columns)})
                        VALUES ({placeholders})
                    '''
            values = (dimension, simplex_id, quotient_number, number_of_simplices) + tuple(element)
            self.cursor.execute(query, values) 
            
        self._close_db()
        
    def simplicial_search(self):
        
        
        for q_id in range(self.bottom_level, -1, -1):
            self.simplices_of_quotient(q_id)
        self._view_db('simplices')
            
            
"""found from https://stackoverflow.com/questions/42069187/
create-a-list-of-unique-numbers-by-applying-transitive-closure"""
def relation(array):

    mapping = {}

    def parent(u):
        if mapping[u] == u:
            return u
        mapping[u] = parent(mapping[u])
        return mapping[u]

    for u, v in array:
        u = int(u)
        v = int(v)
        if u not in mapping:
            mapping[u] = u
        if v not in mapping:
            mapping[v] = v
        mapping[parent(u)] = parent(v)

    results = collections.defaultdict(set)
    

    for u in mapping.keys():
        results[parent(u)].add(u)

        
    return [x for x in results.values()]

            
def find_common_tensors(tensor_A,tensor_B):
    equal_pairs = torch.all(tensor_A[:, None, :] == tensor_B[None, :, :], dim=2)
    common_pair_indices = torch.nonzero(equal_pairs, as_tuple=False)
    return tensor_A[common_pair_indices[:, 0]]

In [5]:
class SimplexCreator():
    """Create standard simplex"""
    def __init__(self, dimension):
        self.input_dimension = dimension
        self.src=list()
        self.dst=list()
        for i in range(self.input_dimension+1):
            for j in range(self.input_dimension+1):
                if (i < j):
                    self.src = self.src + [i]
                    self.dst = self.dst + [j]

In [6]:
"""Code testing"""
K_5 = dgl.heterograph({('paper', 'cites', 'paper'): (SimplexCreator(dimension=10).src, SimplexCreator(dimension=10).dst)})
filepath = 'K_10'
db = 'K_10.db'
K_5_preprocessing = graph_towers(filepath,graph=K_5, database_name=db, ratio=0.01,bottom_level = 10, max_dimension = 23)
#K_5_preprocessing.create_table()
K_5_preprocessing.empty_table('edge_details')
K_5_preprocessing.empty_table('graph')
K_5_preprocessing.empty_table('node_classes')
K_5_preprocessing.empty_table('simplices')
K_5_preprocessing.create_towers()
K_5_preprocessing.simplicial_search()

  adj_squared = torch.sparse.mm(graph.adj_external(),graph.adj_external())


+------------+-------------+-----------+---------------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+----------+
| simplex_id | quotient_id | dimension | number_of_simplices | index_0 | index_1 | index_2 | index_3 | index_4 | index_5 | index_6 | index_7 | index_8 | index_9 | index_10 |
+------------+-------------+-----------+---------------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+---------+----------+
|     1      |      9      |     0     |          3          |   11    |         |         |         |         |         |         |         |         |         |          |
|     2      |      9      |     0     |          3          |   19    |         |         |         |         |         |         |         |         |         |          |
|     3      |      9      |     1     |          3          |   11    |   19    |         |         |         |         |        

In [7]:
# """Code testing"""
src = [0,0,0,1,1,2] + [1] + [4,4,4,5,5,6] 
dst = [1,2,3,2,3,3] + [4] + [5,6,7,6,7,7] 
twosimplices = dgl.heterograph({('paper', 'cites', 'paper'): (src, dst)})
filepath = 'twosimplices'
db2 = 'twosimplices.db'
twosimplices_preprocessing = graph_towers(filepath,database_name = db2, graph=twosimplices,ratio=0.1,bottom_level = 5, max_dimension = 20)
#twosimplices_preprocessing.create_table()
twosimplices_preprocessing.empty_table('edge_details')
twosimplices_preprocessing.empty_table('graph')
twosimplices_preprocessing.empty_table('node_classes')
twosimplices_preprocessing.empty_table('simplices')
twosimplices_preprocessing.create_towers()
twosimplices_preprocessing.simplicial_search()

+------------+-------------+-----------+---------------------+---------+---------+---------+---------+
| simplex_id | quotient_id | dimension | number_of_simplices | index_0 | index_1 | index_2 | index_3 |
+------------+-------------+-----------+---------------------+---------+---------+---------+---------+
|     1      |      3      |     0     |          4          |    5    |         |         |         |
|     2      |      3      |     0     |          4          |   12    |         |         |         |
|     3      |      3      |     1     |          4          |   12    |    5    |         |         |
|     4      |      3      |     1     |          4          |    5    |   12    |         |         |
|     5      |      2      |     0     |         13          |    1    |         |         |         |
|     6      |      2      |     0     |         13          |    5    |         |         |         |
|     7      |      2      |     0     |         13          |    7    | 

In [8]:
import random
def generate_random_graph(num_nodes):
    src_edges =[]
    dst_edges = []
    edges = []
    for i in range(2*num_nodes):
        src_edges.append(random.randint(0,num_nodes))
        dst_edges.append(random.randint(0,num_nodes))
        edges.append((src_edges[i],dst_edges[i]))
    graph = dgl.heterograph({('paper', 'cites', 'paper'): (src_edges, dst_edges)})
    return graph, edges

In [11]:
import matplotlib.pyplot as plt
import networkx as nx

dgl_G, edges = generate_random_graph(40)
print(edges)
print("number of edges originally=",len(edges))
nx_G = nx.DiGraph()
nx_G.add_edges_from(edges)
options = {
    'node_color': 'black',
    'node_size': 20,
    'width': 1,
}
#pos = nx.spring_layout(nx_G, seed=42)
#pos = nx.planar_layout(nx_G)
#nx.draw_networkx(nx_G, pos, with_labels=True, node_color='lightblue', node_size=200, font_size=10, font_color='black', arrows=True)
#plt.show()

[(20, 39), (10, 1), (27, 1), (38, 26), (31, 5), (28, 38), (10, 39), (29, 11), (12, 15), (25, 29), (40, 30), (32, 22), (40, 12), (17, 3), (14, 39), (0, 16), (23, 37), (31, 24), (21, 27), (22, 31), (13, 29), (8, 28), (26, 12), (6, 38), (16, 5), (20, 26), (40, 12), (18, 8), (13, 6), (32, 22), (35, 23), (14, 11), (7, 16), (10, 29), (28, 37), (24, 13), (13, 23), (23, 1), (25, 11), (3, 37), (23, 18), (20, 38), (25, 15), (5, 20), (8, 20), (9, 2), (11, 22), (32, 2), (20, 18), (14, 28), (4, 3), (28, 22), (40, 6), (5, 18), (30, 20), (35, 32), (16, 0), (14, 6), (17, 12), (7, 36), (0, 3), (4, 17), (14, 31), (15, 11), (39, 5), (19, 4), (8, 6), (25, 32), (15, 38), (3, 15), (9, 12), (4, 22), (25, 4), (28, 35), (34, 40), (30, 33), (4, 21), (15, 12), (19, 9), (17, 33)]
number of edges originally= 80


In [12]:
"""Code testing"""
db3 = 'random2.db'
random_preprocessing = graph_towers('random',database_name = db3, graph=dgl_G,ratio=0.7,bottom_level = 5, max_dimension = 20)
random_preprocessing.create_table()
#random_preprocessing.empty_table('edge_details')
#random_preprocessing.empty_table('graph')
#random_preprocessing.empty_table('node_classes')
#random_preprocessing.empty_table('simplices')
random_preprocessing.create_towers()
random_preprocessing.simplicial_search()

+------------+-------------+-----------+---------------------+---------+---------+---------+---------+---------+
| simplex_id | quotient_id | dimension | number_of_simplices | index_0 | index_1 | index_2 | index_3 | index_4 |
+------------+-------------+-----------+---------------------+---------+---------+---------+---------+---------+
|     1      |      0      |     0     |         124         |    0    |         |         |         |         |
|     2      |      0      |     0     |         124         |    1    |         |         |         |         |
|     3      |      0      |     0     |         124         |    2    |         |         |         |         |
|     4      |      0      |     0     |         124         |    3    |         |         |         |         |
|     5      |      0      |     0     |         124         |    4    |         |         |         |         |
|     6      |      0      |     0     |         124         |    5    |         |         |    

plot sparseness vs time

build correspondence between graph + subgraph in our context
imporve lift
investigate slow down!
make proper sql queries
send AOPS