In [12]:
##########################################################################################
# Some of the code is adapted from:
# https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
##########################################################################################

In [13]:
import os
import os.path as osp

import torch
from torch.nn import Linear

# from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import (LastNeighborLoader, IdentityMessage,LastAggregator)
from torch_geometric import *
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
import networkx as nx
import numpy as np
import math
import copy
import re
import time

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
# msg structure:      [src_node_feature,edge_attr,dst_node_feature]
data_dir = "../data/camflow250"
train_data = torch.load(f"{data_dir}/graph_0.TemporalData")

max_node_num = 5045000
min_dst_idx, max_dst_idx = 0, max_node_num
neighbor_loader = LastNeighborLoader(max_node_num, size=5, device=device)

In [15]:
class GraphAttentionEmbedding(torch.nn.Module):
    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super(GraphAttentionEmbedding, self).__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels
        self.conv = TransformerConv(in_channels, out_channels // 2, heads=2,
                                    dropout=0.1, edge_dim=edge_dim)

    def forward(self, x, last_update, edge_index, t, msg):
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
        return self.conv(x, edge_index, edge_attr)

In [16]:
class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels):
        super(LinkPredictor, self).__init__()
        self.lin_src = Linear(in_channels, in_channels)
        self.lin_dst = Linear(in_channels, in_channels)
        self.lin_final = Linear(in_channels, 1)

    def forward(self, z_src, z_dst):
        h = self.lin_src(z_src) + self.lin_dst(z_dst)
        h = h.relu()
        return self.lin_final(h)

In [17]:
memory_dim = time_dim = embedding_dim = 200

