In [1]:
import os.path as osp
from tqdm.auto import tqdm
import numpy as np
import torch
from pathlib import Path
import os
from torch_geometric.utils import erdos_renyi_graph

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GraphSAGE
from torch_geometric.data import Data, Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODELS_PATH = Path("./models")

if not MODELS_PATH.exists():
    os.makedirs(MODELS_PATH)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
DEVICE

device(type='cuda', index=0)

In [5]:
n_graphs = 200
class RandomDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return ''
    
    @property
    def processed_file_names(self):
        l = []
        for i in range(n_graphs):
            l.append("data_" + str(i) + ".pt")
        return l

    def download(self):
        pass

    def process(self):
        idx = 0
        for g in range(n_graphs):                    
            x = torch.rand([torch.randint(low=20, high=70, size=[1]), 1], dtype=torch.float32)
            edge_index = erdos_renyi_graph(x.size()[0], 0.3)            
            data = Data(x=x, edge_index=edge_index.contiguous())

            torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
            idx += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [6]:
dataset = RandomDataset(root="./data/")

Processing...
Done!


In [7]:
class SAGEConvModel(torch.nn.Module):
    def __init__(self, hidden_channels=64, num_layers=2, out_channels=1):
        super(SAGEConvModel, self).__init__()
        self.sage = GraphSAGE(dataset.num_features, hidden_channels, num_layers, out_channels)

    def forward(self, x, edge_index):
        x = self.sage(x, edge_index)
        return torch.softmax(x, dim=1)

In [8]:
def loss_fn(batchX, randomFriend, randomEnemies, Q=1):
    fst = -torch.log(torch.sigmoid(torch.sum(batchX * randomFriend, dim=1)))
    snd = -Nenem * torch.mean(torch.log(torch.sigmoid(-torch.sum(randomEnemies * batchX,  dim=2))), dim=0)
    return torch.mean(fst+snd)

def get_random_repr(node, x, edge_dict):
    variants = edge_dict.get(node, [])
    if variants:
        return x[np.random.choice(variants)]
    return torch.zeros(out_channels)

def get_dict_out_of_nodes(Nnodes, edge_index):
    edge_dict = {i:[] for i in range(Nnodes)}
    for edge in edge_index.reshape(-1,2).cpu().numpy():
        edge_dict[edge[0]].append(edge[1])
    return edge_dict

In [45]:
from functools import partial
criterion = loss_fn

epochs = 50
n_models = 50
out_channels = 10
Nenem = 10
n_features = 1
num_layers = 3

models_result = []

def train():
    for i in range(n_models):
        torch.manual_seed(i)
        model = SAGEConvModel(num_layers=num_layers, out_channels=out_channels)
        model.to(DEVICE)
        model.bfloat16()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        for epoch in range(epochs):
            train_loss = []
            for g in dataset:
                g.to(DEVICE)
                model.train()
                out = model(g.x.to(torch.bfloat16), g.edge_index)

                idxperm = torch.randperm(out.size()[0])
                # pidxs  = idxperm[batch*Nbatch:(batch+1)*Nbatch]
                out = out[idxperm]
                edge_dict = get_dict_out_of_nodes(out.size()[0], g.edge_index)
                get_random_repr_p = partial(get_random_repr, x=out.cpu(), edge_dict=edge_dict)
                randomFriend = torch.stack(list(map(get_random_repr_p, idxperm.numpy())), dim=0).to(DEVICE)
                randomEnemies = out[np.random.choice(g.x.size()[0], Nenem)].reshape(Nenem, 1, out_channels).to(DEVICE)
                
                loss = criterion(out, randomFriend, randomEnemies)            
                loss.backward()
                optimizer.step()
                
                train_loss.append(loss.item())
    
            if (epoch + 1) % 10 == 0:
                print(f'Model {i}: epoch: {epoch + 1:03d}, loss: {np.mean(train_loss)}')

        models_result.append(model(dataset[0].x.to(torch.bfloat16).to(DEVICE), dataset[0].edge_index.to(DEVICE)))
        for g in dataset[1:]:
            g.to(DEVICE)
            models_result[i] = torch.cat([models_result[i], model(g.x.to(torch.bfloat16), g.edge_index)])
        torch.save(model.state_dict(), f'./models/{i+1}')


def eval_stability():
    disagr = 0
    for i in range(n_models):
        for j in range(i, n_models):   
            disagr += torch.sum(torch.abs(models_result[i] - models_result[j])) / models_result[i].size()[0]
    return disagr / (n_models * (n_models - 1) / 2)

In [46]:
train()
d = eval_stability()
print("Stability:", d)

Model 0: epoch: 010, loss: 8.124897985458373
Model 0: epoch: 020, loss: 8.124892420768738
Model 0: epoch: 030, loss: 8.124884567260743
Model 0: epoch: 040, loss: 8.124888019561768
Model 0: epoch: 050, loss: 8.1248969745636
Model 1: epoch: 010, loss: 8.124892659187317
Model 1: epoch: 020, loss: 8.124895482063293
Model 1: epoch: 030, loss: 8.125185613632203
Model 1: epoch: 040, loss: 8.124885783195495
Model 1: epoch: 050, loss: 8.1249036693573
Model 2: epoch: 010, loss: 8.124896202087402
Model 2: epoch: 020, loss: 8.124893164634704
Model 2: epoch: 030, loss: 8.124888343811035
Model 2: epoch: 040, loss: 8.124899797439575
Model 2: epoch: 050, loss: 8.124900584220887
Model 3: epoch: 010, loss: 8.124897508621215
Model 3: epoch: 020, loss: 8.124896349906921
Model 3: epoch: 030, loss: 8.12488651752472
Model 3: epoch: 040, loss: 8.124895310401916
Model 3: epoch: 050, loss: 8.12489914894104
Model 4: epoch: 010, loss: 8.124896216392518
Model 4: epoch: 020, loss: 8.124889855384827
Model 4: epoch: 

2 layers \
float32 0.2288476824760437 \
float64 0.21959084776105314 \
bfloat16 0.10107421875 

1 layer \
bfloat16 0.201171875 \
float32 0.7191293835639954 \
float64 0.6989535146025564 

3 layers \
float64 0.2054728962831928 \
float32 0.2074834555387497\
bfloat16 0.05126953125