In [1]:
import os
import numpy as np
import networkx as nx
import torchinfo
import torchviz
import torch

import matching.glema.common.utils.arg_utils as arg_utils
import matching.glema.common.utils.graph_utils as graph_utils
import matching.glema.common.utils.model_utils as model_utils
from matching.glema.common.model import GLeMaNet
from matching.glema.common.model import InferenceGNN

In [2]:
args = arg_utils.parse_args( use_default=True )

args.dataset = "CPG_augm_large"
args.directed = False
args.anchored = True
version = model_utils.get_latest_model_version( args )
model_name = model_utils.get_model_name( args, version )

args = arg_utils.load_args( args, model_name )
args.iso = True
args.test_data = True

In [3]:
def gen_example_graph( size, args ):
    G = graph_utils.generate_graph( size, directed=args.directed )
    for id, data in G.nodes( data=True ):
        data[ "label" ] = 1
        data[ "anchor" ] = 1 if id == 0 else 0
    return G

In [4]:
max_n = 55
batch_size = args.batch_size

In [5]:
def get_example_input( args, batch_size, max_n ):
    inf_model = InferenceGNN( args )
    input_queries = [ gen_example_graph( int( max_n * 0.3 ), args ) for i in range( batch_size ) ]
    input_sources = [ gen_example_graph( int( max_n * 0.7 ), args ) for i in range( batch_size ) ]
    list_inputs = inf_model.prepare_multi_input( input_queries, input_sources )
    return inf_model.input_to_tensor( list_inputs )

In [6]:
input_shapes = [
    (batch_size, max_n, args.embedding_dim * 2),
    (batch_size, max_n, max_n),
    (batch_size, max_n, max_n),
    (batch_size, max_n)
]
input_shapes

[(128, 55, 12), (128, 55, 55), (128, 55, 55), (128, 55)]

In [7]:
input_tensors = get_example_input( args, batch_size, max_n )
input_shapes = model_utils.get_shape_of_tensors( input_tensors )
input_shapes

Loading model from /Users/jeanjour/Documents/projects/python/dpd-subgraph-matching/matching/glema/training/save/CPG_augm_large_undirected_anchored_v1/model.pt ...


[(128, 54, 12), (128, 54, 54), (128, 54, 54), (128, 54)]

In [8]:
model = model_utils.initialize_model( GLeMaNet( args ), model_utils.get_device() )
model.eval()

Init default model ...


GLeMaNet(
  (gconv1): ModuleList(
    (0-3): 4 x GLeMa(
      (W_h): Linear(in_features=140, out_features=140, bias=True)
      (W_beta): Linear(in_features=280, out_features=1, bias=True)
    )
  )
  (FC): ModuleList(
    (0): Linear(in_features=140, out_features=128, bias=True)
    (1-2): 2 x Linear(in_features=128, out_features=128, bias=True)
    (3): Linear(in_features=128, out_features=1, bias=True)
  )
  (embede): Linear(in_features=12, out_features=140, bias=False)
)

In [9]:
output_tensor = model( input_tensors )
output_tensor

tensor([0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110,
        0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110,
        0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110,
        0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110,
        0.5110, 0.5110, 0.5110, 0.5110, 0.5111, 0.5110, 0.5110, 0.5110, 0.5110,
        0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110,
        0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110,
        0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110,
        0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110,
        0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110,
        0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5111, 0.5110, 0.5110,
        0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110, 0.5110,
        0.5110, 0.5110, 0.5110, 0.5110, 

In [10]:
torchinfo.summary( model,
                   #input_size=input_shapes,
                   #input_data=input_tensors,
                   device=model_utils.get_device() )

Layer (type:depth-idx)                   Param #
GLeMaNet                                 --
├─ModuleList: 1-1                        --
│    └─GLeMa: 2-1                        19,600
│    │    └─Linear: 3-1                  19,740
│    │    └─Linear: 3-2                  281
│    └─GLeMa: 2-2                        19,600
│    │    └─Linear: 3-3                  19,740
│    │    └─Linear: 3-4                  281
│    └─GLeMa: 2-3                        19,600
│    │    └─Linear: 3-5                  19,740
│    │    └─Linear: 3-6                  281
│    └─GLeMa: 2-4                        19,600
│    │    └─Linear: 3-7                  19,740
│    │    └─Linear: 3-8                  281
├─ModuleList: 1-2                        --
│    └─Linear: 2-5                       18,048
│    └─Linear: 2-6                       16,512
│    └─Linear: 2-7                       16,512
│    └─Linear: 2-8                       129
├─Linear: 1-3                            1,680
Total params: 211,3

In [16]:
if True:
    torchviz.make_dot(
        output_tensor,
        params=dict( list( model.named_parameters() ) )
    ).render( f"model_graph/{args.dataset}", format="png" )

In [12]:
#torch.onnx.export(model, list( input_tensors ), f"model_graph/{args.dataset}.onnx", dynamo=False )