In [15]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import pandas as pd
from sklearn.model_selection import train_test_split
from torch_geometric.data import DataLoader, Dataset, InMemoryDataset, download_url, extract_zip
from torch_geometric.nn import dense_diff_pool, GCNConv, GraphConv, SAGPooling, global_mean_pool, global_max_pool 
from torch.utils.data import random_split
from torch_geometric.utils import to_dense_adj, to_dense_batch
import os
import os.path as osp
import shutil
import time
from torch_geometric.io import read_tu_data
from tu_dataset import TUDataset
# from torch_geometric.datasets import TUDataset
import networkx as nx

In [16]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed(777)
else:
    torch.manual_seed(777)

In [17]:
def data_transform(data, index):
    data.x = F.normalize(data.x, p=2,dim = -1)   # L_2归一化
    data.i = index
    return data
DD = TUDataset(root='datasets/DD',name='DD', pre_transform=data_transform)

In [18]:
num_classes = DD.num_classes
num_features = DD.num_features
num_graphs = len(DD)

In [59]:
batch_size = 128

num_train = int(num_graphs*0.8)
num_val = int(num_graphs*0.1)
num_test = num_graphs - (num_train+num_val)
training_set, validation_set, testing_set = random_split(DD, [num_train, num_val, num_test])

train_loader = DataLoader(training_set, batch_size= batch_size, shuffle=True)
val_loader = DataLoader(validation_set,batch_size = batch_size, shuffle=False)
test_loader = DataLoader(testing_set,batch_size=1,shuffle=False)

In [60]:
max_node = np.max([x.num_nodes for x in DD])
# learning_rate = 0.001
reg = 0.0001
epochs = 300
num_hidden = 128
pooling_ratio = .8
dropout_ratio = .5
print(max_node)

5748


In [98]:
class Net(nn.Module):
    def __init__(self, num_features, num_hid, num_classes, pooling_ratio, dropout_ratio, score_conv = GCNConv):
        super(Net, self).__init__()
        self.num_features = num_features
        self.num_hid = num_hid
        self.num_classes = num_classes
        self.dropout_ratio = dropout_ratio
        
        self.conv1 = GCNConv(num_features, num_hid)
        self.pool1 = SAGPooling(num_hid, ratio=pooling_ratio, GNN = score_conv)

        self.conv2 = GCNConv(num_hid, num_hid)
        self.pool2 = SAGPooling(num_hid,ratio=pooling_ratio, GNN = GCNConv)

        self.conv3 = GCNConv(num_hid, num_hid)
        self.pool3 = SAGPooling(num_hid,ratio=pooling_ratio, GNN = GCNConv)

        self.linear1 = nn.Linear(num_hid * 2, num_hid)
        self.linear2 = nn.Linear(num_hid, num_hid //2)
        self.linear3 = nn.Linear(num_hid//2, num_classes)

    def forward(self, data):
        # Batch(batch=[35631], edge_index=[2, 179034], i=[128], x=[35631, 89], y=[128])
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, _,_ = self.pool1(x, edge_index, batch = batch)
        x1 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)

        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _ ,_= self.pool2(x, edge_index,batch = batch)
        x2 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)

        x = F.relu(self.conv3(x, edge_index))
        x, edge_index, _, batch, _,_ = self.pool3(x, edge_index, batch = batch)
        x3 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)

        x = x1 + x2 + x3

        x = F.relu(self.linear1(x))
        x = F.dropout(x, p=self.dropout_ratio, training=self.training)
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x
    
    def loss(self, output, labels):
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, labels)
        return loss

In [99]:
model = Net(num_features,num_hidden,num_classes, pooling_ratio, dropout_ratio)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), weight_decay=reg)
print(model)

