In [None]:
import os
import numpy as np
import networkx as nx
import torchinfo
import torchviz

import matching.glema.common.utils as utils
from matching.glema.common.model import GLeMaNet
from matching.glema.common.model import InferenceGNN

In [None]:
args = utils.parse_args( use_default=True )
model_ckpt = f"{args.ckpt_dir}/SYNTHETIC_TINY_jump_directed_30e/best_model.pt"
args = utils.load_args( args, model_ckpt )
args.ckpt = model_ckpt
args.directed = True

In [None]:
def gen_example_graph( size, args ):
    G = utils.generate_graph( size, directed=args.directed )
    for _, data in G.nodes( data=True ):
        data[ "label" ] = 1
    return G

In [None]:
max_n = 55
batch_size = args.batch_size

In [None]:
def get_example_input( args, batch_size, max_n ):
    inf_model = InferenceGNN( args )
    input_queries = [ gen_example_graph( int( max_n * 0.3 ), args ) for i in range( batch_size ) ]
    input_sources = [ gen_example_graph( int( max_n * 0.7 ), args ) for i in range( batch_size ) ]
    list_inputs = inf_model.prepare_multi_input( input_queries, input_sources )
    return inf_model.input_to_tensor( list_inputs )

In [None]:
input_shapes = [
    (batch_size, max_n, args.embedding_dim * 2),
    (batch_size, max_n, max_n),
    (batch_size, max_n, max_n),
    (batch_size, max_n)
]
input_shapes

In [None]:
input_tensors = get_example_input( args, batch_size, max_n )
input_shapes = utils.get_shape_of_tensors( input_tensors )
input_shapes

In [None]:
model = utils.initialize_model( GLeMaNet( args ), utils.get_device() )
model.eval()

In [None]:
output_tensor = model( input_tensors )
output_tensor

In [None]:
torchinfo.summary( model,
                   #input_size=input_shapes,
                   #input_data=input_tensors,
                   device=utils.get_device() )

In [None]:
torchviz.make_dot(
    output_tensor,
    params=dict( list( model.named_parameters() ) )
).render( f"model_graph/{args.dataset}", format="png" )