In [1]:
import os
import numpy as np
import networkx as nx
import torchinfo
import torchviz

import matching.glema.common.utils as utils
from matching.glema.common.model import GLeMaNet
from matching.glema.common.model import InferenceGNN

In [2]:
args = utils.parse_args( use_default=True )
model_ckpt = f"{args.ckpt_dir}/CPG_best_with_pivot_emb/best_model.pt"
args = utils.load_args( args, model_ckpt )
args.ckpt = model_ckpt
args.directed = True

In [3]:
def gen_example_graph( size, args ):
    G = utils.generate_graph( size, directed=args.directed )
    for id, data in G.nodes( data=True ):
        data[ "label" ] = 1
        data[ "pivot" ] = 1 if id == 0 else 0
    return G

In [4]:
max_n = 55
batch_size = args.batch_size

In [5]:
def get_example_input( args, batch_size, max_n ):
    inf_model = InferenceGNN( args )
    input_queries = [ gen_example_graph( int( max_n * 0.3 ), args ) for i in range( batch_size ) ]
    input_sources = [ gen_example_graph( int( max_n * 0.7 ), args ) for i in range( batch_size ) ]
    list_inputs = inf_model.prepare_multi_input( input_queries, input_sources )
    return inf_model.input_to_tensor( list_inputs )

In [6]:
input_shapes = [
    (batch_size, max_n, args.embedding_dim * 2),
    (batch_size, max_n, max_n),
    (batch_size, max_n, max_n),
    (batch_size, max_n)
]
input_shapes

[(128, 55, 12), (128, 55, 55), (128, 55, 55), (128, 55)]

In [7]:
input_tensors = get_example_input( args, batch_size, max_n )
input_shapes = utils.get_shape_of_tensors( input_tensors )
input_shapes

Init default model ...


[(128, 54, 12), (128, 54, 54), (128, 54, 54), (128, 54)]

In [8]:
model = utils.initialize_model( GLeMaNet( args ), utils.get_device() )
model.eval()

Init default model ...


GLeMaNet(
  (gconv1): ModuleList(
    (0-3): 4 x GLeMa(
      (W_h): Linear(in_features=140, out_features=140, bias=True)
      (W_beta): Linear(in_features=280, out_features=1, bias=True)
    )
  )
  (FC): ModuleList(
    (0): Linear(in_features=140, out_features=128, bias=True)
    (1-2): 2 x Linear(in_features=128, out_features=128, bias=True)
    (3): Linear(in_features=128, out_features=1, bias=True)
  )
  (embede): Linear(in_features=12, out_features=140, bias=False)
)

In [9]:
output_tensor = model( input_tensors )
output_tensor

tensor([0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147,
        0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147,
        0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147,
        0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147,
        0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147,
        0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147,
        0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147,
        0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147,
        0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147,
        0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147,
        0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147,
        0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147, 0.5147,
        0.5147, 0.5147, 0.5147, 0.5147, 

In [10]:
torchinfo.summary( model,
                   #input_size=input_shapes,
                   #input_data=input_tensors,
                   device=utils.get_device() )

Layer (type:depth-idx)                   Param #
GLeMaNet                                 --
├─ModuleList: 1-1                        --
│    └─GLeMa: 2-1                        19,600
│    │    └─Linear: 3-1                  19,740
│    │    └─Linear: 3-2                  281
│    └─GLeMa: 2-2                        19,600
│    │    └─Linear: 3-3                  19,740
│    │    └─Linear: 3-4                  281
│    └─GLeMa: 2-3                        19,600
│    │    └─Linear: 3-5                  19,740
│    │    └─Linear: 3-6                  281
│    └─GLeMa: 2-4                        19,600
│    │    └─Linear: 3-7                  19,740
│    │    └─Linear: 3-8                  281
├─ModuleList: 1-2                        --
│    └─Linear: 2-5                       18,048
│    └─Linear: 2-6                       16,512
│    └─Linear: 2-7                       16,512
│    └─Linear: 2-8                       129
├─Linear: 1-3                            1,680
Total params: 211,3

In [11]:
if False:
    torchviz.make_dot(
        output_tensor,
        params=dict( list( model.named_parameters() ) )
    ).render( f"model_graph/{args.dataset}", format="png" )