In [None]:
import os
import pickle

import matplotlib.patches as patches
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

import matching.glema.common.utils.arg_utils as arg_utils
import matching.glema.common.utils.graph_utils as graph_utils
import matching.glema.common.utils.io_utils as io_utils
import matching.glema.common.utils.model_utils as model_utils
import matching.glema.common.utils.plot_utils as plot_utils
import matching.misc.cpg_const as cpg_const
from matching.glema.common.model import InferenceGNN
from matching.glema.common.utils.plot_utils import ColorScheme

In [None]:
args = arg_utils.parse_args( use_default=True )

args.dataset = "CPG_augm_large"
#args.dataset = "dpdf"
args.directed = True
args.anchored = True
version = model_utils.get_latest_model_version( args )
model_name = model_utils.get_model_name( args, version )
result_dir = os.path.join( args.result_dir, model_name )

args = arg_utils.load_args( args, model_name )
args.iso = True
args.test_data = True

In [None]:
DESIGN_PATTERN_MAPPING = {
    #cpg_const.DesignPatternType.ABSTRACT_FACTORY.value: "Abstract Factory",
    cpg_const.DesignPatternType.ADAPTER.value: "Adapter",
    cpg_const.DesignPatternType.BUILDER.value: "Builder",
    #cpg_const.DesignPatternType.FACADE.value: "Facade",
    cpg_const.DesignPatternType.FACTORY_METHOD.value: "Factory Method",
    cpg_const.DesignPatternType.OBSERVER.value: "Observer",
    cpg_const.DesignPatternType.SINGLETON.value: "Singleton",
    cpg_const.DesignPatternType.DECORATOR.value: "Decorator",
    #cpg_const.DesignPatternType.MEMENTO.value: "Memento",
    #cpg_const.DesignPatternType.PROTOTYPE.value: "Prototype",
    #cpg_const.DesignPatternType.PROXY.value: "Proxy",
    #cpg_const.DesignPatternType.VISITOR.value: "Visitor",
    #cpg_const.NO_DESIGN_PATTERN: "None",
}

matching_colors = {
    2: ColorScheme.HIGHLIGHT,
    1: ColorScheme.SECONDARY,
    0: ColorScheme.GREY_LIGHT,
    -1: ColorScheme.SECONDARY_COMP
}

color_legend = {
    matching_colors[ 2 ]: "Anchor",
    matching_colors[ 1 ]: "Iso",
    matching_colors[ -1 ]: "Non-Iso",
    matching_colors[ 0 ]: "Node",
}

NONE_TYPE = " "

In [None]:
matching_examples_file = io_utils.get_abs_file_path( os.path.join( result_dir, "matching_examples.pkl" ) )
with open( matching_examples_file, 'rb' ) as handle:
    matching_examples: dict[ tuple[ str, str ], dict[ str, any ] ] = pickle.load( handle )
matching_examples.keys()

In [None]:
source_type = cpg_const.DesignPatternType.FACTORY_METHOD.value
query_type = cpg_const.DesignPatternType.ADAPTER.value
source = matching_examples[ (source_type, query_type) ][ "source" ]
query = matching_examples[ (source_type, query_type) ][ "query" ]

plot_utils.plot_graph( source, title="Source",
                       nodeColors=graph_utils.get_node_colors( source ),
                       nodeLabels=graph_utils.get_node_labels( source ) )
plot_utils.plot_graph( query, title="Query",
                       nodeColors=graph_utils.get_node_colors( query ),
                       nodeLabels=graph_utils.get_node_labels( query ) )

combined, node_matches, edge_matches = graph_utils.combine_normalized( source, query )
node_iso = [ n for n in node_matches if n >= 1 ]
node_non_iso = [ n for n in node_matches if n < 0 ]
print( f"Source nodes: {source.number_of_nodes()}" )
print( f"Query nodes: {query.number_of_nodes()}" )
print( f"Combined nodes: {combined.number_of_nodes()}" )
print( f"Matching {len( node_iso ) + len( node_non_iso )} -- Iso: {len( node_iso )} / Non-Iso: {len( node_non_iso )}" )

combined, node_colors, edge_colors = graph_utils.combine_normalized( source, query, matching_colors=matching_colors )
plot_utils.plot_graph( combined, title="Combined", nodeColors=node_colors, edgeColors=edge_colors,
                       nodeLabels=graph_utils.get_node_labels( combined ) )

