In [3]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

from torch_geometric.datasets import TUDataset
from torch_geometric.utils import to_dense_adj, degree
from torch_geometric.transforms import OneHotDegree

import random
import numpy as np

In [8]:
import torch
from sparsemax import Sparsemax

sparsemax = Sparsemax(dim=2)
softmax = torch.nn.Softmax(dim=1)

logits = torch.randn(2, 5,3).cuda()
print("\nLogits")
print(logits)

sparsemax_probs = sparsemax(logits)
print("\nSparsemax probabilities")
print(sparsemax_probs)


Logits
tensor([[[ 1.2134, -0.0457,  0.1644],
         [ 2.0272, -0.3087,  2.3172],
         [ 1.1514, -2.5599, -0.8436],
         [-0.1408, -1.5938,  0.3637],
         [ 1.9940, -0.6176,  1.8082]],

        [[-0.1017,  0.0801,  2.0556],
         [-2.2154, -0.1930, -0.3883],
         [ 0.7791, -1.0272,  0.1358],
         [-0.6180,  1.4808, -0.2632],
         [ 0.4972,  0.7544, -0.7078]]], device='cuda:0')

Sparsemax probabilities
tensor([[[1.0000, 0.0000, 0.0000],
         [0.3550, 0.0000, 0.6450],
         [1.0000, 0.0000, 0.0000],
         [0.2478, 0.0000, 0.7522],
         [0.5929, 0.0000, 0.4071]],

        [[0.0000, 0.0000, 1.0000],
         [0.0000, 0.5976, 0.4024],
         [0.8216, 0.0000, 0.1784],
         [0.0000, 1.0000, 0.0000],
         [0.3714, 0.6286, 0.0000]]], device='cuda:0')


In [7]:
data = TUDataset("./datasets/COLLAB.", name = "COLLAB", use_node_attr=True, use_edge_attr=True)

In [17]:
max_num_nodes = 0
for g in data:
    max_num_nodes = max(max_num_nodes, g.x.shape[0])
max_num_nodes

492

In [20]:
data[1]

Data(edge_index=[2, 1572], y=[1], num_nodes=44, x=[44, 492])

In [5]:
class GraphDataset(torch.utils.data.Dataset):
    def __init__(self, data, max_num_nodes = None) -> None:
        super().__init__()
        self.adj_list        = []
        self.x_list          = []
        self.y_list          = []
        self.edge_index_list = []
        self.max_num_nodes = max_num_nodes
        self.prepareData(data, max_num_nodes)
    
    def prepareData(self, data, max_num_nodes = None):
        for g in data:
            f = torch.zeros((self.max_num_nodes, g.x.shape[1]))
            f[:g.x.shape[0], :g.x.shape[1]] = g.x
            self.x_list.append(f)
            self.y_list.append(g.y)
            self.edge_index_list.append(g.edge_index)
            adj = to_dense_adj(g.edge_index)
            self.adj_list.append(adj[0])

    def __len__(self):
        return len(self.adj_list)

    def __getitem__(self, idx):
        adj = self.adj_list[idx]
        num_nodes = adj.shape[0]
        adj_padded = torch.zeros((self.max_num_nodes, self.max_num_nodes))
        adj_padded[:num_nodes, :num_nodes] = adj
        #adj_padded = np.zeros((self.max_num_nodes, self.max_num_nodes))
        #adj_padded[:num_nodes, :num_nodes] = adj

        return {'adj':adj_padded,
                'x':self.x_list[idx],
                'y':self.y_list[idx],
                'num_nodes': num_nodes 
                }


In [6]:
data = data.shuffle()

In [7]:
train_data = data[:int(len(data)*0.8)]
test_data  = data[int(len(data)*0.8):] 

In [8]:
train_dataset = GraphDataset(train_data, max_num_nodes)
train_loader  = DataLoader(train_dataset, batch_size = 32)

test_dataset  = GraphDataset(test_data, max_num_nodes)
test_loader   = DataLoader(test_dataset, batch_size = 32)

In [21]:
device = torch.device("cuda:0")
model = DiffPool(None, number_of_labels= 4).to(device)
optimizer = torch.optim.Adam(model.parameters() ,lr = 1e-4)

In [None]:
model.train()
for epoch in range(150):
    loss_sum = 0
    total = 0
    for idx, graph in enumerate(train_loader):
        optimizer.zero_grad()

        x = graph['x'].to(device)
        adj = graph['adj'].to(device)
        y = graph['y']

        y_pred, loss_lp, loss_e = model(x, adj)
        loss = F.cross_entropy(y_pred.cpu(), y.view(-1), reduction='mean') + loss_lp+loss_e
        loss.backward()
        optimizer.step()
        loss_sum += loss
    print(loss_sum/len(train_data))
    loss_sum = 0
    total = 0

In [23]:
model.eval()
correct = 0
total = 0
for idx, graph in enumerate(test_loader):
    x = graph['x'].to(device)
    adj = graph['adj'].to(device)
    y = graph['y']

    y_pred, _, _ = model(x, adj)
    prediction = torch.argmax(y_pred, 1).cpu()
    correct += (prediction == y.view(-1)).sum()
    total += len(y.view(-1)) 
print((correct/total).detach().data.numpy())

0.75784755
