In [1]:
import torch
from torch_geometric.nn import GCNConv
from torch.nn import Linear
from torch_geometric.nn import global_mean_pool
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from pooling_class import *

In [17]:
dataset = TUDataset('data', name='MUTAG')
dataset = dataset.shuffle()
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]
data = train_dataset[0]
data

Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])

In [18]:
# Test whether the message can flow (gradient can be updated automatically)
class dispooling_GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        torch.manual_seed(1234)
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.pool1 = pooling(hidden_channels, score_method=2, normalize=True, self_add=2.5, aggregate_score_method='avg', upper_bound=5, greedy=True, select=True)
        # self.poo1 = EdgePooling(hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index).relu()
        edge_index, h, _, _, _ = self.pool1(h,edge_index)
        # h, edge_index, _, _ = self.pool1(x, edge_index, batch=None)
        h = self.conv2(h, edge_index).relu()
        h = self.conv3(h, edge_index).relu()
        h = self.lin(h).relu()
        h = global_mean_pool(h, batch=None)
        return F.log_softmax(h, dim=-1)

In [27]:
model = dispooling_GCN(in_channels=dataset.num_node_features, hidden_channels=20, out_channels=dataset.num_classes)
model.forward(data.x, data.edge_index)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()
    for data in train_dataset:
        optimizer.zero_grad() 
        out = model(data.x, data.edge_index)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
    return(loss)


def test():
    model.eval()
    correct = 0
    for data in test_dataset:
        pred = model(data.x, data.edge_index).argmax(dim=1)
        correct += int((pred == data.y).sum())
    return correct / len(test_dataset)


def val():
    model.eval()
    correct = 0
    for data in val_dataset:
        pred = model(data.x, data.edge_index).argmax(dim=1)
        correct += int((pred == data.y).sum())
    return correct / len(test_dataset)



best_val_acc = 0
for epoch in range(1, 171):
    train_loss = train()
    val_acc = val()
    test_acc = test()
    if val_acc > best_val_acc:
        test_acc = test()
        best_val_acc = val_acc
    if epoch % 10 == 1:
        print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, '
            f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

Epoch: 001, Train Loss: 0.5653, Val Acc: 0.5789, Test Acc: 0.3684
Epoch: 011, Train Loss: 0.5030, Val Acc: 0.5789, Test Acc: 0.3684
Epoch: 021, Train Loss: 0.5840, Val Acc: 0.5789, Test Acc: 0.3684
Epoch: 031, Train Loss: 0.6743, Val Acc: 0.5789, Test Acc: 0.3684
Epoch: 041, Train Loss: 0.7102, Val Acc: 0.6316, Test Acc: 0.4737
Epoch: 051, Train Loss: 0.9604, Val Acc: 0.6316, Test Acc: 0.6316
Epoch: 061, Train Loss: 1.0706, Val Acc: 0.6316, Test Acc: 0.6316
Epoch: 071, Train Loss: 1.1856, Val Acc: 0.6316, Test Acc: 0.6316
Epoch: 081, Train Loss: 1.2665, Val Acc: 0.6316, Test Acc: 0.6316
Epoch: 091, Train Loss: 1.3340, Val Acc: 0.6316, Test Acc: 0.6316
Epoch: 101, Train Loss: 1.3869, Val Acc: 0.6316, Test Acc: 0.6316
Epoch: 111, Train Loss: 1.4389, Val Acc: 0.6316, Test Acc: 0.6316
Epoch: 121, Train Loss: 1.4803, Val Acc: 0.6316, Test Acc: 0.6316
Epoch: 131, Train Loss: 1.5150, Val Acc: 0.6316, Test Acc: 0.6316
Epoch: 141, Train Loss: 1.5473, Val Acc: 0.6316, Test Acc: 0.6316
Epoch: 151

In [12]:
# upper bound is chose not to integrate the graph to a single node 5 * 0.5 = 2.5
import networkx as nx
d_min = 1000
for data in dataset:
    G = U.to_networkx(data, to_undirected=True)
    d = nx.diameter(G)
    if d < d_min:
        d_min = d
d_min

5