In [1]:
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
from scipy.sparse.csgraph import shortest_path
import torch
import torch.nn.functional as F
from torch.nn import Conv1d, MaxPool1d, Linear, Dropout,BCEWithLogitsLoss
import torch_geometric
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, aggr, GATConv
from torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix

In [2]:
transform = RandomLinkSplit(num_val=0.05, num_test=0.1,
is_undirected=True, split_labels=True)
dataset = Planetoid('.', name='Cora',
transform=transform)
train_data, val_data, test_data = dataset[0]

In [3]:
def seal_processing(dataset, edge_label_index, y):
    data_list = []
    for src, dst in edge_label_index.t().tolist():
        sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph([src, dst], 2, dataset.edge_index, relabel_nodes=True)
        src, dst = mapping.tolist()
        mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)
        mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)
        sub_edge_index = sub_edge_index[:, mask1 & mask2]
        src, dst = (dst, src) if src > dst else (src,dst)
        adj = to_scipy_sparse_matrix(sub_edge_index, num_nodes=sub_nodes.size(0)).tocsr()
        idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
        adj_wo_src = adj[idx, :][:, idx]
        idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
        adj_wo_dst = adj[idx, :][:, idx]
        d_src = shortest_path(adj_wo_dst, directed=False,unweighted=True, indices=src)
        d_src = np.insert(d_src, dst, 0, axis=0)
        d_src = torch.from_numpy(d_src)
        d_dst = shortest_path(adj_wo_src, directed=False,unweighted=True, indices=dst-1)
        d_dst = np.insert(d_dst, src, 0, axis=0)
        d_dst = torch.from_numpy(d_dst)
        dist = d_src + d_dst
        z = 1 + torch.min(d_src, d_dst) + dist // 2 *(dist // 2 + dist % 2 - 1)
        z[src], z[dst], z[torch.isnan(z)] = 1., 1., 0.
        z = z.to(torch.long)
        node_labels = F.one_hot(z, num_classes=200).to(torch.float)
        node_emb = dataset.x[sub_nodes]
        node_x = torch.cat([node_emb, node_labels],dim=1)
        data = Data(x=node_x, z=z, edge_index=sub_edge_index, y=y)
        data_list.append(data)
    return data_list

In [4]:
train_pos_data_list = seal_processing(train_data, train_data.pos_edge_label_index, 1)
train_neg_data_list = seal_processing(train_data, train_data.neg_edge_label_index, 0)
val_pos_data_list = seal_processing(val_data, val_data.pos_edge_label_index, 1)
val_neg_data_list = seal_processing(val_data, val_data.neg_edge_label_index, 0)
test_pos_data_list = seal_processing(test_data, test_data.pos_edge_label_index, 1)
test_neg_data_list = seal_processing(test_data, test_data.neg_edge_label_index, 0)

In [5]:
train_dataset = train_pos_data_list + train_neg_data_list
val_dataset = val_pos_data_list + val_neg_data_list
test_dataset = test_pos_data_list + test_neg_data_list

