In [14]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
import pandas as pd
from sklearn.model_selection import train_test_split
from torch_geometric.data import DataLoader, Dataset
from torch_geometric.nn import dense_diff_pool, GCNConv, GraphConv, DenseGCNConv, JumpingKnowledge,DenseSAGEConv
from torch.utils.data import random_split
from torch_geometric.utils import to_dense_adj, to_dense_batch
import time

In [15]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed(777)
else:
    torch.manual_seed(777)

# D&D

In [17]:
class MyDataset(Dataset):
    def __init__(self, graphs):
        super(MyDataset, self).__init__()
        self.graphs = graphs

    def __getitem__(self,index):
        return self.graphs[index]

    def __len__(self):
        return len(self.graphs)


In [18]:
def data_transform(data):
    data.x = F.normalize(data.x, p=2,dim = -1)   # L_2归一化
    return data

DD = TUDataset(root='datasets/DD',name='DD', pre_transform=data_transform)

In [19]:
num_classes = DD.num_classes
num_features = DD.num_features
num_graphs = len(DD)
print(num_classes)
print(num_features)
print(num_graphs)

2
89
1178


In [20]:
batch_size = 32

num_train = int(num_graphs*0.8)
num_val = int(num_graphs*0.1)
num_test = num_graphs - (num_train+num_val)
training_set, validation_set, testing_set = random_split(DD, [num_train, num_val, num_test])

train_loader = DataLoader(training_set, batch_size= batch_size, shuffle=True)
val_loader = DataLoader(validation_set,batch_size = batch_size, shuffle=False)
test_loader = DataLoader(testing_set,batch_size=1,shuffle=False)

In [21]:
for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()
    break

Step 1:
Number of graphs in the current batch: 32
Batch(batch=[8211], edge_index=[2, 40850], x=[8211, 89], y=[32])



In [22]:
from torch_geometric.nn import dense_diff_pool, SAGPooling

## DiffPool on D&D 

`x`: $x \in \mathbb{R}^{B \times N \times F}$

`adj`: $adj \in \mathbb{R}^{B \times N \times N}$

In [23]:
max_node = np.max([x.num_nodes for x in DD])
ratio = 0.25  # Keep node ratio in each layer of DiffPool
# learning_rate = 0.001
reg = 0.0001
epochs = 300
num_hidden = 30
print(max_node)

5748


In [24]:
class DiffPoolLayer(nn.Module):
    def __init__(self, in_channel, out_channel, num_cluster = 10):
        super(DiffPoolLayer, self).__init__()
        self.gnn_pool = DenseGCNConv(in_channel, num_cluster)  # GCN with mask
        self.gnn_embed = DenseGCNConv(in_channel, out_channel)

    def forward(self, x, adj, mask = None):
        """
        x: feature matrix of batch graphs  x \in (B,N,F)   B is batch size
        adj: adj of batchs graphs        adj \in (B,N,N) 
        batch: Batch vector, which assigns each node to a specific graph (0,0,0,0,1,1,1,...)
        """
        S = F.relu(self.gnn_pool(x, adj, mask))
        X = F.relu(self.gnn_embed(x, adj, mask))
        X, adj, link_loss, ent_loss = dense_diff_pool(X,adj,S,mask)
        return X, adj, link_loss+ent_loss

In [25]:
class Net(nn.Module):
    def __init__(self, hidden_channel, Model, use_jumpingknowledge = False, dropout = .0):
        super(Net, self).__init__()
        ############################################### Out of Memory ###################################################
        # self.conv1 = DenseGCNConv(num_features, hidden_channel)
        # self.pool1 = DiffPoolLayer(hidden_channel, hidden_channel, np.ceil(ratio * max_node).astype(np.int))
        
        # self.conv2 = DenseGCNConv(hidden_channel, hidden_channel)
        # self.pool2 = DiffPoolLayer(hidden_channel, hidden_channel, np.ceil(ratio ** 2 * max_node).astype(np.int))

        # self.conv3 = DenseGCNConv(hidden_channel, hidden_channel)
        # self.pool3 = DiffPoolLayer(hidden_channel, hidden_channel, np.ceil(ratio ** 3 * max_node).astype(np.int))
        #################################################################################################################
        self.conv1 = Model(num_features, hidden_channel)
        self.conv2 = Model(hidden_channel, hidden_channel)
        self.pool1 = DiffPoolLayer(hidden_channel, hidden_channel, np.ceil(ratio * max_node).astype(np.int))
        self.conv3 = Model(hidden_channel, hidden_channel)
        self.conv4 = Model(hidden_channel, hidden_channel)
        ################################################################################################################# 

        self.use_jumpingknowledge = use_jumpingknowledge
        if self.use_jumpingknowledge:
            self.jump = JumpingKnowledge(mode='cat')
        self.dropout = dropout
        # self.linear1 = nn.Linear(3 * hidden_channel, hidden_channel)
        self.linear2 = nn.Linear(hidden_channel, hidden_channel * 2)
        self.linear3 = nn.Linear(hidden_channel * 2, num_classes)

    def forward(self, data):
        """
        data: a batch graphs like: Batch(batch=[33138], edge_index=[2, 164894], x=[33138, 89], y=[128])
        """
        # batch, edge_index, x, y = data.batch, data.edge_index, data.x, data.y
        # x, mask = to_dense_batch(x,batch)
        # adj = to_dense_adj(edge_index, batch)
        x, mask, adj = data[0], data[1],data[2]
        xs = []
        x = F.relu(self.conv1(x, adj, mask))
        x = F.relu(self.conv2(x, adj, mask))

        x, adj, reg = self.pool1(x, adj, mask)

        xs.append(x)

        x = F.relu(self.conv3(x, adj))
        # x, adj, loss2 = self.pool2(x, adj)

        xs.append(x)

        x = F.relu(self.conv4(x, adj))
        # x, adj, loss3 = self.pool3(x, adj)

        xs.append(x)

        if self.use_jumpingknowledge:
            x = self.jump(xs)
            x = F.relu(self.linear1(x))
        

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        readout_x = torch.sum(x,dim = 1)
        return readout_x , reg

