In [1]:
import math
import time, os
import logging
from config import Config
import numpy as np
import torch
import out_manager as om
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score
from torch.nn import BCEWithLogitsLoss, Conv1d, MaxPool1d, ModuleList
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, GCNConv, SortAggregation

In [2]:
config = Config()
out_dir = om.get_existing_out_dir(config)
om.save_config(config, out_dir)
om.setup_logging(os.path.join(out_dir, "ssseal_log.txt"))
seed = config.seed
torch.manual_seed(seed)
np.random.seed(seed)
device = config.device

Configuration saved to: ./out\Cora_k60_hop3_Batch\config.json


In [3]:
train_data = torch.load(f'./data/{config.dataset}/split/ssseal_train_data_k{config.scoresampler.k_min}_h{config.scoresampler.num_hops}_{config.version}.pt')
val_data = torch.load(f'./data/{config.dataset}/split/ssseal_val_data_k{config.scoresampler.k_min}_h{config.scoresampler.num_hops}_{config.version}.pt')
test_data = torch.load(f'./data/{config.dataset}/split/ssseal_test_data_k{config.scoresampler.k_min}_h{config.scoresampler.num_hops}_{config.version}.pt')

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)

In [4]:
class DGCNN(torch.nn.Module):
    def __init__(self, hidden_dim, num_layers, GNN=GCNConv, k=0.6):
        super().__init__()

        if k < 1:  # Transform percentile to number.
            num_nodes = sorted([data.num_nodes for data in train_data])
            k = num_nodes[int(math.ceil(k * len(num_nodes))) - 1]
            k = int(max(10, k))

        self.convs = ModuleList()
        self.convs.append(GNN(train_data[0].x.size(1), hidden_dim))
        for i in range(0, num_layers - 1):
            self.convs.append(GNN(hidden_dim, hidden_dim))
        self.convs.append(GNN(hidden_dim, 1))

        conv1d_channels = [16, 32]
        total_latent_dim = hidden_dim * num_layers + 1
        conv1d_kws = [total_latent_dim, 5]
        self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0],
                            conv1d_kws[0])
        self.pool = SortAggregation(k)
        self.maxpool1d = MaxPool1d(2, 2)
        self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1],
                            conv1d_kws[1], 1)
        dense_dim = int((k - 2) / 2 + 1)
        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
        self.mlp = MLP([dense_dim, 128, 1], dropout=0.5, norm=None)

    def forward(self, x, edge_index, batch):
        xs = [x]
        for conv in self.convs:
            xs += [conv(xs[-1], edge_index).tanh()]
        x = torch.cat(xs[1:], dim=-1)

        # Global pooling.
        x = self.pool(x, batch)
        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
        x = self.conv1(x).relu()
        x = self.maxpool1d(x)
        x = self.conv2(x).relu()
        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]

        return self.mlp(x)


In [5]:
model = DGCNN(hidden_dim = config.ssseal.hidden_dim, num_layers = config.ssseal.num_layers, k = config.ssseal.k).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr = config.ssseal.lr)
loss_fn = BCEWithLogitsLoss()

In [6]:
def train():
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        data.batch = data.batch.long()
        data.edge_index = data.edge_index.long()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = loss_fn(out.view(-1), data.y.to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs

    return total_loss / len(train_data)

In [7]:
@torch.no_grad()
def test(loader):
    model.eval()

    y_pred, y_true = [], []
    for data in loader:
        data = data.to(device)
        data.batch = data.batch.long()
        data.edge_index = data.edge_index.long()
        logits = model(data.x, data.edge_index, data.batch)
        y_pred.append(logits.view(-1).cpu())
        y_true.append(data.y.view(-1).cpu().to(torch.float))

    return roc_auc_score(torch.cat(y_true), torch.cat(y_pred)), average_precision_score(torch.cat(y_true),torch.cat(y_pred) )

In [8]:
times = []
best_val_auc = final_test_auc = final_test_ap = 0

for epoch in range(1, 1 + config.ssseal.epochs):
    start = time.time()
    loss = train()
    train_auc, train_ap = test(train_loader)
    val_auc, val_ap = test(val_loader)
    test_auc, test_ap = test(test_loader)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
        final_test_ap = test_ap

    logging.info(f'Epoch: {epoch:03d}, Loss: {loss:.4f} '
             f'Train_AUC: {train_auc:.4f}, Train_AP: {train_ap:.4f} '
             f'Val_AUC: {val_auc:.4f}, Val_AP: {val_ap:.4f} '
             f'Test_AUC: {test_auc:.4f}, Test_AP: {test_ap:.4f}')
    times.append(time.time() - start)
    
logging.info(f'Median time per epoch: {torch.tensor(times).median():.4f}s'
             f'Final Test AUC: {final_test_auc:.4f}, AP: {final_test_ap:.4f}')

Epoch: 001, Loss: 0.6455 Train_AUC: 0.7464, Train_AP: 0.7632 Val_AUC: 0.7251, Val_AP: 0.7461 Test_AUC: 0.7681, Test_AP: 0.7854
Epoch: 002, Loss: 0.5748 Train_AUC: 0.7791, Train_AP: 0.7957 Val_AUC: 0.7603, Val_AP: 0.7805 Test_AUC: 0.8070, Test_AP: 0.8196
Epoch: 003, Loss: 0.5556 Train_AUC: 0.8002, Train_AP: 0.8173 Val_AUC: 0.7849, Val_AP: 0.8085 Test_AUC: 0.8185, Test_AP: 0.8365
Epoch: 004, Loss: 0.5417 Train_AUC: 0.8044, Train_AP: 0.8234 Val_AUC: 0.7823, Val_AP: 0.8054 Test_AUC: 0.8211, Test_AP: 0.8399
Epoch: 005, Loss: 0.5351 Train_AUC: 0.8095, Train_AP: 0.8305 Val_AUC: 0.7906, Val_AP: 0.8141 Test_AUC: 0.8241, Test_AP: 0.8448
Epoch: 006, Loss: 0.5289 Train_AUC: 0.8167, Train_AP: 0.8396 Val_AUC: 0.8029, Val_AP: 0.8261 Test_AUC: 0.8309, Test_AP: 0.8529
Epoch: 007, Loss: 0.5211 Train_AUC: 0.8288, Train_AP: 0.8474 Val_AUC: 0.8184, Val_AP: 0.8343 Test_AUC: 0.8393, Test_AP: 0.8570
Epoch: 008, Loss: 0.5148 Train_AUC: 0.8319, Train_AP: 0.8529 Val_AUC: 0.8243, Val_AP: 0.8401 Test_AUC: 0.8413, 

KeyboardInterrupt: 