In [6]:
train_loader = DataLoader(train_dataset, batch_size=32,shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

In [7]:
class DGCNN(torch.nn.Module):
    def __init__(self, dim_in):
        super().__init__()
        self.gcn1 = GATConv(dim_in, 32)
        self.gcn2 = GATConv(32, 32)
        self.gcn3 = GATConv(32, 32)
        self.gcn4 = GATConv(32, 1)
        self.global_pool = aggr.SortAggregation(k=30)
        self.conv1 = torch.nn.Conv1d(1, 16, 97, 97)
        self.conv2 = torch.nn.Conv1d(16, 32, 5, 1)
        self.maxpool = torch.nn.MaxPool1d(2, 2)
        self.linear1 = torch.nn.Linear(352, 128)
        self.linear2 = torch.nn.Linear(128, 1)
        self.dropout = torch.nn.Dropout(0.5)
    def forward(self, mat, edge_index, batch):
        h1 = self.gcn1(mat, edge_index).tanh()
        h2 = self.gcn2(h1, edge_index).tanh()
        h3 = self.gcn3(h2, edge_index).tanh()
        h4 = self.gcn4(h3, edge_index).tanh()
        h = torch.cat([h1, h2, h3, h4], dim=-1)
        h = self.global_pool(h, batch)
        h = h.view(h.size(0), 1, h.size(-1))
        h = self.conv1(h).relu()
        h = self.maxpool(h)
        h = self.conv2(h).relu()
        h = h.view(h.size(0), -1)
        h = self.linear1(h).relu()
        h = self.dropout(h)
        h = self.linear2(h).sigmoid()
        return h

In [None]:
class AMDGCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.gcn1 = torch_geometric.nn.GATConv(16, 32, edge_dim=2)
        self.gcn2 = torch_geometric.nn.GATConv(32, 32, edge_dim=2)
        self.gcn3 = torch_geometric.nn.GATConv(32, 32, edge_dim=2)
        self.gcn4 = AltGraphConv()
        self.global_pool = torch_geometric.nn.aggr.SortAggregation(k=30)
        self.conv1 = torch.nn.Conv1d(1, 16, 97, 97)
        self.conv2 = torch.nn.Conv1d(16, 32, 5, 1)
        self.maxpool = torch.nn.MaxPool1d(2, 2)
        self.linear1 = torch.nn.Linear(352, 128)
        self.linear2 = torch.nn.Linear(128, 1)
        self.dropout = torch.nn.Dropout(0.5)
    def forward(self, mat, edge_index):
        h1 = self.gcn1(mat, edge_index).tanh()
        h2 = self.gcn2(h1, edge_index).tanh()
        h3 = self.gcn3(h2, edge_index).tanh()
        h4 = self.gcn4(h3, edge_index).tanh()
        h = torch.cat([h1, h2, h3, h4], dim=-1)
        h = self.global_pool(h)
        h = h.view(h.size(0), 1, h.size(-1))
        h = self.conv1(h).relu()
        h = self.maxpool(h)
        h = self.conv2(h).relu()
        h = h.view(h.size(0), -1)
        h = self.linear1(h).relu()
        h = self.dropout(h)
        h = self.linear2(h).sigmoid()
        return h

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DGCNN(train_dataset[0].num_features).to(device)     # Use model = AMDGCNN for benchmarking AMDGCNN
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
criterion = torch.nn.BCELoss()

In [9]:
def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out.view(-1), data.y.to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
    return total_loss / len(train_dataset)

In [10]:
@torch.no_grad()
def test(loader):
    model.eval()
    y_pred, y_true = [], []
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        y_pred.append(out.view(-1).cpu())
        y_true.append(data.y.view(-1).cpu().to(torch.float))
    auc = roc_auc_score(torch.cat(y_true), torch.cat(y_pred))
    ap = average_precision_score(torch.cat(y_true),torch.cat(y_pred))
    return auc, ap

In [11]:
for epoch in range(31):
    loss = train()
    val_auc, val_ap = test(val_loader)
    print(f'Epoch {epoch:>2} | Loss: {loss:.4f} | ValAUC: {val_auc:.4f} | Val AP: {val_ap:.4f}')

Epoch  0 | Loss: 0.6131 | ValAUC: 0.8594 | Val AP: 0.8815
Epoch  1 | Loss: 0.4576 | ValAUC: 0.8919 | Val AP: 0.9123
Epoch  2 | Loss: 0.4140 | ValAUC: 0.9115 | Val AP: 0.9284
Epoch  3 | Loss: 0.3802 | ValAUC: 0.9232 | Val AP: 0.9354
Epoch  4 | Loss: 0.3537 | ValAUC: 0.9286 | Val AP: 0.9411
Epoch  5 | Loss: 0.3333 | ValAUC: 0.9264 | Val AP: 0.9415
Epoch  6 | Loss: 0.3213 | ValAUC: 0.9260 | Val AP: 0.9415
Epoch  7 | Loss: 0.3105 | ValAUC: 0.9251 | Val AP: 0.9411
Epoch  8 | Loss: 0.2989 | ValAUC: 0.9230 | Val AP: 0.9390
Epoch  9 | Loss: 0.2895 | ValAUC: 0.9197 | Val AP: 0.9366
Epoch 10 | Loss: 0.2787 | ValAUC: 0.9198 | Val AP: 0.9368
Epoch 11 | Loss: 0.2705 | ValAUC: 0.9215 | Val AP: 0.9383
Epoch 12 | Loss: 0.2646 | ValAUC: 0.9224 | Val AP: 0.9386
Epoch 13 | Loss: 0.2576 | ValAUC: 0.9238 | Val AP: 0.9403
Epoch 14 | Loss: 0.2481 | ValAUC: 0.9197 | Val AP: 0.9368
Epoch 15 | Loss: 0.2437 | ValAUC: 0.9191 | Val AP: 0.9362
Epoch 16 | Loss: 0.2415 | ValAUC: 0.9196 | Val AP: 0.9383
Epoch 17 | Los

In [12]:
test_auc, test_ap = test(test_loader)
print(f'Test AUC: {test_auc:.4f} | Test AP {test_ap:.4f}')

Test AUC: 0.8887 | Test AP 0.9123
