In [6]:
import os.path as osp

from tqdm import tqdm
import torch
from torch.nn import BCEWithLogitsLoss

import networkx as nx
import numpy as np
from torch_geometric.loader import DataLoader
from sklearn.metrics import roc_auc_score

from datasets import SEALDataset, SEALDatasetInMemory, SEALPredData
from models import DGCNN
from utils import read_temporary_graph_data, read_graph_from_edgefile

Super parameters:
 - `T`: split number of the time series.
 - `N`: number of predicted continuous time series.
 - `datasets`: the graph datasets, instance of `SEALDataset`.
 - `BS`: batch size.
 - `EPOCH`: number of epochs.
 - `LEARNING_RATE`: learning rate.

In [7]:
T = 400
N = 6
datasets = ['SuperUser',]
BS = 8192
EPOCH = 50
LEARN_RATE = 0.0001

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x1704f596310>

In [4]:
def train(model, optimizer, criterion, data_loader, device='cuda'):
    model.train()

    total_loss = 0
    for data in data_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out.view(-1), data.y.to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs

    return total_loss / len(data_loader.dataset)


@torch.no_grad()
def test(model, data_loader, device='cuda'):
    model.eval()

    y_pred, y_true = [], []
    for data in data_loader:
        data = data.to(device)
        logits = model(data.x, data.edge_index, data.batch)
        y_pred.append(logits.view(-1).cpu())
        y_true.append(data.y.view(-1).cpu().to(torch.float))

    return roc_auc_score(torch.cat(y_true), torch.cat(y_pred))


@torch.no_grad()
def predict(model, data_loader, threshold=0.5, device='cuda'):
    pred_edge_index = torch.zeros(0, 2).to(device)
    model.eval()
    
    for j, data in enumerate(data_loader):
        data = data.to(device)
        logits = model(data.x, data.edge_index, data.batch)
        logits = logits.view(-1).sigmoid()
        batch_mask = (logits > threshold).nonzero(as_tuple=True)[0]
        pred_edge_index = torch.cat((pred_edge_index, data.pred_edge[batch_mask]), dim=0)
    
    return pred_edge_index.type(torch.int).cpu()


def build_dataloader(dataset_name: str, t: int, additional_graphs=None, pred_from_add=True):
    n = 0 if additional_graphs is None else len(additional_graphs)
    print(f'\nLoading {dataset_name} Dataset with T={t} + {n}...')
    
    if n > 0 and pred_from_add:
        pred_graph = additional_graphs[-1]
    else:
        pred_graph = read_temporary_graph_data(
            f'data/SEALDataset/{dataset_name}/raw/{SEALDataset.info[dataset_name]["file"]}', 
            SEALDataset.info[dataset_name]["timespan"], t)[-1]
    assert pred_graph.number_of_nodes() > 30, f'The last temporal graph of {dataset_name} Dataset with T={t} is too small, consider to decrease T.'
    
    if dataset_name in ['WikiTalk', 'StackOverflow', 'SuperUser']:
        # For large dataset, we use SEALDataset just load last temporal graph data.
        # Cause splited data from last temporal graph's node feature size is much smaller than the whole last temporal graph.
        # So we should first load the whole last temporal graph data to get the max node feature size.
        pred_pos_data, pred_neg_data = SEALPredData.get_pos_neg_data(pred_graph)
        pred_data_list = SEALPredData.toSEAL_pred_datalist(pred_pos_data, pred_neg_data)
        print(f'Last temporal graph T={t} + {n} with {SEALPredData._max_z} node feature size.')
        
        # change the pred_idx range to choose the best training data.
        params = {'pred_idx': slice(-1, None), 'num_hops': 2, 'T': t, 'max_z': SEALPredData._max_z, 'additional_graphs': additional_graphs}
        train_dataset = SEALDataset('data/SEALDataset', dataset_name, 'train', **params)
        val_dataset = SEALDataset('data/SEALDataset', dataset_name, 'val', **params)
        test_dataset = SEALDataset('data/SEALDataset', dataset_name, 'test', **params)
        # Avoid splited data have more node feature size than the whole last temporal graph.
        if SEALPredData._max_z < train_dataset.num_features-1:
            pred_data_list = SEALPredData.toSEAL_pred_datalist(pred_pos_data, pred_neg_data, num_features=train_dataset.num_features)
            print(f'Update pred data node feature size to {train_dataset.num_features}.')
    else:
        # For small dataset, we use SEALDatasetInMemory to load all temporal graph data.
        params = {'num_hops': 2, 'T': t, 'additional_graphs': additional_graphs}
        train_dataset = SEALDatasetInMemory('data/SEALDataset', dataset_name, 'train', **params)
        val_dataset = SEALDatasetInMemory('data/SEALDataset', dataset_name, 'val', **params)
        test_dataset = SEALDatasetInMemory('data/SEALDataset', dataset_name, 'test', **params)
        # Obviously, the whole temporal graph's node feature size with the max node feature size.
        # So we use this num_features as the pred_data_list's node feature size.
        pred_pos_data, pred_neg_data = SEALPredData.get_pos_neg_data(pred_graph)
        pred_data_list = SEALPredData.toSEAL_pred_datalist(pred_pos_data, pred_neg_data, num_features=train_dataset.num_features)
        print(f'Last temporal graph T={t} + {n} with {SEALPredData._max_z} node feature size.')
    
    train_loader = DataLoader(train_dataset, batch_size=BS, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BS)
    test_loader = DataLoader(test_dataset, batch_size=BS)
    pred_loader = DataLoader(pred_data_list, batch_size=BS)
    return train_loader, val_loader, test_loader, pred_loader


