In [21]:
import torch
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{torchversion}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{torchversion}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
import numpy as np
np.random.seed(0)
import torch.nn.functional as F
from torch.nn import Linear, Dropout
from torch_geometric.nn import GCNConv
!pip install onnx onnxruntime
import onnx
import onnxruntime
!pip install onnx2torch

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [22]:
# Создание модели GCN
class GCN(torch.nn.Module):
  def __init__(self, dim_in, dim_h, dim_out):
    super().__init__()
    self.gcn1 = GCNConv(dim_in, dim_h)
    self.gcn2 = GCNConv(dim_h, dim_out)
    self.optimizer = torch.optim.Adam(self.parameters(),
                                      lr=0.01,
                                      weight_decay=5e-4)

  def forward(self, x, edge_index):
    h = F.dropout(x, p=0.5, training=self.training)
    h = self.gcn1(h, edge_index)
    h = torch.relu(h)
    h = F.dropout(h, p=0.5, training=self.training)
    h = self.gcn2(h, edge_index)
    return h, F.log_softmax(h, dim=1)

  def accuracy(pred_y, y):
    return ((pred_y == y).sum() / len(y)).item()

def train(model, data):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = model.optimizer
    epochs = 200

    model.train()
    for epoch in range(epochs+1):
        # Обучение
        optimizer.zero_grad()
        _, out = model(data.x, data.edge_index)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        acc = accuracy(out[data.train_mask].argmax(dim=1), data.y[data.train_mask])
        loss.backward()
        optimizer.step()

    return model

def test(model, data):
    model.eval()
    _, out = model(data.x, data.edge_index)
    acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])
    print(out)
    return acc

In [23]:
# Загрузка датасета
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root=".", name="CiteSeer")
data = dataset[0]

In [24]:
# Обучение и запуск модели
gcn = GCN(dataset.num_features, 16, dataset.num_classes)
train(gcn, data)
acc = test(gcn, data)
print(f'\Тестовая точность: {acc*100:.2f}%\n')

tensor([[-1.0811e+01, -1.1549e+01, -9.4690e+00, -1.6295e-04, -1.0123e+01,
         -1.1056e+01],
        [-6.7831e+00, -4.4258e-03, -6.1391e+00, -6.9418e+00, -8.8053e+00,
         -1.1521e+01],
        [-9.6899e+00, -1.9580e+01, -1.2986e+01, -8.9659e+00, -1.2185e+01,
         -1.9703e-04],
        ...,
        [-6.5827e+00, -6.1197e+00, -3.6955e+00, -3.0511e-02, -7.1129e+00,
         -7.1088e+00],
        [-4.4577e+00, -6.5145e-01, -8.0605e-01, -4.4080e+00, -5.2005e+00,
         -5.8748e+00],
        [-6.7035e+00, -1.2847e+01, -8.6602e+00, -7.1166e+00, -8.3440e+00,
         -2.4550e-03]], grad_fn=<LogSoftmaxBackward0>)
\Тестовая точность: 67.10%



In [25]:
# Конвертация в ONNX
torch.onnx.export(gcn, (data.x, data.edge_index), "gcn.onnx", opset_version=13)
onnx_model = onnx.load("gcn.onnx")
ort_session = onnxruntime.InferenceSession("gcn.onnx")



In [26]:
# Запуск модели в ONNX
input_name = ort_session.get_inputs()[0].name
input_shape = ort_session.get_inputs()[0].shape
x = data.x.numpy()
edge_index = data.edge_index.numpy()
outputs = ort_session.run(None, {'x.1': x, 'edge_index.1': edge_index})
print(outputs)

[array([[ -5.2798457 ,  -5.990191  ,  -1.8487921 ,   6.267871  ,
         -3.6762595 ,  -2.5198135 ],
       [ -4.0579696 ,   4.9101915 ,  -0.10661659,  -0.7850874 ,
         -1.1388807 ,  -2.6390142 ],
       [ -1.1025225 , -10.607943  ,  -2.875979  ,  -0.9262146 ,
         -2.8024988 ,   9.351824  ],
       ...,
       [ -2.3104267 ,  -2.5289018 ,  -3.8067794 ,   6.439511  ,
         -5.726099  ,  -6.306202  ],
       [  1.9849607 ,   1.3787694 ,  -0.23785113,  -2.5923834 ,
         -1.5183591 ,  -2.727569  ],
       [ -0.71692884, -11.272264  ,  -3.6094556 ,  -1.3425553 ,
         -2.6587367 ,  10.055194  ]], dtype=float32), array([[-1.1548229e+01, -1.2258575e+01, -8.1171761e+00, -5.1342178e-04,
        -9.9446430e+00, -8.7881975e+00],
       [-8.9810781e+00, -1.2917649e-02, -5.0297256e+00, -5.7081966e+00,
        -6.0619898e+00, -7.5621233e+00],
       [-1.0454420e+01, -1.9959839e+01, -1.2227876e+01, -1.0278111e+01,
        -1.2154396e+01, -7.3311028e-05],
       ...,
       [-8.75

In [27]:
# Обратная конвертация
from onnx2torch.converter import convert
torch_model = convert("gcn.onnx")

NotImplementedError: Converter is not implemented (OperationDescription(domain='', operation_type='ScatterElements', version=13))