In [1]:
import os.path as osp
from tqdm.auto import tqdm
import numpy as np
import torch
from pathlib import Path
import os
from torch_geometric.utils import erdos_renyi_graph

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GraphSAGE
from torch_geometric.data import Data, Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
MODELS_PATH = Path("./models")

if not MODELS_PATH.exists():
    os.makedirs(MODELS_PATH)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [19]:
DEVICE

device(type='cpu')

In [3]:
n_graphs = 200
class RandomDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return ''
    
    @property
    def processed_file_names(self):
        l = []
        for i in range(n_graphs):
            l.append("data_" + str(i) + ".pt")
        return l

    def download(self):
        pass

    def process(self):
        idx = 0
        for g in range(n_graphs):                    
            x = torch.rand([torch.randint(low=20, high=70, size=[1]), 1], dtype=torch.float32)
            edge_index = erdos_renyi_graph(x.size()[0], 0.3)            
            data = Data(x=x, edge_index=edge_index.contiguous())

            torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
            idx += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [4]:
dataset = RandomDataset(root="./data/")

In [5]:
class SAGEConvModel(torch.nn.Module):
    def __init__(self, hidden_channels=64, num_layers=2, out_channels=1):
        super(SAGEConvModel, self).__init__()
        self.sage = GraphSAGE(dataset.num_features, hidden_channels, num_layers, out_channels)

    def forward(self, x, edge_index):
        x = self.sage(x, edge_index)
        return torch.softmax(x, dim=1)

In [6]:
def loss_fn(batchX, randomFriend, randomEnemies, Q=1):
    fst = -torch.log(torch.sigmoid(torch.sum(batchX * randomFriend, dim=1)))
    snd = -Nenem * torch.mean(torch.log(torch.sigmoid(-torch.sum(randomEnemies * batchX,  dim=2))), dim=0)
    return torch.mean(fst+snd)

def get_random_repr(node, x, edge_dict):
    variants = edge_dict.get(node, [])
    if variants:
        return x[np.random.choice(variants)]
    return torch.zeros(out_channels)

def get_dict_out_of_nodes(Nnodes, edge_index):
    edge_dict = {i:[] for i in range(Nnodes)}
    for edge in edge_index.reshape(-1,2).cpu().numpy():
        edge_dict[edge[0]].append(edge[1])
    return edge_dict

In [16]:
# criterion = torch.nn.CrossEntropyLoss()
from functools import partial
criterion = loss_fn

epochs = 50
n_models = 50
out_channels = 10
Nenem = 10
n_features = 1
num_layers = 3

models_result = []

def train():
    for i in range(n_models):
        torch.manual_seed(i)
        model = SAGEConvModel(num_layers=num_layers, out_channels=out_channels)
        model.to(DEVICE)
        # model.double()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        for epoch in range(epochs):
            train_loss = []
            for g in dataset:
                g.to(DEVICE)
                model.train()
                out = model(g.x.to(torch.float32), g.edge_index)

                idxperm = torch.randperm(out.size()[0])
                # pidxs  = idxperm[batch*Nbatch:(batch+1)*Nbatch]
                out = out[idxperm]
                edge_dict = get_dict_out_of_nodes(out.size()[0], g.edge_index)
                get_random_repr_p = partial(get_random_repr, x=out, edge_dict=edge_dict)
                randomFriend = torch.stack(list(map(get_random_repr_p, idxperm.numpy())), dim=0).to(DEVICE)
                randomEnemies = out[np.random.choice(g.x.size()[0], Nenem)].reshape(Nenem, 1, out_channels).to(DEVICE)
                
                loss = criterion(out, randomFriend, randomEnemies)            
                loss.backward()
                optimizer.step()
                
                train_loss.append(loss.item())
    
            if (epoch + 1) % 10 == 0:
                print(f'Model {i}: epoch: {epoch + 1:03d}, loss: {np.mean(train_loss)}')

        models_result.append(model(dataset[0].x.to(torch.float32), dataset[0].edge_index))
        for g in dataset[1:]:
            models_result[i] = torch.cat([models_result[i], model(g.x.to(torch.float32), g.edge_index)])
        torch.save(model.state_dict(), f'./models/{i+1}')


def eval_stability():
    disagr = 0
    for i in range(n_models):
        for j in range(i, n_models):   
            disagr += torch.sum(torch.abs(models_result[i] - models_result[j])) / models_result[i].size()[0]
    return disagr / (n_models * (n_models + 1) / 2)

In [17]:
train()
d = eval_stability()
print("Stability:", d)

Model 0: epoch: 010, loss: 8.101938920021057
Model 0: epoch: 020, loss: 8.098837537765503
Model 0: epoch: 030, loss: 8.109405617713929
Model 0: epoch: 040, loss: 8.106246495246888
Model 0: epoch: 050, loss: 8.093907990455627
Model 1: epoch: 010, loss: 8.09745801448822
Model 1: epoch: 020, loss: 8.09821361064911
Model 1: epoch: 030, loss: 8.102283091545106
Model 1: epoch: 040, loss: 8.108295373916626
Model 1: epoch: 050, loss: 8.113085570335388
Model 2: epoch: 010, loss: 8.094871244430543
Model 2: epoch: 020, loss: 8.096293849945068
Model 2: epoch: 030, loss: 8.101992859840394
Model 2: epoch: 040, loss: 8.10881190776825
Model 2: epoch: 050, loss: 8.106374373435974
Model 3: epoch: 010, loss: 8.09601707458496
Model 3: epoch: 020, loss: 8.100526375770569
Model 3: epoch: 030, loss: 8.102142906188964
Model 3: epoch: 040, loss: 8.105073771476746
Model 3: epoch: 050, loss: 8.11139690876007
Model 4: epoch: 010, loss: 8.095027899742126
Model 4: epoch: 020, loss: 8.098659772872924
Model 4: epoch:

In [238]:
model.float16()
for param in model.parameters():
    print(param.dtype)

AttributeError: 'SAGEConvModel' object has no attribute 'float16'

In [213]:
eval_stability()

tensor(0.2209, grad_fn=<DivBackward0>)

In [15]:
model = SAGEConvModel(num_layers=num_layers, out_channels=out_channels)
for param in model.parameters():
    print(param.dtype)

torch.float32
torch.float32
torch.float32


2 layers
float32 0.2209
float64 0.2333
bfloat16 0.1035

1 layer
bfloat16 0.2041
float64 0.7173
float32 0.7346

3 layers
float32 0.2181