def build_model(train_dataset, hidden_channels=32, num_layers=3):
    model = DGCNN(hidden_channels=hidden_channels, num_layers=num_layers, train_dataset=train_dataset).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE)
    criterion = BCEWithLogitsLoss()
    return model, optimizer, criterion
    

def load_predict_graph(file_name: str):
    graph = read_graph_from_edgefile(file_name)
    # remove isolate node and self loop
    graph.remove_nodes_from(nx.isolates(graph))
    graph.remove_edges_from(nx.selfloop_edges(graph))
    return graph


def is_pred_edge_num_too_small(pred_edge_index, threshold=40):
    DG = nx.DiGraph()
    DG.add_edges_from(pred_edge_index.numpy())
    DG.remove_nodes_from(list(nx.isolates(DG)))
    DG.remove_edges_from(nx.selfloop_edges(DG))
    return DG.number_of_edges() < threshold


def is_pred_edge_num_too_many(pred_edge_index, threshold=7000):
    DG = nx.DiGraph()
    DG.add_edges_from(pred_edge_index.numpy())
    DG.remove_nodes_from(list(nx.isolates(DG)))
    DG.remove_edges_from(nx.selfloop_edges(DG))
    return DG.number_of_edges() > threshold


for dataset in datasets:
    predict_graphs = []
    t = T
    for n in range(1, N+1):
        # If current graph has been predicted, skip it.
        if osp.exists(f'data/SEALDataset/{dataset}/T{t}+{n}_pred_edge.pt'):
            print(f'{dataset}/T{t}+{n} has been predicted, just load.')
            predict_graphs.append(load_predict_graph(f'data/SEALDataset/{dataset}/T{t}+{n}_pred_edge.pt'))
            continue
        
        train_loader, val_loader, test_loader, pred_loader = build_dataloader(dataset, T, additional_graphs=predict_graphs)
        model, optimizer, criterion = build_model(train_loader.dataset)
        
        if osp.exists(f'data/models/SEAL_{dataset}_T{t}+{n-1}.pth'):
            model.load_state_dict(torch.load(f'data/models/SEAL_{dataset}_T{t}+{n-1}.pth'))
        else:
            print(f'\nTraining {dataset} Dataset with T={t}+{n-1}...')
            best_val_auc = test_auc = 0
            for epoch in tqdm(range(EPOCH)):
                loss = train(model, optimizer, criterion, train_loader)
                val_auc = test(model, val_loader)
                if val_auc > best_val_auc:
                    best_val_auc = val_auc
                    test_auc = test(model, test_loader)
                if epoch % 10 == 0 or epoch == EPOCH-1:
                    print(f'Epoch: [{epoch+1:02d}]/[{EPOCH:02d}], Loss: {loss:.4f}, Val: {val_auc:.4f}, Test: {test_auc:.4f}')
            print(f'Loss: {loss:.4f}, Val_best: {best_val_auc:.4f}, Test: {test_auc:.4f}')
            
            torch.save(model.state_dict(), f'data/models/SEAL_{dataset}_T{T}+{n-1}.pth')
        
        print(f'\nPredicting {dataset} Dataset with T={t} + {n}...')
        pred_edge_index = predict(model, pred_loader, threshold=0.5)
        if is_pred_edge_num_too_small(pred_edge_index):
            # Avoid no edge predicted, we will descend the threshold by 0.01 until we get some edges.
            for threshold in range(49, 0, -1):
                print(f'No edge predicted in {dataset} Dataset with T={t} + {n}, try threshold={threshold/100}...')
                pred_edge_index = predict(model, pred_loader, threshold=threshold/100)
                if not is_pred_edge_num_too_small(pred_edge_index):
                    break
        assert not is_pred_edge_num_too_small(pred_edge_index), f'No edge predicted in {dataset} Dataset with T={t} + {n}...'
        if is_pred_edge_num_too_many(pred_edge_index):
            # Avoid too many edges predicted, we will ascend the threshold by 0.01 until we get less edges.
            for threshold in range(51, 100, 1):
                print(f'Too many edges predicted in {dataset} Dataset with T={t} + {n}, try threshold={threshold/100}...')
                pred_edge_index = predict(model, pred_loader, threshold=threshold/100)
                if not is_pred_edge_num_too_many(pred_edge_index):
                    break
        assert not is_pred_edge_num_too_many(pred_edge_index), f'Too many edges predicted in {dataset} Dataset with T={t} + {n}...'
        torch.save(pred_edge_index, f'data/SEALDataset/{dataset}/T{t}+{n}_pred_edge.pt')
        # load current pred graph to predict_graphs
        print(f'load {dataset}/T{t}+{n}_pred_edge.pt to dataset')
        predict_graphs.append(load_predict_graph(f'data/SEALDataset/{dataset}/T{t}+{n}_pred_edge.pt'))

        torch.cuda.empty_cache()


