In [None]:
#List of imports

import networkx as nx
import numpy as np
import pickle
from pathlib import Path
import os
import torch
import time
import dgl
import WLColorRefinement as wl
import CSL_data
import Molecules_data


In [None]:
"""
    Download MOLECULES datasets
"""
# if not os.path.isfile('molecules.zip'):
#     print('downloading..')
#     !curl https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1 -o molecules.zip -J -L -k
# else:
#     print('File already downloaded')

In [None]:
"""
    Get CSL graphs and send to color function
"""
# CSL_graphs = CSL_data.get_CSL_graphs()

"""
    Get ZINC graphs
"""
ZINC_graphs = Molecules_data.MoleculeDatasetDGL('ZINC')

In [None]:
"""
    Send the CSL graphs to the coloring function
"""
# CSL_graphs_colored = list()

# for i in CSL_graphs:
#     CSL_graphs_colored.append(wl.wl_coloring(i))

"""
    Send the ZINC graphs to the coloring function
"""
ZINC_graphs_colored = list()

print('Original Graph', '\n')
x = ZINC_graphs.train.__getitem__(0)[0]
nx_x = x.to_networkx()
pos = nx.kamada_kawai_layout(nx_x)
nx.draw(nx_x, pos, with_labels=True, node_color=x.ndata['feat'])

for graph in ZINC_graphs.train:
    ZINC_graphs_colored.append(wl.wl_coloring(graph.__getitem__(0)))

In [None]:
y = ZINC_graphs_colored[0]
print('Colored Graph \n')
# print(y)
# print('nodes data: ', y.nodes())
# print('nodes features: ', y.ndata)
# print('edges data: ', y.edges())
# print('edges features: ', y.edata)
nx_y = y.to_networkx()
pos = nx.kamada_kawai_layout(nx_y)
nx.draw(nx_y, pos, with_labels=True, node_color=y.ndata['feat'])

In [None]:
"""
    super_node(dgl.DGLgraph)
    
    Function to create a supernode
    The idea is to have a function that iterates over the graph and
    finds the nodes with the same colors and then places them in one supernode
"""

def super_node(graph):
    super_node = {}
    colors = graph.ndata['feat'].numpy()
    nodes = graph.nodes().numpy()

    count = len(colors)
    i = 0
    while i < count:
        if colors[i] not in super_node:
            super_node[colors[i]] = dict()
        super_node[colors[i]][nodes[i]] = list()
        i += 1
#     print('Supernode initial: ', super_node, '\n')
    edges = create_edge_list(graph)
#     print('Edges: ', edges, '\n')
    super_node = super_edge(edges, super_node)
#     print('Supernode: ', super_node, '\n')
    mapping = create_mapping(super_node)
#     print('Mapping: ', mapping, '\n')
    return prepare_DGLgraph(super_node, mapping)

In [None]:
"""
    Create_edge_list(dgl.DGLgraph)
    
    Function to create a list with (source node, destination node)
"""

def create_edge_list(graph):
    src_edges = graph.edges(form='uv', order='srcdst')[0].numpy()
    dst_edges = graph.edges(form='uv', order='srcdst')[1].numpy()
    
    concat_edges = list()
    for i in range(0, len(src_edges)):
        sub_list = list()
        sub_list.append(src_edges[i])
        sub_list.append(dst_edges[i])
        concat_edges.append(sub_list)
    return concat_edges  

In [None]:
"""
    super_edge(list, dict)
    
    Function to find the edges between supernodes.
    
    Make sure that the the src_edges of the create_edge_list function is sorted from low to high.
"""

def super_edge(edges, super_node):
    weights = {}
    for color in super_node:
        for node in super_node[color]:
            for edge in edges:
                if node == edge[0]:
                    super_node[color][node].append(edge)
                elif node < edge[0]:
                    break
    return (super_node)

In [None]:
"""
    create_mapping(dict)
    @super_node - the supernode with original nodes added to it and a list of original edges
    
    Create a mapping from one nodes to super nodes.
"""

