In [106]:
import os.path as osp
from tqdm.auto import tqdm
import numpy as np
import torch
from torch_geometric.transforms import RandomLinkSplit
from pathlib import Path
import os

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GraphSAGE

In [107]:
MODELS_PATH = Path("./models")

In [108]:
if not MODELS_PATH.exists():
    os.makedirs(MODELS_PATH)

In [53]:
dataset = Planetoid(root="./data/", name="Cora", split="public")
transform = RandomLinkSplit(num_val=0.1, num_test=0.20, split_labels=True)
train_data, val_data, test_data = transform(dataset[0])
data = dataset[0]

In [76]:
class SAGEConvModel(torch.nn.Module):
    def __init__(self, hidden_channels=64, num_layers=2, out_channels=dataset.num_classes):
        super(SAGEConvModel, self).__init__()
        self.sage = GraphSAGE(dataset.num_features, hidden_channels, num_layers, out_channels)

    def forward(self, x, edge_index):
        x = self.sage(x, edge_index)
        return torch.softmax(x, dim=1)

In [99]:
criterion = torch.nn.CrossEntropyLoss()

epochs = 200
n_models = 50

models_result = []

def train():
    for i in range(n_models):
        model = SAGEConvModel()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        train_loss = []
        for epoch in range(epochs):
            model.train()
            out = model(train_data.x, train_data.edge_index)
            loss = criterion(out, train_data.y)            
            
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
        print(f'Model {i}: epoch: {epoch + 1:03d}, loss: {np.mean(train_loss)}')

        models_result.append(model(test_data.x, test_data.edge_index), dim=1))
        torch.save(model.state_dict(), f'./models/{i+1}')


def eval_stability():
    dis_sum = 0
    for i in range(n_models):
        for j in range(i, n_models):
            dis_sum += torch.sum(torch.abs(models_result[i] - models_result[j]))
    return dis_sum / (n_models * (n_models + 1) / 2)


def eval_stability1():
    dis_sum = 0
    for i in range(n_models):
        for j in range(i, n_models):
            eq_mask = torch.eq(models_result[i], models_result[j])
            dis_sum += torch.sum(eq_mask) / eq_mask.size()[0]
    return dis_sum / (n_models * (n_models + 1) / 2)

In [100]:
train()
print(eval_stability())

Model 0: 200, loss: 1.4210438048839569
Model 1: 200, loss: 1.4350572943687439
Model 2: 200, loss: 1.4065681314468383
Model 3: 200, loss: 1.4377384531497954
Model 4: 200, loss: 1.4045356917381286
Model 5: 200, loss: 1.3982780426740646
Model 6: 200, loss: 1.4160704296827316
Model 7: 200, loss: 1.4020868039131165
Model 8: 200, loss: 1.4070432579517365
Model 9: 200, loss: 1.3889498561620712
Model 10: 200, loss: 1.4154532074928283
Model 11: 200, loss: 1.4106796818971634
Model 12: 200, loss: 1.3918729048967362
Model 13: 200, loss: 1.409529229402542
Model 14: 200, loss: 1.4157434397935866
Model 15: 200, loss: 1.4098780786991119
Model 16: 200, loss: 1.4237270164489746
Model 17: 200, loss: 1.4252563017606734
Model 18: 200, loss: 1.426661447286606
Model 19: 200, loss: 1.4151845079660417
Model 20: 200, loss: 1.407782748937607
Model 21: 200, loss: 1.400235372185707
Model 22: 200, loss: 1.4013875275850296
Model 23: 200, loss: 1.4118175631761551
Model 24: 200, loss: 1.414202960729599
Model 25: 200, 

In [94]:
models_result[0].shape[0]

2708