Loading SuperUser Dataset with T=400 + 0...
Last temporal graph T=400 + 0 with 31 node feature size.

Training SuperUser Dataset with T=400+0...


  2%|▏         | 1/50 [00:04<03:29,  4.27s/it]

Epoch: [01]/[50], Loss: 0.6694, Val: 0.5032, Test: 0.4930


 22%|██▏       | 11/50 [00:06<00:10,  3.61it/s]

Epoch: [11]/[50], Loss: 0.6666, Val: 0.5967, Test: 0.5771


 44%|████▍     | 22/50 [00:09<00:05,  4.82it/s]

Epoch: [21]/[50], Loss: 0.6639, Val: 0.6734, Test: 0.6580


 62%|██████▏   | 31/50 [00:11<00:04,  4.23it/s]

Epoch: [31]/[50], Loss: 0.6615, Val: 0.6539, Test: 0.6867


 82%|████████▏ | 41/50 [00:13<00:01,  5.08it/s]

Epoch: [41]/[50], Loss: 0.6578, Val: 0.6143, Test: 0.6867


100%|██████████| 50/50 [00:15<00:00,  3.26it/s]

Epoch: [50]/[50], Loss: 0.6550, Val: 0.6106, Test: 0.6867
Loss: 0.6550, Val_best: 0.6846, Test: 0.6867

Predicting SuperUser Dataset with T=400 + 1...





load SuperUser/T400+1_pred_edge.pt to dataset

Loading SuperUser Dataset with T=400 + 1...
Last temporal graph T=400 + 1 with 57 node feature size.


Processing...
Embedding: 100%|██████████| 1/1 [00:13<00:00, 13.09s/it]
Done!



Training SuperUser Dataset with T=400+1...


  2%|▏         | 1/50 [00:00<00:38,  1.26it/s]

Epoch: [01]/[50], Loss: 0.6991, Val: 0.5232, Test: 0.4762


 22%|██▏       | 11/50 [00:04<00:14,  2.62it/s]

Epoch: [11]/[50], Loss: 0.6963, Val: 0.5020, Test: 0.4762


 42%|████▏     | 21/50 [00:08<00:11,  2.60it/s]