def create_mapping(super_node):
    mapping = {}
    for color in super_node:
        for node in super_node[color]:
            mapping[node] = color
#     print(mapping)
    return mapping

In [None]:
"""
    prepare_DGLgraph(dict, dict)
    
    Prepare a DGLgraph with the following steps:
        1. Overwrite the original node with the color of the supernode
        2. Calculate and add the weights to the edges
           The weight of the edge is the total of edges that go from one node within the supernode 
           to another node in another supernode. If the number of edges from each node going out are not the same, 
           the smallest common value is taken.
"""

def prepare_DGLgraph(super_node, mapping):
    print('mapping: ', mapping)
    src_edge = list()
    dst_edge = list()
    super_node_reformat = {}
    weights = list()
    for color, nodes in super_node.items():
        super_node_reformat[color] = {}
        
        for node, edges in nodes.items():
            super_node_reformat[color][node] = {}
            super_node_reformat[color][node]['edges'] = edges
            super_node_reformat[color][node]['weights'] = {}
            i = 0
            while i < len(edges):
                mapping_src = mapping[edges[i][0]]
                mapping_dst = mapping[edges[i][1]]
                super_node[color][node][i][0] = mapping_src               
                super_node[color][node][i][1] = mapping_dst
                src_edge.append(mapping_src)
                dst_edge.append(mapping_dst)
                
                weights = super_node_reformat[color][node]['weights']
                if mapping_dst not in weights:
                    weights[mapping_dst] = 1
                else:
                    weights[mapping_dst] += 1
                i += 1                    
    return super_node_reformat, src_edge, dst_edge
        
"""
    create_DGLgraph(dict)
"""

def create_DGLgraph(graph):
    U = torch.tensor(graph[1])
    V = torch.tensor(graph[2])

    weights_list = list()
#     for color, nodes in graph[0].items():
#         len_nodes = len(nodes)
#         len_weights = len(nodes[0]['weights'])
#         if len(nodes) == 1:
#             for node, weight in nodes:
#                 for j in weight:
#                     weights_list.append(weight[j])
#         else:
#             occurs = list()
#             for node, weight in nodes:
#                 for j in weight:
#                     occurs.append(j)
                    
#             for weight_1 in nodes[0]['weights']:
#                 occur_1 = nodes[0]['weights'][weight_1]
            
#             if len(nodes) ==  1:
#                 weight.append(occur_1)
#             else:
#                 for i in nodes[1:]:
#                     if weight_1 in nodes[i]['weights']:
#                         if nodes[i]['weights'][weight_1] == occur_1:
#                             weight.append(occur_1)
#                     elif nodes[i]['weights'][weight_1] >= occur_1:
#                         weight.append(occur_1)
#                     elif nodes[i]['weights'][weight_1] <= occur_1:
#                         weight.append(nodes[i]['weights'][weight_1])
#                 else:
#                     weight.append(0)

    dgl_graph = dgl.graph((U, V), num_nodes = len(graph[0]))
    return dgl_graph
    #Add nodes
    
    #Add node data
    
    #Add edges
    
    #Add edge data
    
    

In [None]:
newGraph = super_node(ZINC_graphs_colored[0])

In [None]:
graph = create_DGLgraph(newGraph)

In [None]:
print(newGraph[0])
print(newGraph[1])
print(newGraph[2])

count = 0
for i in newGraph[0]:
    print('node: ', i)
    count += 1
print('count = ', count)
print(len(newGraph[0]))
# kd = ZINC_graphs.train.__getitem__(0)[0]
# print('kd.ndata: ', kd.ndata)
# print('kd.edata: ', kd.edata)
# print('kd.nodes(): ', kd.nodes())
# print('kd.edges(): ', kd.edges())
# print('kd: ', kd)
# print(type(kd))

In [None]:
graph = newGraph[0]
for i in graph:
    print('i: ', i)
    
    for j in graph[i]:
        print('j: ', j)
        print('edges: ', graph[i][j]['edges'])
        print('weights: ', graph[i][j]['weights'])
    print('\n')