memory = TGNMemory(
    max_node_num,
    train_data.msg.size(-1),
    memory_dim,
    time_dim,
    message_module=IdentityMessage(train_data.msg.size(-1), memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

gnn = GraphAttentionEmbedding(
    in_channels=memory_dim,
    out_channels=embedding_dim,
    msg_dim=train_data.msg.size(-1),
    time_enc=memory.time_enc,
).to(device)

link_pred = LinkPredictor(in_channels=embedding_dim).to(device)

optimizer = torch.optim.Adam(
    set(memory.parameters()) | set(gnn.parameters())
    | set(link_pred.parameters()), lr=0.0001)
criterion = torch.nn.BCEWithLogitsLoss()

# Helper vector to map global node indices to local ones.
assoc = torch.empty(max_node_num, dtype=torch.long, device=device)
BATCH=1024

In [18]:
@torch.no_grad()
def test_new(inference_data):
    # memory.eval()
    # gnn.eval()
    # link_pred.eval()

    total_loss = 0
    memory.reset_state()  # Start with a fresh memory.
    neighbor_loader.reset_state()  # Start with an empty graph.
    torch.manual_seed(12345)  # Ensure determi|nistic sampling across epochs.

    aps, aucs = [], []
    pos_o = []

    # test_loader = TemporalDataLoader(inference_data, batch_size=BATCH)
    for batch in inference_data.seq_batches(batch_size=BATCH):
        batch = batch.to(device)
        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg
        neg_dst = torch.randint(min_dst_idx, max_dst_idx + 1, (src.size(0),),
                                dtype=torch.long, device=device)

        n_id = torch.cat([src, pos_dst, neg_dst]).unique()
        n_id, edge_index, e_id = neighbor_loader(n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        z, last_update = memory(n_id)

        z = gnn(z, last_update, edge_index, inference_data.t[e_id], inference_data.msg[e_id])

        pos_out = link_pred(z[assoc[src]], z[assoc[pos_dst]])
        neg_out = link_pred(z[assoc[src]], z[assoc[neg_dst]])
        pos_o.append(pos_out)

        loss = criterion(pos_out, torch.ones_like(pos_out))
        loss += criterion(neg_out, torch.zeros_like(neg_out))
        total_loss += float(loss) * batch.num_events

        memory.update_state(src, pos_dst, t, msg)
        neighbor_loader.insert(src, pos_dst)

    loss = total_loss / inference_data.num_events
    return float(torch.tensor(aps).mean()), float(
        torch.tensor(aucs).mean()), pos_out.sigmoid().cpu(), neg_out.sigmoid().cpu(), loss

In [19]:
m = torch.load(f"{data_dir}/model_saved.pt")
memory, gnn, link_pred, neighbor_loader = m
memory.eval()
gnn.eval()
link_pred.eval()


LinkPredictor(
  (lin_src): Linear(in_features=200, out_features=200, bias=True)
  (lin_dst): Linear(in_features=200, out_features=200, bias=True)
  (lin_final): Linear(in_features=200, out_features=1, bias=True)
)

In [23]:
if os.path.exists(f"{data_dir}/val_ans_old.pt") is not True:
    graph_label = []
    all_loss = []
    start = time.time()
    for i in tqdm(range(0, 10)):
        path = f"{data_dir}/graph_" + str(i) + ".TemporalData"
        test_graph = torch.load(path)
        test_ap, test_auc, pos_out_test, neg_out_test, loss_test = test_new(test_graph)
        print(f'Graph:{i}, Loss: {loss_test:.4f}')
        all_loss.append(loss_test)
        graph_label.append(0)

    for i in tqdm(range(85, 95)):
        path = f"{data_dir}/graph_" + str(i) + ".TemporalData"
        test_graph = torch.load(path)
        test_ap, test_auc, pos_out_test, neg_out_test, loss_test = test_new(test_graph)
        print(f'Graph:{i}, Loss: {loss_test:.4f}')
        all_loss.append(loss_test)
        graph_label.append(0)
        
    for i in tqdm(range(50, 60)):
        path = f"{data_dir}/graph_" + str(i) + ".TemporalData"
        test_graph = torch.load(path)
        test_ap, test_auc, pos_out_test, neg_out_test, loss_test = test_new(test_graph)
        print(f'Graph:{i}, Loss: {loss_test:.4f}')
        all_loss.append(loss_test)
        graph_label.append(0)

    print(f"test cost time:{time.time() - start}")

    val_ans_old = [ all_loss, graph_label]
    torch.save(val_ans_old, f"{data_dir}/val_ans_old.pt")
else:
    val_ans = torch.load(f"{data_dir}/val_ans_old.pt")
    loss_list = []
    for i in val_ans[0]:
        loss_list.append(i)
    # threshold = max(loss_list)
    threshold = np.percentile(loss_list, 100)
    print(loss_list)
    print(f'Threshold: {threshold}')

[1.405900239944458, 1.3365707235150044, 1.3183009507393084, 1.3128296777532291, 1.3322297561500245, 1.3351100542706664, 1.3225050029356387, 1.3415021725618532, 1.3274185846294593, 1.3398106955165212, 1.3322762468697287, 1.3338331351276786, 1.3332143187656091, 1.3255792100753836, 1.3309522127227822, 1.3295645526742925, 1.3285021885428487, 1.337964107117759, 1.3322496228240153, 1.324715163767951, 1.3205544175068284, 1.3293767518624418, 1.3377300664416114, 1.3333816487297652, 1.3268602331870873, 1.3318964754859903, 1.3309308142580822, 1.3267603535192587, 1.326288911810971, 1.335318206003018]
Threshold: 1.405900239944458


In [24]:
def classifier_evaluation(y_test, y_test_pred):
    tn, fp, fn, tp =confusion_matrix(y_test, y_test_pred,labels=[False, True]).ravel()
    print('tn:',tn)
    print('fp:',fp)
    print('fn:',fn)
    print('tp:',tp)
    precision=tp/(tp+fp)
    recall=tp/(tp+fn)
    accuracy=(tp+tn)/(tp+tn+fp+fn)
    fscore=2*(precision*recall)/(precision+recall)
    auc_val=roc_auc_score(y_test, y_test_pred)
    print("precision:",precision)
    print("recall:",recall)
    print("fscore:",fscore)
    print("accuracy:",accuracy)
    print("auc_val:",auc_val)
    return precision,recall,fscore,accuracy,auc_val


In [25]:
if os.path.exists(f"{data_dir}/test_ans_old.pt") is not True:
    graph_label = []
    all_loss = []
    start = time.time()
    for i in tqdm(range(30, 96)):
        path = f"{data_dir}/graph_" + str(i) + ".TemporalData"
        test_graph = torch.load(path)
        test_ap, test_auc, pos_out_test, neg_out_test, loss_test = test_new(test_graph)
        print(f'Graph:{i}, Loss: {loss_test:.4f}')
        all_loss.append(loss_test)
        graph_label.append(0)
        
    for i in tqdm(range(129, 199)):
        path = f"{data_dir}/graph_" + str(i) + ".TemporalData"
        test_graph = torch.load(path)
        test_ap, test_auc, pos_out_test, neg_out_test, loss_test = test_new(test_graph)
        print(f'Graph:{i}, Loss: {loss_test:.4f}')
        all_loss.append(loss_test)
        graph_label.append(1)
        
    # for i in tqdm(range(211, 223)):
    #     path = f"{data_dir}/graph_" + str(i) + ".TemporalData"
    #     test_graph = torch.load(path)
    #     test_ap, test_auc, pos_out_test, neg_out_test, loss_test = test_new(test_graph)
    #     print(f'Graph:{i}, Loss: {loss_test:.4f}')
    #     all_loss.append(loss_test)
    #     graph_label.append(1)

    print(f"test cost time:{time.time() - start}")
    test_ans_old = [all_loss, graph_label]
    torch.save(test_ans_old, f"{data_dir}/test_ans_old.pt")
else:
    labels = []
    preds = []

    test_ans = torch.load(f"{data_dir}/test_ans_old.pt")
    test_loss_list = []
    index = 0
    for i in test_ans[0]:
        temp_loss = i
        label = test_ans[1][index]
        if temp_loss > threshold:
            pred = 1
        else:
            pred = 0

        labels.append(label)
        preds.append(pred)

        index += 1

        # If prediction is incorrect, print the information of the tested graph
        if pred != label:
            print(f"{index=} {temp_loss=} {label=} {pred=} {pred == label}")

    classifier_evaluation(labels, preds)

tn: 66
fp: 0
fn: 0
tp: 70
precision: 1.0
recall: 1.0
fscore: 1.0
accuracy: 1.0
auc_val: 1.0