In [None]:
for (source_type, query_type), example in matching_examples.items():
    matched, node_colors, edge_colors = graph_utils.combine_normalized( example[ "source" ],
                                                                        example[ "query" ],
                                                                        matching_colors=matching_colors )
    example[ "matched" ] = matched
    example[ "node_colors" ] = node_colors
    example[ "edge_colors" ] = edge_colors

raw_type_examples = { }
for (source_type, query_type), example in matching_examples.items():
    if (source_type, NONE_TYPE) not in raw_type_examples:
        raw_type_examples[ (source_type, NONE_TYPE) ] = {
            "matched": example[ "source" ],
            "node_colors": graph_utils.get_node_colors( example[ "source" ],
                                                        anchor_color=matching_colors[ 2 ],
                                                        node_color=matching_colors[ 0 ] ),
            "edge_colors": matching_colors[ 0 ]
        }
    if (NONE_TYPE, query_type) not in raw_type_examples:
        raw_type_examples[ (NONE_TYPE, query_type) ] = {
            "matched": example[ "query" ],
            "node_colors": graph_utils.get_node_colors( example[ "query" ],
                                                        anchor_color=matching_colors[ 2 ],
                                                        node_color=matching_colors[ 0 ] ),
            "edge_colors": matching_colors[ 0 ]
        }

for type, example in raw_type_examples.items():
    matching_examples[ type ] = example

In [None]:
def plot_graph_matching( graph_types: list[ str ],
                         graph_matching_examples: dict,
                         color_legend: dict[ str, str ] = None,
                         font_size=6,
                         save_name=None ):
    n = len( graph_types ) + 1
    fig, axes = plt.subplots( n, n, figsize=(4 * n, 4 * n) )

    # Ensure axes is 2D even for a single plot.
    if n == 1:
        axes = [ [ axes ] ]

    x_graph_type = [ NONE_TYPE, *graph_types ]
    y_graph_type = [ *graph_types, NONE_TYPE ]

    # Loop over each cell.
    for i, source in enumerate( y_graph_type ):
        for j, query in enumerate( x_graph_type ):
            ax = axes[ i ][ j ]
            key = (source, query)
            if key in graph_matching_examples:
                matched_graph = graph_matching_examples[ key ][ "matched" ]
                matching_pred = graph_matching_examples[ key ].get( "pred", None )
                # Retrieve the matching colors if provided.
                node_colors = graph_matching_examples[ key ].get( "node_colors", None )
                edge_colors = graph_matching_examples[ key ].get( "edge_colors", None )
                # Plot the matched graph using the provided nodeColors.
                title = f"p={matching_pred:.4f}" if matching_pred is not None else ""
                ax = plot_utils.plot_graph( matched_graph, font_size=font_size,
                                            ax=ax, title=title, show_title=True, with_label=False,
                                            node_sizes=70, edge_width=1.8,
                                            nodeColors=node_colors, edgeColors=edge_colors )
                # Draw a border around each graph plot.
                # Using a rectangle patch in axes coordinates (from 0,0 to 1,1).

                border_width = 1.5
                if source == query:
                    #if matching_pred is not None and matching_pred > 0.8:
                    border_width *= 3
                line_style = None
                if NONE_TYPE in [ source, query ]:
                    line_style = "--"

                rect = patches.Rectangle( (0, 0), 1, 1, transform=ax.transAxes,
                                          fill=False, edgecolor="black", linewidth=border_width,
                                          linestyle=line_style )
                ax.add_patch( rect )
            else:
                #ax.text( 0.5, 0.5, "No graph", ha="center", va="center", fontsize=10 )
                ax.axis( "off" )
            # Remove tick labels for a cleaner look.
            ax.set_xticks( [ ] )
            ax.set_yticks( [ ] )

    # Adjust the layout to leave room for outer labels.
    plt.subplots_adjust( left=0.1, bottom=0.1, top=0.95, right=0.95 )

    # Add query (x-axis) labels at the bottom center of each column.
    for j, query in enumerate( x_graph_type ):
        if query not in DESIGN_PATTERN_MAPPING:
            continue
        pos = axes[ -1 ][ j ].get_position()
        x = (pos.x0 + pos.x1) / 2.0
        y = pos.y0 - 0.02  # position slightly below the subplot
        fig.text( x, y, DESIGN_PATTERN_MAPPING[ query ],
                  ha="center", va="top", fontsize=font_size * 1.8 )

    # Add source (y-axis) labels at the left center of each row.
    for i, source in enumerate( y_graph_type ):
        if source not in DESIGN_PATTERN_MAPPING:
            continue
        pos = axes[ i ][ 0 ].get_position()
        x = pos.x0 - 0.02  # position slightly left of the subplot
        y = (pos.y0 + pos.y1) / 2.0
        fig.text( x, y, DESIGN_PATTERN_MAPPING[ source ],
                  ha="right", va="center", fontsize=font_size * 1.8, rotation='vertical' )

    # Add overall axis labels.
    fig.supxlabel( "Query Graphs", fontsize=font_size * 1.9 )
    fig.supylabel( "Source Graphs", fontsize=font_size * 1.9 )

    # Create and display a legend at the top right of the whole plot.
    if color_legend is not None:
        legend_elements = [ Patch( facecolor=color, edgecolor='black', label=label )
                            for color, label in color_legend.items() ]
        #fig.legend( handles=legend_elements, bbox_to_anchor=(0.1, 0.08) )
        legend_pos = axes[ n - 1 ][ 0 ].get_position()
        fig.legend( handles=legend_elements, bbox_to_anchor=legend_pos, fontsize=font_size * 1.3 )
    if save_name is None:
        plt.show()
    else:
        plt.savefig( f"plots/{save_name}.png" )


