In [None]:
!pip install torch_geometric

# Simulated GCN Model

In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

class gcn_graph_model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(gcn_graph_model, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(5, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 2)

    def forward(self, x, edge_index, batch):
        outputs = {}
        # print("Input shapes:")
        # print("x:", x.shape)
        # print("edge_index:", edge_index.shape)
        # print("batch:", batch.shape)
        # 1. Obtain node embeddings
        edge_index = edge_index.to(torch.int64)
        batch = batch.to(torch.int64)
        x = self.conv1(x, edge_index)
        x = x.relu()
        outputs['conv1'] = x
        x = self.conv2(x, edge_index)
        x = x.relu()
        outputs['conv2'] = x
        x = self.conv3(x, edge_index)
        outputs['conv3'] = x
        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]
        outputs["pooling"] = x
        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        outputs["dropout"] = x
        x = self.lin(x)
        outputs['final'] = x
        return outputs

# Simulated GCN(node-based) Model

In [27]:
class gcn_node_model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(gcn_node_model, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(5, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 2)

    def forward(self, x, edge_index, batch):
        outputs = {}
        h = self.conv1(x, edge_index)
        h = h.tanh()
        outputs["conv1"] = h
        h = self.conv2(h, edge_index)
        h = h.tanh()
        outputs["conv2"] = h
        h = self.conv3(h, edge_index)
        h = h.tanh()  # Final GNN embedding space.
        outputs["conv3"] = h
        # Apply a final (linear) classifier.
        out = self.lin(h)
        outputs["final"] = out
        return outputs

# Simulated GCN(edge-based) Model

In [28]:
import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv

class gcn_edge_model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(gcn_edge_model, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(5, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 2)  # 可选：如果做二分类标签预测

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.conv3(x, edge_index).tanh()
        return x

    def decode(self, z, edge_label_index):
        # Inner product decoder
        x = (z[edge_label_index[0]] * z[edge_label_index[1]])
        x = x.sum(dim=-1)  # shape [num_edges]
        return x

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

    def forward(self, x, edge_index, edge_label_index, batch):
        outputs = {}
        h = self.conv1(x, edge_index).tanh()
        outputs["conv1"] = h
        h = self.conv2(h, edge_index).tanh()
        outputs["conv2"] = h
        h = self.conv3(h, edge_index).tanh()
        outputs["conv3"] = h

        # Decode edge probabilities
        decode_mul = h[edge_label_index[0]] * h[edge_label_index[1]]
        decode_sum = decode_mul.sum(dim=-1)
        outputs["decode_mul"] = decode_mul
        outputs["decode_sum"] = decode_sum

        prob_adj = h @ h.t()
        outputs["prob_adj"] = prob_adj
        outputs["decode_all_final"] = (prob_adj > 0).nonzero(as_tuple=False).t()
        return outputs

## GAT Simulated Model

In [29]:
from torch_geometric.nn import GATConv

class gat_edge_model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(gat_edge_model, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GATConv(5, hidden_channels, heads=1, concat=False)
        self.conv2 = GATConv(hidden_channels, hidden_channels, heads=1, concat=False)
        self.conv3 = GATConv(hidden_channels, hidden_channels, heads=1, concat=False)
        self.lin = Linear(hidden_channels, 2)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.conv3(x, edge_index).tanh()
        return x

    def decode(self, z, edge_label_index):
        x = (z[edge_label_index[0]] * z[edge_label_index[1]])
        x = x.sum(dim=-1)
        return x

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

    def forward(self, x, edge_index, edge_label_index, batch):
        outputs = {}
        h = self.conv1(x, edge_index).tanh()
        outputs["conv1"] = h
        h = self.conv2(h, edge_index).tanh()
        outputs["conv2"] = h
        h = self.conv3(h, edge_index).tanh()
        outputs["conv3"] = h

        decode_mul = h[edge_label_index[0]] * h[edge_label_index[1]]
        decode_sum = decode_mul.sum(dim=-1)
        outputs["decode_mul"] = decode_mul
        outputs["decode_sum"] = decode_sum

        prob_adj = h @ h.t()
        outputs["prob_adj"] = prob_adj
        outputs["decode_all_final"] = (prob_adj > 0).nonzero(as_tuple=False).t()
        return outputs

## Customized SAGEConv Layer

In [41]:
import torch
import torch.nn.functional as F
from torch.nn import Parameter
# from graphSAGE.CustomizedMessgPassing import CustomizedMessagePassing
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.nn.conv import MessagePassing
import math

def uniform(size, tensor):
    bound = 1.0 / math.sqrt(size)
    if tensor is not None:
        tensor.data.uniform_(-bound, bound)

class testClass():
    def __init__(self):
        pass


class ConvSAGE(MessagePassing):
    r"""The GraphSAGE operator from the `"Inductive Representation Learning on
    Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper

    .. math::
        \mathbf{\hat{x}}_i &= \mathbf{\Theta} \cdot
        \mathrm{mean}_{j \in \mathcal{N(i) \cup \{ i \}}}(\mathbf{x}_j)

        \mathbf{x}^{\prime}_i &= \frac{\mathbf{\hat{x}}_i}
        {\| \mathbf{\hat{x}}_i \|_2}.

    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        normalize (bool, optional): If set to :obj:`True`, output features
            will be :math:`\ell_2`-normalized. (default: :obj:`False`)
        concat (bool, optional): If set to :obj:`True`, will concatenate
            current node features with aggregated ones. (default: :obj:`False`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    def __init__(self, in_channels, out_channels, normalize=False,
                 concat=True, bias=True, **kwargs):
        super(ConvSAGE, self).__init__(aggr='mean', **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.concat = concat

        in_channels = 2 * in_channels if concat else in_channels
        self.weight = Parameter(torch.Tensor(in_channels, out_channels))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        uniform(self.weight.size(0), self.weight)
        uniform(self.weight.size(0), self.bias)

    def forward(self, x, edge_index, edge_weight=None, size=None,
                res_n_id=None):
        """
        Args:
            res_n_id (Tensor, optional): Residual node indices coming from
                :obj:`DataFlow` generated by :obj:`NeighborSampler` are used to
                select central node features in :obj:`x`.
                Required if operating in a bipartite graph and :obj:`concat` is
                :obj:`True`. (default: :obj:`None`)
        """
        if not self.concat and torch.is_tensor(x):
            edge_index, edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, 1, x.size(0))

        return self.propagate(edge_index, size=size, x=x,
                              edge_weight=edge_weight, res_n_id=res_n_id)

    def message(self, x_j, edge_weight):
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def update(self, aggr_out, x, res_n_id):
        if self.concat and torch.is_tensor(x):
            aggr_out = torch.cat([x, aggr_out], dim=-1)
        elif self.concat and (isinstance(x, tuple) or isinstance(x, list)):
            assert res_n_id is not None
            aggr_out = torch.cat([x[0][res_n_id], aggr_out], dim=-1)

        aggr_out = torch.matmul(aggr_out, self.weight)

        if self.bias is not None:
            aggr_out = aggr_out + self.bias

        if self.normalize:
            aggr_out = F.normalize(aggr_out, p=2, dim=-1)

        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)

## GraphSAGE Simulated Model

In [39]:
from torch_geometric.nn import SAGEConv

class sage_edge_model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(sage_edge_model, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = ConvSAGE(5, hidden_channels)
        self.conv2 = ConvSAGE(hidden_channels, hidden_channels)
        self.conv3 = ConvSAGE(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 2)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.conv3(x, edge_index).tanh()
        return x

    def decode(self, z, edge_label_index):
        x = (z[edge_label_index[0]] * z[edge_label_index[1]])
        x = x.sum(dim=-1)
        return x

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

    def forward(self, x, edge_index, edge_label_index, batch):
        outputs = {}
        h = self.conv1(x, edge_index).tanh()
        outputs["conv1"] = h
        h = self.conv2(h, edge_index).tanh()
        outputs["conv2"] = h
        h = self.conv3(h, edge_index).tanh()
        outputs["conv3"] = h

        decode_mul = h[edge_label_index[0]] * h[edge_label_index[1]]
        decode_sum = decode_mul.sum(dim=-1)
        outputs["decode_mul"] = decode_mul
        outputs["decode_sum"] = decode_sum

        prob_adj = h @ h.t()
        outputs["prob_adj"] = prob_adj
        outputs["decode_all_final"] = (prob_adj > 0).nonzero(as_tuple=False).t()
        return outputs

# Model Building

In [31]:
import numpy as np
import json 

with open("/kaggle/input/graphs/testing_graph.json", "r") as f:
    testing_graph = json.load(f)

x = torch.tensor(testing_graph["x"], dtype=torch.float32)  # shape: [num_nodes, feature_dim]
edge_index = torch.tensor(testing_graph["edge_index"], dtype=torch.long)

In [42]:
x = torch.tensor(testing_graph['x'], dtype=torch.float32)
edge_index = torch.tensor(testing_graph['edge_index'], dtype=torch.int64)
batch = torch.tensor(testing_graph['batch'], dtype=torch.int64)

gcn_graph_model_names = ['conv1', 'conv2', 'conv3', 'pooling','dropout','final']
gcn_node_model_names = ['conv1', 'conv2', 'conv3', 'final']
gcn_edge_model_names = ['conv1', 'conv2', 'conv3', 'decode_mul', 'decode_sum', 'prob_adj', 'decode_all_final']

gcn_graph_model_var = gcn_graph_model(hidden_channels=16)
gcn_node_model_var = gcn_node_model(hidden_channels=16)
gcn_edge_model_var = gcn_edge_model(hidden_channels=16)

gat_edge_model_var = gat_edge_model(hidden_channels=16)
sage_edge_model_var = sage_edge_model(hidden_channels=16)

model_name_arrays = [
    gcn_graph_model_names,
    gcn_node_model_names,
    gcn_edge_model_names,
    gcn_edge_model_names,
    gcn_edge_model_names
]

models = [
    gcn_graph_model_var,
    gcn_node_model_var,
    gcn_edge_model_var,
    gat_edge_model_var,
    sage_edge_model_var
]

model_names = [
    'gcn_graph_model',
    'gcn_node_model',
    'gcn_edge_model',
    'gat_edge_model',
    'sage_edge_model'
]

dummy_input = (x, edge_index, batch)

edge_label_index = edge_index
dummy_edge_input = (x, edge_index, edge_label_index, batch)

input_arrays = [
    dummy_input,
    dummy_input,
    dummy_edge_input,
    dummy_edge_input,
    dummy_edge_input
]

def model_building(input_template, output_names, model, model_name):
    state_dict = model.state_dict()
    # print(state_dict)
    json_state_dict = {k: v.tolist() for k, v in state_dict.items()}

    print(json_state_dict.keys())
    
    with open(f"/kaggle/working/simulated_{model_name}_weights.json", "w") as f:
        json.dump(json_state_dict, f, indent=2)
    torch.onnx.export(model,               # model being run
                      input_template,         # model input 
                      f"/kaggle/working/simulated_{model_name}.onnx",    # where to save the model
                      export_params=True,  # store the trained parameter weights inside the model file
                      opset_version=17,    # the ONNX version to export the model to
                    #   do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['x', 'edge_index', 'batch'],   # the model's input names
                      output_names=output_names,
                      dynamic_axes={'x': {0: 'num_nodes'},
                                    'edge_index': {1: 'num_edges'},
                                    'batch': {0: 'num_nodes'},
                                    'output': {0: 'batch_size'}})  # which axes should be considered dynamic)

In [43]:
for i in range(5):
    model_building(input_arrays[i], model_name_arrays[i], models[i], model_names[i])

dict_keys(['conv1.bias', 'conv1.lin.weight', 'conv2.bias', 'conv2.lin.weight', 'conv3.bias', 'conv3.lin.weight', 'lin.weight', 'lin.bias'])
dict_keys(['conv1.bias', 'conv1.lin.weight', 'conv2.bias', 'conv2.lin.weight', 'conv3.bias', 'conv3.lin.weight', 'lin.weight', 'lin.bias'])
dict_keys(['conv1.bias', 'conv1.lin.weight', 'conv2.bias', 'conv2.lin.weight', 'conv3.bias', 'conv3.lin.weight', 'lin.weight', 'lin.bias'])
dict_keys(['conv1.att_src', 'conv1.att_dst', 'conv1.bias', 'conv1.lin.weight', 'conv2.att_src', 'conv2.att_dst', 'conv2.bias', 'conv2.lin.weight', 'conv3.att_src', 'conv3.att_dst', 'conv3.bias', 'conv3.lin.weight', 'lin.weight', 'lin.bias'])
dict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias', 'lin.weight', 'lin.bias'])


# Model Testing

In [44]:
!pip install onnxruntime

Collecting onnxruntime
  Downloading onnxruntime-1.22.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.6 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnxruntime-1.22.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.5/16.5 MB[0m [31m63.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m862.0 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m3.7 MB/s[0m eta [36m0:00

In [45]:
import onnxruntime

def test_all_onnx_models(model_names, input_arrays):
    print("Starting ONNX model testing...\n")
    
    for model_name, input_data in zip(model_names, input_arrays):
        model_path = f"/kaggle/working/simulated_{model_name}.onnx"
        print(f"Testing model: {model_name}")
        
        # 加载 ONNX 模型
        session = onnxruntime.InferenceSession(model_path, providers=['CPUExecutionProvider'])
        
        # 提取输入名称（顺序和 onnx 导出时一致）
        input_names = [inp.name for inp in session.get_inputs()]
        input_feed = {}

        # 构造输入映射：PyTorch → NumPy
        for name, tensor in zip(input_names, input_data):
            input_feed[name] = tensor.detach().cpu().numpy()

        # 提取输出名称
        output_names = [out.name for out in session.get_outputs()]

        # 执行 ONNX 推理
        outputs = session.run(output_names, input_feed)
        
        for i, output in enumerate(outputs):
            print(f"  Output {i} ({output_names[i]}): shape={output.shape}")
        
        print("-" * 50)

In [46]:
test_all_onnx_models(model_names, input_arrays)

Starting ONNX model testing...

Testing model: gcn_graph_model
  Output 0 (conv1): shape=(5, 16)
  Output 1 (conv2): shape=(5, 16)
  Output 2 (conv3): shape=(5, 16)
  Output 3 (pooling): shape=(1, 16)
  Output 4 (dropout): shape=(1, 16)
  Output 5 (final): shape=(1, 2)
--------------------------------------------------
Testing model: gcn_node_model
  Output 0 (conv1): shape=(5, 16)
  Output 1 (conv2): shape=(5, 16)
  Output 2 (conv3): shape=(5, 16)
  Output 3 (final): shape=(5, 2)
--------------------------------------------------
Testing model: gcn_edge_model
  Output 0 (conv1): shape=(5, 16)
  Output 1 (conv2): shape=(5, 16)
  Output 2 (conv3): shape=(5, 16)
  Output 3 (decode_mul): shape=(14, 16)
  Output 4 (decode_sum): shape=(14,)
  Output 5 (prob_adj): shape=(5, 5)
  Output 6 (decode_all_final): shape=(2, 25)
--------------------------------------------------
Testing model: gat_edge_model
  Output 0 (conv1): shape=(5, 16)
  Output 1 (conv2): shape=(5, 16)
  Output 2 (conv3): shap

[0;93m2025-07-16 18:24:33.043091310 [W:onnxruntime:, execution_frame.cc:876 VerifyOutputSizes] Expected shape from model of {-1} does not match actual shape of {5,1} for output /conv1/Identity_3_output_0[m
[0;93m2025-07-16 18:24:33.043182377 [W:onnxruntime:, execution_frame.cc:876 VerifyOutputSizes] Expected shape from model of {-1} does not match actual shape of {5,1} for output /conv1/Identity_4_output_0[m
[0;93m2025-07-16 18:24:33.043217389 [W:onnxruntime:, execution_frame.cc:876 VerifyOutputSizes] Expected shape from model of {-1} does not match actual shape of {5,1} for output /conv1/Identity_7_output_0[m
[0;93m2025-07-16 18:24:33.043557850 [W:onnxruntime:, execution_frame.cc:876 VerifyOutputSizes] Expected shape from model of {-1} does not match actual shape of {5,1} for output /conv2/Identity_3_output_0[m
[0;93m2025-07-16 18:24:33.043648961 [W:onnxruntime:, execution_frame.cc:876 VerifyOutputSizes] Expected shape from model of {-1} does not match actual shape of {5,1} f