Net(
  (conv1): GCNConv(89, 128)
  (pool1): SAGPooling(GCNConv, 128, ratio=0.8, multiplier=1)
  (conv2): GCNConv(128, 128)
  (pool2): SAGPooling(GCNConv, 128, ratio=0.8, multiplier=1)
  (conv3): GCNConv(128, 128)
  (pool3): SAGPooling(GCNConv, 128, ratio=0.8, multiplier=1)
  (linear1): Linear(in_features=256, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=64, bias=True)
  (linear3): Linear(in_features=64, out_features=2, bias=True)
)


### Early Stopping Function

In [100]:
patience = 0
limit_patience = 50
min_loss = float("inf")
def early_stopping(val_loss, min_loss, epoch, dataset="DD" , method = "SAGPool"):
    if val_loss < min_loss:
        torch.save(model.state_dict(), 'saved_model_{}_{}.pth'.format(dataset, method))
        # print("model saved at epoch {}".format(epoch))
        min_loss = val_loss
        patience = 0
    else:
        patience += 1
    return patience

In [101]:
def train(loader):
    model.train()
    train_loss = 0
    train_acc = 0
    for data in loader:
        data = data.to(device)
        output = model(data)
        optimizer.zero_grad()
        loss = model.loss(output, data.y)

        loss.backward()

        train_loss += loss.item()


        optimizer.step()

        train_acc += torch.eq(torch.argmax(output, -1),data.y).sum().item()

    acc_current_epoch = train_acc / len(loader.dataset)
    return train_loss, acc_current_epoch


def val(loader):
    model.eval()
    val_loss = 0
    val_acc = 0
    correct = 0
    for data in loader:
        data = data.to(device)
        output = model(data)
        optimizer.zero_grad()
        loss = model.loss(output, data.y)
        correct += torch.eq(torch.argmax(output,-1),data.y).sum().item()
        val_loss += loss.item()
    

    val_acc_epoch = correct / len(loader.dataset)
    return val_loss, val_acc_epoch


In [102]:
def do_training():
    for epoch in range(epochs):
        train_loss, acc_current_epoch = train(train_loader)
        val_loss, val_acc_epoch = val(val_loader)
        print("Epochs:{} Train loss:{} Train accuracy:{} Validation loss:{} Validation accuracy:{}".format(epoch, train_loss, acc_current_epoch, val_loss, val_acc_epoch))
        patience = early_stopping(val_loss, min_loss, epoch)

        if patience > limit_patience:
            break


In [103]:
do_training()

on accuracy:0.8290598290598291
Epochs:166 Train loss:0.5807328429073095 Train accuracy:0.9851380042462845 Validation loss:0.8404719829559326 Validation accuracy:0.8205128205128205
Epochs:167 Train loss:0.6125918552279472 Train accuracy:0.9787685774946921 Validation loss:1.0166735649108887 Validation accuracy:0.8376068376068376
Epochs:168 Train loss:0.60365847684443 Train accuracy:0.9830148619957537 Validation loss:0.9318631291389465 Validation accuracy:0.811965811965812
Epochs:169 Train loss:0.46387396566569805 Train accuracy:0.9872611464968153 Validation loss:0.9437280893325806 Validation accuracy:0.811965811965812
Epochs:170 Train loss:0.4949536882340908 Train accuracy:0.9872611464968153 Validation loss:1.0321846008300781 Validation accuracy:0.811965811965812
Epochs:171 Train loss:0.4562909146770835 Train accuracy:0.9872611464968153 Validation loss:1.0456277132034302 Validation accuracy:0.8290598290598291
Epochs:172 Train loss:0.4178936704993248 Train accuracy:0.9851380042462845 Vali

### load best model

In [104]:
path = "saved_model_DD_SAGPool.pth"
model = Net(num_features,num_hidden,num_classes, pooling_ratio, dropout_ratio)
model.load_state_dict(torch.load(path))
model.to(device)

In [108]:
test_loss, test_acc_epoch = val(test_loader)
print(test_loss,test_acc_epoch)

283.01559376443504 0.7226890756302521