#pattern_types = list( DESIGN_PATTERN_MAPPING.keys() )[3:]
#pattern_types = list( DESIGN_PATTERN_MAPPING.keys() )
plot_graph_matching( list( DESIGN_PATTERN_MAPPING.keys() ),
                     matching_examples, font_size=16,
                     color_legend=color_legend,
                     save_name="matching_examples" )

In [None]:
plot_graph_matching( list( DESIGN_PATTERN_MAPPING.keys() )[ 3: ],
                     matching_examples, font_size=14,
                     color_legend=color_legend,
                     save_name="matching_examples_3_1" )
plot_graph_matching( list( DESIGN_PATTERN_MAPPING.keys() )[ :3 ],
                     matching_examples, font_size=14,
                     color_legend=color_legend,
                     save_name="matching_examples_3_2" )

In [None]:
model = InferenceGNN( args )

In [None]:
def predict( model, G_source, G_query ) -> float:
    p, (x, y) = model.predict( G_source, G_query )
    return p


def is_subgraph( model: InferenceGNN, G_source, G_query, conf=0.5 ) -> bool:
    p, (x, y) = model.predict( G_source, G_query, conf=conf )
    print( f"query is subgraph of source: {x == 1.0} [{p:.3}]" )
    return x == 1.0

In [None]:
def get_interaction_mapping( args, model, source_idx, query_idx, color_mapping=None, conf=0.5 ):
    source = graph_utils.load_source_graph( args, source_idx, relabel=False )
    query = graph_utils.load_query_graph( args, source_idx, query_idx, relabel=False )
    ground_truth = graph_utils.load_query_id_mapping( args, source_idx, query_idx, flip=False )
    interactions = graph_utils.compute_interactions( model, source, query, threshold=conf )

    pred_mapping = { }
    for (pred_query, pred_source), pred in interactions.items():
        if pred_source not in pred_mapping:
            pred_mapping[ pred_source ] = [ ]
        pred_mapping[ pred_source ].append( pred_query )

    source_matching = { }
    for n_source in source.nodes:
        if n_source not in ground_truth:
            # pred has predicted invalid mapping
            source_matching[ n_source ] = -1 if n_source in pred_mapping else 0
            continue
        if n_source not in pred_mapping:
            # has not predicted valid mapping
            #source_matching[ n_source ] = -1
            source_matching[ n_source ] = 1
            continue
        if ground_truth[ n_source ] in pred_mapping[ n_source ]:
            # exact match
            source_matching[ n_source ] = 2
            continue
        # any match
        source_matching[ n_source ] = 1

    edge_mapping = { }
    for (edge_from, edge_to) in source.edges:
        edge_mapping[ (edge_from, edge_to) ] = min( source_matching[ edge_from ], source_matching[ edge_to ] )

    if color_mapping is None:
        return source, source_matching.values(), edge_mapping.values()

    return (source,
            [ color_mapping[ n ] for n in source_matching.values() ],
            [ color_mapping[ e ] for e in edge_mapping.values() ])