In [26]:
model = Net(hidden_channel = num_hidden, Model=DenseSAGEConv)
model.to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), weight_decay=reg)
loss_func = nn.CrossEntropyLoss()

Net(
  (conv1): DenseSAGEConv(89, 30)
  (conv2): DenseSAGEConv(30, 30)
  (pool1): DiffPoolLayer(
    (gnn_pool): DenseGCNConv(30, 1437)
    (gnn_embed): DenseGCNConv(30, 30)
  )
  (conv3): DenseSAGEConv(30, 30)
  (conv4): DenseSAGEConv(30, 30)
  (linear2): Linear(in_features=30, out_features=60, bias=True)
  (linear3): Linear(in_features=60, out_features=2, bias=True)
)


In [27]:
def train(loader):
    model.train()
    train_loss = 0
    train_acc = 0
    for data in loader:
        data = data.to(device)
        batch, edge_index, x, y = data.batch, data.edge_index, data.x, data.y
        x, mask = to_dense_batch(x,batch)
        adj = to_dense_adj(edge_index, batch)
        data_input = [x, mask, adj]
        output,reg = model(data_input)  # torch.Size([64, 2])

        loss = loss_func(output, data.y) + reg

        optimizer.zero_grad()

        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
        train_loss += loss.item()

        optimizer.step()

        train_acc += (torch.argmax(output, -1) == data.y).sum().item()

    acc_current_epoch = train_acc / len(loader.dataset)
    return train_loss, acc_current_epoch
    

def val(loader):
    model.eval()
    val_loss = 0
    val_acc = 0
    correct = 0
    for data in loader:
        data = data.to(device)
        batch, edge_index, x, y = data.batch, data.edge_index, data.x, data.y

        x, mask = to_dense_batch(x,batch)
        adj = to_dense_adj(edge_index, batch)
        data_input = [x, mask, adj]
        output, _ = model(data_input)
        correct += (torch.argmax(output,-1)==data.y).sum().item()
        loss = loss_func(output, data.y)
        val_loss += loss.item()

    val_acc_epoch = correct / len(loader.dataset)
    return val_loss, val_acc_epoch

In [28]:
patience = 0
min_loss = 1e10
limit_patience = 50
for epoch in range(epochs):
    t = time.time()
    train_loss, acc_current_epoch = train(train_loader)
    val_loss, val_acc_epoch = val(val_loader)
    print("Epochs:{} Train loss:{} Train accuracy:{} Validation loss:{} Validation accuracy:{}".format(epoch, train_loss, acc_current_epoch, val_loss, val_acc_epoch))
    if val_loss < min_loss:
        torch.save(model.state_dict(), 'saved_model')
        min_loss = val_loss
        patience = 0
    else:
        patience += 1
    if patience > limit_patience:
        break

Epochs:0 Train loss:199.64890241622925 Train accuracy:0.643312101910828 Validation loss:5.620843529701233 Validation accuracy:0.7863247863247863
Epochs:1 Train loss:113.06811237335205 Train accuracy:0.7229299363057324 Validation loss:4.690021872520447 Validation accuracy:0.6666666666666666
Epochs:2 Train loss:110.85566115379333 Train accuracy:0.6740976645435244 Validation loss:4.256055951118469 Validation accuracy:0.4358974358974359
Epochs:3 Train loss:107.70110607147217 Train accuracy:0.6560509554140127 Validation loss:2.05905345082283 Validation accuracy:0.811965811965812
Epochs:4 Train loss:104.84012830257416 Train accuracy:0.6942675159235668 Validation loss:1.8513012826442719 Validation accuracy:0.8205128205128205
Epochs:5 Train loss:97.6878525018692 Train accuracy:0.6602972399150743 Validation loss:3.864527940750122 Validation accuracy:0.6581196581196581
Epochs:6 Train loss:111.07568550109863 Train accuracy:0.6518046709129511 Validation loss:2.2108646631240845 Validation accuracy:

In [35]:
test_loss, test_acc_epoch = val(test_loader)

In [36]:
test_loss

62.006263660092365

In [37]:
test_acc_epoch

0.7815126050420168