Epoch: [21]/[50], Loss: 0.6927, Val: 0.4850, Test: 0.4762


 62%|██████▏   | 31/50 [00:13<00:07,  2.49it/s]

Epoch: [31]/[50], Loss: 0.6891, Val: 0.4921, Test: 0.4762


 82%|████████▏ | 41/50 [00:17<00:03,  2.34it/s]

Epoch: [41]/[50], Loss: 0.6841, Val: 0.4687, Test: 0.4762


100%|██████████| 50/50 [00:20<00:00,  2.41it/s]

Epoch: [50]/[50], Loss: 0.6788, Val: 0.4668, Test: 0.4762
Loss: 0.6788, Val_best: 0.5232, Test: 0.4762

Predicting SuperUser Dataset with T=400 + 2...





Too many edges predicted in SuperUser Dataset with T=400 + 2, try threshold=0.51...
Too many edges predicted in SuperUser Dataset with T=400 + 2, try threshold=0.52...
Too many edges predicted in SuperUser Dataset with T=400 + 2, try threshold=0.53...
load SuperUser/T400+2_pred_edge.pt to dataset

Loading SuperUser Dataset with T=400 + 2...
Last temporal graph T=400 + 2 with 1 node feature size.


Processing...
Embedding: 100%|██████████| 1/1 [00:13<00:00, 13.08s/it]
Done!


Update pred data node feature size to 58.

Training SuperUser Dataset with T=400+2...


  2%|▏         | 1/50 [00:00<00:37,  1.32it/s]

Epoch: [01]/[50], Loss: 0.6895, Val: 0.4922, Test: 0.4688


 22%|██▏       | 11/50 [00:04<00:15,  2.53it/s]

Epoch: [11]/[50], Loss: 0.6864, Val: 0.4826, Test: 0.4688


 42%|████▏     | 21/50 [00:08<00:12,  2.39it/s]

Epoch: [21]/[50], Loss: 0.6835, Val: 0.4778, Test: 0.4688


 62%|██████▏   | 31/50 [00:13<00:08,  2.34it/s]

Epoch: [31]/[50], Loss: 0.6796, Val: 0.4897, Test: 0.4755


 82%|████████▏ | 41/50 [00:17<00:03,  2.35it/s]

Epoch: [41]/[50], Loss: 0.6755, Val: 0.4872, Test: 0.4755


100%|██████████| 50/50 [00:21<00:00,  2.34it/s]

Epoch: [50]/[50], Loss: 0.6709, Val: 0.4958, Test: 0.4755
Loss: 0.6709, Val_best: 0.5022, Test: 0.4755

Predicting SuperUser Dataset with T=400 + 3...
load SuperUser/T400+3_pred_edge.pt to dataset

Loading SuperUser Dataset with T=400 + 3...
Last temporal graph T=400 + 3 with 3 node feature size.



Processing...
Embedding: 100%|██████████| 1/1 [00:13<00:00, 13.35s/it]
Done!


Update pred data node feature size to 44.

Training SuperUser Dataset with T=400+3...


  2%|▏         | 1/50 [00:00<00:43,  1.13it/s]

Epoch: [01]/[50], Loss: 0.6910, Val: 0.4522, Test: 0.5217


 22%|██▏       | 11/50 [00:05<00:15,  2.44it/s]

Epoch: [11]/[50], Loss: 0.6876, Val: 0.5056, Test: 0.5691


 42%|████▏     | 21/50 [00:09<00:12,  2.37it/s]

Epoch: [21]/[50], Loss: 0.6836, Val: 0.5207, Test: 0.5778


 62%|██████▏   | 31/50 [00:13<00:08,  2.28it/s]

Epoch: [31]/[50], Loss: 0.6796, Val: 0.5271, Test: 0.5693


 82%|████████▏ | 41/50 [00:18<00:03,  2.40it/s]

Epoch: [41]/[50], Loss: 0.6743, Val: 0.5295, Test: 0.5712


100%|██████████| 50/50 [00:21<00:00,  2.28it/s]

Epoch: [50]/[50], Loss: 0.6683, Val: 0.5292, Test: 0.5711
Loss: 0.6683, Val_best: 0.5304, Test: 0.5711