def get_ground_truth( args, source_graph_idx, query_subgraph_idx, color_mapping=None ):
    source = graph_utils.load_source_graph( args, source_graph_idx, relabel=False )
    # source: query
    ground_truth = graph_utils.load_query_id_mapping( args, source_graph_idx, query_subgraph_idx, flip=False )
    node_mapping = { }
    for n in source.nodes:
        if n in ground_truth:
            node_mapping[ n ] = 1
        else:
            node_mapping[ n ] = 0
    edge_mapping = { }
    for (edge_from, edge_to) in source.edges:
        edge_mapping[ (edge_from, edge_to) ] = min( node_mapping[ edge_from ], node_mapping[ edge_to ] )

    if color_mapping is None:
        return source, node_mapping.values(), edge_mapping.values()
    return (source,
            [ color_mapping[ n ] for n in node_mapping.values() ],
            [ color_mapping[ e ] for e in edge_mapping.values() ])


for iso in [ True, False ]:
    args.iso = iso
    source_graph_idx = 0
    query_subgraph_idx = 5
    save_prefix = "node_pred_iso" if args.iso else "node_pred_non_iso"
    figsize = (12, 12)
    font_size = 15
    edge_width = 2
    node_size = 700

    source = graph_utils.load_source_graph( args, source_graph_idx )
    query = graph_utils.load_query_graph( args, source_graph_idx, query_subgraph_idx )
    plot_utils.plot_graph( source, show_title=False, save_name=f"plots/{save_prefix}_source.png", figsize=figsize,
                           color_legend={
                               matching_colors[ 0 ]: "Node",
                               matching_colors[ 2 ]: "Anchor"
                           },
                           font_size=font_size, node_sizes=node_size, edge_width=edge_width,
                           nodeColors=graph_utils.get_node_colors( source,
                                                                   node_color=matching_colors[ 0 ],
                                                                   anchor_color=matching_colors[ 2 ] ) )

    plot_utils.plot_graph( query, show_title=False, save_name=f"plots/{save_prefix}_query.png", figsize=figsize,
                           color_legend={
                               matching_colors[ 0 ]: "Node",
                               matching_colors[ 2 ]: "Anchor"
                           },
                           font_size=font_size, node_sizes=node_size, edge_width=edge_width,
                           nodeColors=graph_utils.get_node_colors( query,
                                                                   node_color=matching_colors[ 0 ],
                                                                   anchor_color=matching_colors[ 2 ] ) )
    
    combined, combined_node_colors, combined_edge_colors = graph_utils.combine_graph( source, query, anchor=graph_utils.get_anchor( source ), matching_colors=matching_colors )
    plot_utils.plot_graph( combined, show_title=False, save_name=f"plots/{save_prefix}_combined.png", figsize=figsize,
                           font_size=font_size, node_sizes=node_size, edge_width=edge_width,
                           color_legend={
                               matching_colors[ 0 ]: "Node",
                               matching_colors[ 1 ]: "Match",
                               matching_colors[ 2 ]: "Anchor",
                               matching_colors[ -1 ]: "No Match",
                           },
                           nodeColors=combined_node_colors, edgeColors=combined_edge_colors, with_label=False )

    gt_source, gt_node_colors, gt_edge_colors = get_ground_truth( args, source_graph_idx, query_subgraph_idx,
                                                                  color_mapping=matching_colors )
    plot_utils.plot_graph( gt_source, show_title=False, save_name=f"plots/{save_prefix}_gt.png", figsize=figsize,
                           font_size=font_size, node_sizes=node_size, edge_width=edge_width,
                           color_legend={
                               matching_colors[ 0 ]: "Node",
                               matching_colors[ 1 ]: "Match"
                           },
                           nodeColors=gt_node_colors, edgeColors=gt_edge_colors, with_label=False )

    conf = 0.6
    #graph_utils.plot_interactions( args, model, source_graph_idx, query_subgraph_idx, threshold=conf )
    source, node_colors, edge_colors = get_interaction_mapping( args, model,
                                                                source_graph_idx, query_subgraph_idx,
                                                                color_mapping=matching_colors, conf=conf )
    plot_utils.plot_graph( source, show_title=False, save_name=f"plots/{save_prefix}_p.png", figsize=figsize,
                           font_size=font_size, node_sizes=node_size, edge_width=edge_width,
                           color_legend={
                               matching_colors[ 0 ]: "Node",
                               matching_colors[ 1 ]: "Match",
                               matching_colors[ 2 ]: "Exact Match",
                               matching_colors[ -1 ]: "No Match",
                           },
                           nodeColors=node_colors, edgeColors=edge_colors, with_label=False )