In [1]:
import torch
import tqdm
import math
import numpy as np

from args import *
from dataset import *
from model import CEALNetwork, GCNNetwork, load_model

import os.path as osp
from train import make_data_loader, train_step, test_evaluations

from module.madgap import *

  from .autonotebook import tqdm as notebook_tqdm


Prepare evaluation

In [2]:
train_dataset, validation_dataset, test_dataset = make_dataset()
train_loader, val_loader, test_loader = make_data_loader(train_dataset, validation_dataset, test_dataset)

model, model_data = load_model(osp.join(args["result_path"], "CEAL/1716963786873904"))
model.eval()

print("ok")

ok


Calculate the MAD

In [3]:
mad_total = 0.0
with torch.no_grad():
    for i, batch_data in enumerate(train_loader):
        batch_data.to(get_device())
        node_embeddings = model(batch_data, node_embedding=True)

        print(node_embeddings.shape)
        
        in_arr = node_embeddings.cpu().detach().numpy()

        num_nodes = batch_data.x.shape[0]
        adj = torch.zeros((num_nodes, num_nodes))
        adj[batch_data.edge_index[0], batch_data.edge_index[1]] = 1
        mask_arr = adj.numpy()

        mad_single = mad_value(in_arr, mask_arr)
        mad_total += mad_single
        torch.cuda.empty_cache()
        # print(i, "single MAD:", mad_single)

print("MAD:", mad_total / (i + 1))

torch.Size([5449, 100])
torch.Size([5946, 100])
torch.Size([5310, 100])
torch.Size([5300, 100])
torch.Size([5393, 100])
torch.Size([5134, 100])
torch.Size([5590, 100])
torch.Size([5890, 100])
torch.Size([5679, 100])
torch.Size([5417, 100])
torch.Size([5136, 100])
torch.Size([4932, 100])
torch.Size([6012, 100])


KeyboardInterrupt: 

Calculate the reachable nodes

In [5]:
def reachable_nodes(num_nodes, edge_index, num_layers=1, ret_mat=False, device=get_device()):
    adj = torch.tensor(np.zeros(num_nodes), dtype=int).to(device)
    adj[edge_index[0], edge_index[1]] = 1
    if num_layers == 1:
        return torch.sum(adj, dim=0).to(device) if ret_mat is False else adj

    res_adj = adj

    for i in range(0, num_layers - 1):
        columns = []
        for column in range(0, num_nodes):
            # get column data
            data = torch.where(res_adj[column] > 0)[0].to(device)
            # get all node index data
            data = torch.where(torch.sum(adj[data], dim=0) > 0)[0].to(device)

            line = torch.zeros(num_nodes, dtype=int).to(device)
            line[data] = 1
            columns.append(line)
        res_adj = torch.stack(columns).to(device)

    return torch.sum(res_adj, dim=0).to(device) if ret_mat is False else res_adj


def get_reachable_nodes(total_graphs, dataloader, num_layers, device=get_device()):
    predict_epochs = math.ceil(total_graphs / args["batch_size"])
    pbar = tqdm(total=predict_epochs)
    pbar.set_description("Progress")

    total_num_nodes = 0
    total_reachable_nodes = 0

    for i, batch_data in enumerate(dataloader):
        num_nodes = batch_data.num_nodes
        edge_index = batch_data.edge_index.to(device)

        batch_reachable_nodes = torch.sum(reachable_nodes(num_nodes, edge_index, num_layers, ret_mat=False, device=device), dim=0)
        total_reachable_nodes += batch_reachable_nodes

        total_num_nodes += num_nodes
        
        torch.cuda.empty_cache()
        pbar.update(1)
    pbar.close()

    avg_reachable_nodes = (total_reachable_nodes / total_num_nodes).item()
    avg_nodes_on_graph = total_num_nodes / total_graphs
    print(f"Layer {num_layers}")
    print("Average reachable nodes:", avg_reachable_nodes)
    print("Average nodes on a graph:", avg_nodes_on_graph)
    print(f"=================================")

    return avg_reachable_nodes, avg_nodes_on_graph

In [6]:
get_reachable_nodes(len(train_dataset), train_loader, 1)



AttributeError: 'int' object has no attribute 'shape'

Calculate the shared reachable nodes

In [103]:
# def get_shared_reachable_nodes(total_graphs, dataloader, num_layers, device=get_device()):

total_graphs = len(train_dataset)
dataloader = train_loader
num_layers = 2
device = get_device()


predict_epochs = math.ceil(total_graphs / args["batch_size"])
pbar = tqdm(total=predict_epochs)
pbar.set_description("Progress")

total_num_nodes = 0
total_reachable_nodes = 0

for i, data in enumerate(train_dataset):
    data = train_dataset[1]
    num_nodes = data.num_nodes
    edge_index = data.edge_index.to(device)

    batch_shared_reachable_nodes = reachable_nodes(num_nodes, edge_index, num_layers, ret_mat=False, device=device)
    print(batch_shared_reachable_nodes)
    # total_reachable_nodes += batch_shared_reachable_nodes

    total_num_nodes += num_nodes

    torch.cuda.empty_cache()
    pbar.update(1)

    # if i == 10:
    break
# pbar.close()

# avg_reachable_nodes = (total_reachable_nodes / total_num_nodes).item()
# avg_nodes_on_graph = total_num_nodes / total_graphs
# print(f"Layer {num_layers}")
# print("Average reachable nodes:", avg_reachable_nodes)
# print("Average nodes on a graph:", avg_nodes_on_graph)
# print(f"=================================")

# return avg_reachable_nodes, avg_nodes_on_graph# def get_shared_reachable_nodes(total_graphs, dataloader, num_layers, device=get_device()):

total_graphs = len(train_dataset)
dataloader = train_loader
num_layers = 2
device = get_device()


predict_epochs = math.ceil(total_graphs / args["batch_size"])
pbar = tqdm(total=predict_epochs)
pbar.set_description("Progress")

total_num_nodes = 0
total_reachable_nodes = 0

for i, data in enumerate(train_dataset):
    data = train_dataset[1]
    num_nodes = data.num_nodes
    edge_index = data.edge_index.to(device)

    batch_shared_reachable_nodes = reachable_nodes(num_nodes, edge_index, num_layers, ret_mat=False, device=device)
    print(batch_shared_reachable_nodes)
    # total_reachable_nodes += batch_shared_reachable_nodes

    total_num_nodes += num_nodes

    torch.cuda.empty_cache()
    pbar.update(1)

    # if i == 10:
    break
# pbar.close()

# avg_reachable_nodes = (total_reachable_nodes / total_num_nodes).item()
# avg_nodes_on_graph = total_num_nodes / total_graphs
# print(f"Layer {num_layers}")
# print("Average reachable nodes:", avg_reachable_nodes)
# print("Average nodes on a graph:", avg_nodes_on_graph)
# print(f"=================================")

# return avg_reachable_nodes, avg_nodes_on_graph

Progress:   1%|          | 1/126 [00:06<14:18,  6.87s/it]


tensor([4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')


Progress:   1%|          | 1/126 [00:00<00:01, 110.63it/s]
Progress:   0%|          | 0/126 [00:00<?, ?it/s]

tensor([4, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')


In [25]:
reachability = torch.tensor(np.array([[1, 1, 0, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], dtype=bool))

# 使用 np.all 检查每个节点在所有节点之间是否都可达
overlapping_nodes = torch.all(reachability, axis=0)
print(overlapping_nodes)

tensor([ True,  True, False,  True,  True])