Predicting SuperUser Dataset with T=400 + 4...
load SuperUser/T400+4_pred_edge.pt to dataset

Loading SuperUser Dataset with T=400 + 4...





Last temporal graph T=400 + 4 with 31 node feature size.


Processing...
Embedding: 100%|██████████| 1/1 [00:12<00:00, 12.88s/it]
Done!


Update pred data node feature size to 56.

Training SuperUser Dataset with T=400+4...


  2%|▏         | 1/50 [00:00<00:36,  1.34it/s]

Epoch: [01]/[50], Loss: 0.7067, Val: 0.5530, Test: 0.4988


 22%|██▏       | 11/50 [00:04<00:16,  2.38it/s]

Epoch: [11]/[50], Loss: 0.7016, Val: 0.5287, Test: 0.4869


 42%|████▏     | 21/50 [00:09<00:11,  2.56it/s]

Epoch: [21]/[50], Loss: 0.6971, Val: 0.4776, Test: 0.4869


 62%|██████▏   | 31/50 [00:13<00:07,  2.42it/s]

Epoch: [31]/[50], Loss: 0.6927, Val: 0.4662, Test: 0.4869


 82%|████████▏ | 41/50 [00:17<00:04,  2.00it/s]

Epoch: [41]/[50], Loss: 0.6873, Val: 0.4644, Test: 0.4869


100%|██████████| 50/50 [00:21<00:00,  2.37it/s]

Epoch: [50]/[50], Loss: 0.6824, Val: 0.4658, Test: 0.4869
Loss: 0.6824, Val_best: 0.5564, Test: 0.4869

Predicting SuperUser Dataset with T=400 + 5...
load SuperUser/T400+5_pred_edge.pt to dataset

Loading SuperUser Dataset with T=400 + 5...





Last temporal graph T=400 + 5 with 90 node feature size.


Processing...
Embedding: 100%|██████████| 1/1 [00:12<00:00, 12.88s/it]
Done!



Training SuperUser Dataset with T=400+5...


  2%|▏         | 1/50 [00:00<00:44,  1.10it/s]

Epoch: [01]/[50], Loss: 0.7051, Val: 0.4381, Test: 0.4321


 22%|██▏       | 11/50 [00:05<00:17,  2.20it/s]

Epoch: [11]/[50], Loss: 0.6999, Val: 0.4592, Test: 0.4500


 42%|████▏     | 21/50 [00:09<00:13,  2.13it/s]

Epoch: [21]/[50], Loss: 0.6953, Val: 0.4906, Test: 0.4862


 62%|██████▏   | 31/50 [00:14<00:08,  2.30it/s]

Epoch: [31]/[50], Loss: 0.6901, Val: 0.4985, Test: 0.4896


 82%|████████▏ | 41/50 [00:18<00:03,  2.46it/s]

Epoch: [41]/[50], Loss: 0.6838, Val: 0.5037, Test: 0.4975


100%|██████████| 50/50 [00:22<00:00,  2.18it/s]

Epoch: [50]/[50], Loss: 0.6773, Val: 0.5081, Test: 0.5021
Loss: 0.6773, Val_best: 0.5081, Test: 0.5021

Predicting SuperUser Dataset with T=400 + 6...
load SuperUser/T400+6_pred_edge.pt to dataset





In [5]:
predict_graphs

[<networkx.classes.digraph.DiGraph at 0x1700c5d6640>,
 <networkx.classes.digraph.DiGraph at 0x1704dd1bd30>,
 <networkx.classes.digraph.DiGraph at 0x17017d679a0>,
 <networkx.classes.digraph.DiGraph at 0x170fe7680a0>,
 <networkx.classes.digraph.DiGraph at 0x1725a368130>,
 <networkx.classes.digraph.DiGraph at 0x170fe769100>]

In [None]:
import networkx as nx
from matplotlib import pyplot as plt

for g in predict_graphs:    
    options = {
        'node_color': '#1f78b4',
        'node_size': 500,
        'width': 3,
        'arrowstyle': '-|>',
        'arrowsize': 12,
    }
    ax = plt.figure(figsize=(10, 10), dpi=100)
    nx.draw_networkx(g, pos=nx.spring_layout(g), **options)
    plt.show()