In [None]:
import os

import numpy as np
import torchinfo
import networkx as nx

import matching.glema.common.utils as utils
from matching.glema.common.model import GLeMaNet
from matching.glema.common.model import InferenceGNN

In [None]:
args = utils.parse_args( use_default=True )
# SYNTHETIC_TINY_jump_directed_30e
# KKI_jump_directed_promising
model_ckpt = "training/save/KKI_jump_directed_promising/best_model.pt"
args = utils.load_args( args, model_ckpt )
args.ckpt = model_ckpt

args.data_path = "data/data_real/datasets"
#"SYNTHETIC_TINY_train"
args.dataset = "KKI_test"
args.source = 49
args.query = 8
args.iso = True
args.directed = True
args.mapping_threshold = 0.5
args.tag = ""
args

In [None]:
data_path = utils.get_abs_file_path( os.path.join( args.data_path, args.dataset ) )
print( "data_path:", data_path )

In [None]:
torchinfo.summary( GLeMaNet( args ) )

In [None]:
model = InferenceGNN( args )

In [None]:
def get_node_labels( G ):
    return nx.get_node_attributes( G, 'label' )

In [None]:
# Load subgraph
subgraphs = utils.read_graphs(
    f"{data_path}/{args.source}/{'non' if not args.iso else ''}iso_subgraphs.lg", directed=True
)
subgraph = subgraphs[ args.query ]
print( "subgraph exists", subgraph is not None )
print( "subgraph nodes", subgraph.number_of_nodes() )

In [None]:
utils.plot_graph( subgraph, nodeLabels=get_node_labels( subgraph ) )

In [None]:
utils.set_seed( 8 )
sub_subgraph = utils.random_subgraph( subgraph, 6 )
utils.plot_graph( sub_subgraph, nodeLabels=get_node_labels( sub_subgraph ) )

In [None]:
graphs = utils.read_graphs( f"{data_path}/{args.source}/source.lg", directed=True )
graph = graphs[ args.source ]
print( "graph exists", graph is not None )
print( "graph nodes", graph.number_of_nodes() )

In [None]:
utils.plot_graph( graph, nodeLabels=get_node_labels( subgraph ) )

In [None]:
# Load mapping groundtruth
mapping_gts = utils.read_mapping(
    f"{data_path}/{args.source}/{'non' if not args.iso else ''}iso_subgraphs_mapping.lg"
)
mapping_gt = mapping_gts[ args.query ]
print( mapping_gt )

In [None]:
def is_subgraph( graph, query, conf=0.5 ) -> bool:
    results = model.predict_label( [ query ], [ graph ] )
    return ( results[ 0 ] > conf ).item()

In [None]:
print( "result", is_subgraph( graph, subgraph, args.confidence ) )

In [None]:
subgraph_count = 0
for key, query in subgraphs.items():
    results = model.predict_label( [ query ], [ graph ] )
    is_subgraph = (results[ 0 ] > args.confidence).item()
    if is_subgraph:
        subgraph_count += 1

print( f"Are subgraphs: {subgraph_count}/{len( subgraphs )}" )

