In [1]:
# We assume that PyTorch is already installed
import torch
torchversion = torch.__version__

# Install PyTorch Scatter, PyTorch Sparse, and PyTorch Geometric
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{torchversion}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{torchversion}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

# Numpy for matrices
import numpy as np
np.random.seed(0)
import torch.nn.functional as F
from torch.nn import Linear, Dropout
from torch_geometric.nn import GCNConv

!pip install onnx onnxruntime
import onnx
import onnxruntime
!pip install onnx2torch

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m51.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch-geometric (pyproject.toml) ... [?25l[?25hdone
Collecting onnx
  Downloading onnx-1.16.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.9/15.9 MB[0m [31m54.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting onnxruntime
  Downloading onnxruntime-1.18.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m75.2 MB/s[0m eta [36m0:00:00[0m
Collecting coloredlogs (fro

In [2]:
class GCN(torch.nn.Module):
  """Graph Convolutional Network"""
  def __init__(self, dim_in, dim_h, dim_out):
    super().__init__()
    self.gcn1 = GCNConv(dim_in, dim_h)
    self.gcn2 = GCNConv(dim_h, dim_out)
    self.optimizer = torch.optim.Adam(self.parameters(),
                                      lr=0.01,
                                      weight_decay=5e-4)

  def forward(self, x, edge_index):
    h = F.dropout(x, p=0.5, training=self.training)
    h = self.gcn1(h, edge_index)
    h = torch.relu(h)
    h = F.dropout(h, p=0.5, training=self.training)
    h = self.gcn2(h, edge_index)
    return h, F.log_softmax(h, dim=1)

In [3]:
def accuracy(pred_y, y):
    """Calculate accuracy."""
    return ((pred_y == y).sum() / len(y)).item()

def train(model, data):
    """Train a GNN model and return the trained model."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = model.optimizer
    epochs = 200

    model.train()
    for epoch in range(epochs+1):
        # Training
        optimizer.zero_grad()
        _, out = model(data.x, data.edge_index)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        acc = accuracy(out[data.train_mask].argmax(dim=1), data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        # Validation
        val_loss = criterion(out[data.val_mask], data.y[data.val_mask])
        val_acc = accuracy(out[data.val_mask].argmax(dim=1), data.y[data.val_mask])

        # Print metrics every 10 epochs
        if(epoch % 10 == 0):
            print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f} | Train Acc: '
                  f'{acc*100:>6.2f}% | Val Loss: {val_loss:.2f} | '
                  f'Val Acc: {val_acc*100:.2f}%')

    return model

def test(model, data):
    """Evaluate the model on test set and print the accuracy score."""
    model.eval()
    _, out = model(data.x, data.edge_index)
    acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])
    print(out)
    return acc

In [4]:
from torch_geometric.datasets import Planetoid

# Import dataset from PyTorch Geometric
dataset = Planetoid(root=".", name="CiteSeer")

data = dataset[0]

# Print information about the dataset
print(f'Dataset: {dataset}')
print('-------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {data.x.shape[0]}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

# Print information about the graph
print(f'\nGraph:')
print('------')
print(f'Edges are directed: {data.is_directed()}')
print(f'Graph has isolated nodes: {data.has_isolated_nodes()}')
print(f'Graph has loops: {data.has_self_loops()}')

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.test.index
Processing...


Dataset: CiteSeer()
-------------------
Number of graphs: 1
Number of nodes: 3327
Number of features: 3703
Number of classes: 6

Graph:
------
Edges are directed: False
Graph has isolated nodes: True
Graph has loops: False


Done!


In [5]:
%%time

# Create GCN model
gcn = GCN(dataset.num_features, 16, dataset.num_classes)
print(gcn)

# Train
train(gcn, data)

# Test
acc = test(gcn, data)
print(f'\nGCN test accuracy: {acc*100:.2f}%\n')

GCN(
  (gcn1): GCNConv(3703, 16)
  (gcn2): GCNConv(16, 6)
)
Epoch   0 | Train Loss: 1.786 | Train Acc:  21.67% | Val Loss: 1.80 | Val Acc: 14.80%
Epoch  10 | Train Loss: 0.424 | Train Acc:  95.83% | Val Loss: 1.38 | Val Acc: 55.00%
Epoch  20 | Train Loss: 0.205 | Train Acc:  95.00% | Val Loss: 1.43 | Val Acc: 56.20%
Epoch  30 | Train Loss: 0.047 | Train Acc: 100.00% | Val Loss: 1.61 | Val Acc: 57.20%
Epoch  40 | Train Loss: 0.055 | Train Acc: 100.00% | Val Loss: 1.57 | Val Acc: 58.00%
Epoch  50 | Train Loss: 0.068 | Train Acc:  99.17% | Val Loss: 1.71 | Val Acc: 56.80%
Epoch  60 | Train Loss: 0.046 | Train Acc: 100.00% | Val Loss: 1.75 | Val Acc: 56.00%
Epoch  70 | Train Loss: 0.047 | Train Acc: 100.00% | Val Loss: 1.51 | Val Acc: 59.80%
Epoch  80 | Train Loss: 0.063 | Train Acc:  98.33% | Val Loss: 1.74 | Val Acc: 59.20%
Epoch  90 | Train Loss: 0.051 | Train Acc:  99.17% | Val Loss: 1.53 | Val Acc: 58.60%
Epoch 100 | Train Loss: 0.044 | Train Acc: 100.00% | Val Loss: 1.58 | Val Acc: 6

In [6]:
torch.onnx.export(gcn, (data.x, data.edge_index), "gcn.onnx", opset_version=13)
onnx_model = onnx.load("gcn.onnx")
ort_session = onnxruntime.InferenceSession("gcn.onnx")



In [7]:
# Запуск модели из файла ONNX
# Get input name and shape
input_name = ort_session.get_inputs()[0].name
input_shape = ort_session.get_inputs()[0].shape

# Prepare input data
x = data.x.numpy()
edge_index = data.edge_index.numpy()

# Run the model with ONNX Runtime
outputs = ort_session.run(None, {'x.1': x, 'edge_index.1': edge_index})

# Print the result
print(outputs)

[array([[-3.5778384 , -0.25357652, -1.8256342 ,  7.038357  , -2.0389197 ,
        -2.615196  ],
       [-4.1751566 ,  5.707779  , -4.9034753 , -3.9429836 , -3.9813735 ,
         0.351418  ],
       [-0.49726263, -4.143748  , -5.80157   , -4.9154763 , -0.9774149 ,
         7.435755  ],
       ...,
       [-2.1670246 ,  1.4986513 , -2.0194976 ,  6.024235  , -3.7016306 ,
        -3.9124808 ],
       [ 1.6548642 ,  0.5736364 ,  0.7832375 , -4.5990286 , -2.0648942 ,
        -2.6529593 ],
       [-0.16095512, -5.0551696 , -7.413631  , -3.999258  , -1.6690872 ,
         8.920065  ]], dtype=float32), array([[-1.0617221e+01, -7.2929578e+00, -8.8650160e+00, -1.0247938e-03,
        -9.0783014e+00, -9.6545782e+00],
       [-9.8878431e+00, -4.9080607e-03, -1.0616161e+01, -9.6556702e+00,
        -9.6940603e+00, -5.3612690e+00],
       [-7.9336133e+00, -1.1580098e+01, -1.3237921e+01, -1.2351827e+01,
        -8.4137659e+00, -5.9598801e-04],
       ...,
       [-8.2027283e+00, -4.5370522e+00, -8.055201

In [8]:
from onnx2torch.converter import convert
torch_model = convert("gcn.onnx")

NotImplementedError: Converter is not implemented (OperationDescription(domain='', operation_type='ScatterElements', version=13))