In [None]:
def plot_matching( graph, subgraph, mapping_gt ):
    source_labels = get_node_labels( graph )
    source_labels = { key: f"{key}: {value}" for key, value in source_labels.items() }
    subgraph_labels = get_node_labels( subgraph )
    subgraph_labels = { key: f"{key}: {value}" for key, value in subgraph_labels.items() }
    utils.plot_graph( graph, title="Source", nodeLabels=source_labels )
    utils.plot_graph( subgraph, title="Query", nodeLabels=subgraph_labels )

    if True:
        interactions = model.predict_embedding( [ subgraph ], [ graph ] )
        # print("interactions", interactions[0])
        interactions = interactions[ 0 ].cpu().detach().numpy()
        n_subgraph_atom = subgraph.number_of_nodes()
        x_coord, y_coord = np.where( interactions > args.mapping_threshold )

        print( "Embedding: (subgraph node, graph node)" )
        interaction_dict = { }
        for x, y in zip( x_coord, y_coord ):
            if x < n_subgraph_atom <= y:
                interaction_dict[ (x, y - n_subgraph_atom) ] = interactions[ x ][ y ]
                # print("(", x, y-n_ligand_atom, ")")

            if (
                    x >= n_subgraph_atom > y
                    and (y, x - n_subgraph_atom) not in interaction_dict
            ):
                interaction_dict[ (y, x - n_subgraph_atom) ] = interactions[ x ][ y ]
                # print("(", y, x-n_ligand_atom, ")")

        list_mapping = list( interaction_dict.keys() )
        mapping_dict = { }
        for node in subgraph.nodes:
            cnode_mapping = list(
                map(
                    lambda y: (y[ 1 ], interaction_dict[ y ]),
                    filter( lambda x: x[ 0 ] == node, list_mapping ),
                )
            )
            if len( cnode_mapping ) == 0:
                mapping_dict[ node ] = [ ]
                continue

            max_prob = max( cnode_mapping, key=lambda x: x[ 1 ] )[ 1 ]
            mapping_dict[ node ] = list(
                map( lambda x: x[ 0 ], filter( lambda y: y[ 1 ] == max_prob, cnode_mapping ) )
            )

        print( "Mapping:", mapping_dict )

        node_labels = { n: "" for n in graph.nodes }
        for sgn, list_gn in mapping_dict.items():
            for gn in list_gn:
                if len( node_labels[ gn ] ) == 0:
                    node_labels[ gn ] = str( sgn )
                else:
                    node_labels[ gn ] += ",%d" % sgn

        node_colors = { n: "gray" for n in graph.nodes }
        for node, nmaping in node_labels.items():
            if not nmaping:
                if mapping_gt[ node ] != -1:
                    node_colors[ node ] = "gold"
                continue

            list_nm = nmaping.split( "," )
            for nm in list_nm:
                if mapping_gt[ node ] == int( nm ):
                    node_colors[ node ] = "lime"
                    break

                if mapping_gt[ node ] != int( nm ) and node_colors[ node ] == "gray":
                    node_colors[ node ] = "pink"

        for gn, sgn in mapping_gt.items():
            if node_labels[ gn ] == "" and sgn != -1:
                node_labels[ gn ] = str( sgn )

        edge_colors = { n: "whitesmoke" for n in graph.edges }
        for edge in graph.edges:
            n1, n2 = edge
            # map node from graph to node in subgraph
            n1_sgs, n2_sgs = node_labels[ n1 ], node_labels[ n2 ]

            if node_colors[ n1 ] == "gray" or node_colors[ n2 ] == "gray":
                continue

            # Check wheather a link between n1, n2 in subgraph
            total_pair = len( n1_sgs.split( "," ) ) * len( n2_sgs.split( "," ) )
            count_pair = 0
            for n1_sg in n1_sgs.split( "," ):
                n1_sg = int( n1_sg )
                for n2_sg in n2_sgs.split( "," ):
                    n2_sg = int( n2_sg )
                    if (n1_sg, n2_sg) not in subgraph.edges and (
                            n2_sg,
                            n1_sg,
                    ) not in subgraph.edges:
                        count_pair += 1

            if count_pair != total_pair:
                if node_colors[ n1 ] == "lime" and node_colors[ n2 ] == "lime":
                    edge_colors[ edge ] = "black"
                elif node_colors[ n1 ] == "gold" or node_colors[ n2 ] == "gold":
                    edge_colors[ edge ] = "goldenrod"
                elif node_colors[ n1 ] == "pink" or node_colors[ n2 ] == "pink":
                    edge_colors[ edge ] = "palevioletred"
            else:
                if node_colors[ n1 ] == "pink" or node_colors[ n2 ] == "pink":
                    edge_colors[ edge ] = "palevioletred"

        utils.plot_graph( graph,
                          nodeLabels=node_labels,
                          nodeColors=list( node_colors.values() ),
                          edgeColors=list( edge_colors.values() ),
                          title="Matching",
                          with_label=True )

In [None]:
utils.set_seed( 8 )
source_graph_idx = 49
query_graph_idx = 10

graph = graphs[ source_graph_idx ]
subgraph = subgraphs[ query_graph_idx ]
#subgraph = utils.inject_edge_errors( subgraph, 12 )
#subgraph = utils.random_subgraph( subgraph, 6 )
mapping_gt = mapping_gts[ query_graph_idx ]

plot_matching( graph, subgraph, mapping_gt )
is_subgraph( graph, subgraph, 0.9 )