In [12]:
import torch
import tqdm
import math
import numpy as np

from args import *
from dataset import *
from utils import *
from model import CEALNetwork, GCNNetwork, load_model

import os.path as osp
from train import make_data_loader, train_step, test_evaluations

from module.madgap import *

Prepare evaluation

In [80]:
train_dataset, validation_dataset, test_dataset = make_dataset()
train_loader, val_loader, test_loader = make_data_loader(train_dataset, validation_dataset, test_dataset)

model, model_data = load_model(osp.join(args["result_path"], "CEAL/1717141328680483"))
model.eval()

print("ok")

ok


Find features of high prediction error compounds

In [79]:
with torch.no_grad():
    threshold = 0.5
    high_pred_compounds = []
    for i, batch_data in enumerate(test_loader):
        batch_data.to(get_device())
        out = model(batch_data, node_embedding=False)
    
        # reverse data scale
        min, max = get_data_scale(args)
        res_out = reverse_min_max_scalar_1d(out, min, max)
        res_y =  reverse_min_max_scalar_1d( batch_data.y, min, max)
        
        # get high prediction error compounds
        error =  (res_out.squeeze() - res_y).abs()
        index = torch.where(error > threshold)[0]
        compounds = [{'mid':data.mid,'idx':data.idx} for _,data in enumerate(batch_data[index])]
        high_pred_compounds.extend(compounds)
        
        # high_pred_compounds.append()
        
        torch.cuda.empty_cache()
        
    high_pred_compounds.sort(key=lambda x: int(x['idx']))
    for i,d in enumerate(high_pred_compounds):
        print(d)
        

{'mid': 'mp-1096329', 'idx': '124'}
{'mid': 'mp-1215149', 'idx': '145'}
{'mid': 'mp-34418', 'idx': '253'}
{'mid': 'mp-1345658', 'idx': '318'}
{'mid': 'mp-1183035', 'idx': '326'}
{'mid': 'mp-1096119', 'idx': '396'}
{'mid': 'mp-1096521', 'idx': '417'}
{'mid': 'mp-1096644', 'idx': '454'}
{'mid': 'mp-1215160', 'idx': '562'}
{'mid': 'mp-1097119', 'idx': '708'}
{'mid': 'mp-1079481', 'idx': '777'}
{'mid': 'mp-569025', 'idx': '1294'}
{'mid': 'mp-1183204', 'idx': '1352'}
{'mid': 'mp-1207049', 'idx': '1362'}
{'mid': 'mp-1214876', 'idx': '1408'}
{'mid': 'mp-1096502', 'idx': '1417'}
{'mid': 'mp-1097309', 'idx': '1424'}
{'mid': 'mp-1183274', 'idx': '1453'}
{'mid': 'mp-1064865', 'idx': '1546'}
{'mid': 'mp-866292', 'idx': '1556'}
{'mid': 'mp-1093785', 'idx': '1560'}
{'mid': 'mp-1096475', 'idx': '1561'}
{'mid': 'mp-1013708', 'idx': '1801'}
{'mid': 'mp-1228028', 'idx': '1820'}
{'mid': 'mp-1095925', 'idx': '2117'}
{'mid': 'mp-23861', 'idx': '2208'}
{'mid': 'mp-1097489', 'idx': '2418'}
{'mid': 'mp-109356

Calculate the MAD

In [81]:
mad_total = 0.0
with torch.no_grad():
    for i, batch_data in enumerate(train_loader):
        batch_data.to(get_device())
        node_embeddings = model(batch_data, node_embedding=True)

        print(node_embeddings.shape)
        
        in_arr = node_embeddings.cpu().detach().numpy()

        num_nodes = batch_data.x.shape[0]
        adj = torch.zeros((num_nodes, num_nodes))
        adj[batch_data.edge_index[0], batch_data.edge_index[1]] = 1
        mask_arr = adj.numpy()

        mad_single = mad_value(in_arr, mask_arr)
        mad_total += mad_single
        torch.cuda.empty_cache()
        # print(i, "single MAD:", mad_single)

print("MAD:", mad_total / (i + 1))

torch.Size([27398, 75])
torch.Size([27710, 75])
torch.Size([26402, 75])
torch.Size([27019, 75])
torch.Size([28867, 75])
torch.Size([27756, 75])
torch.Size([27194, 75])
torch.Size([27445, 75])
torch.Size([27830, 75])
torch.Size([27525, 75])
torch.Size([28118, 75])
torch.Size([30471, 75])
torch.Size([28136, 75])
torch.Size([26250, 75])
torch.Size([28366, 75])
torch.Size([28582, 75])
torch.Size([28020, 75])
torch.Size([28507, 75])
torch.Size([29366, 75])
torch.Size([27411, 75])
torch.Size([28173, 75])
torch.Size([29171, 75])
torch.Size([28525, 75])
torch.Size([28564, 75])
torch.Size([29747, 75])
torch.Size([3086, 75])
MAD: 0.11007307692307691


Calculate the reachable nodes

In [3]:
def reachable_nodes(num_nodes, edge_index, num_layers=1, ret_mat=False, device=get_device()):
    adj = torch.tensor(np.eye(num_nodes), dtype=int).to(device)
    adj[edge_index[0], edge_index[1]] = 1
    if num_layers == 1:
        return torch.sum(adj, dim=0).to(device) if ret_mat is False else adj

    res_adj = adj

    for i in range(0, num_layers - 1):
        columns = []
        for column in range(0, num_nodes):
            # get column data
            data = torch.where(res_adj[column] > 0)[0].to(device)
            # get all node index data
            data = torch.where(torch.sum(adj[data], dim=0) > 0)[0].to(device)

            line = torch.zeros(num_nodes, dtype=int).to(device)
            line[data] = 1
            columns.append(line)
        res_adj = torch.stack(columns).to(device)

    return torch.sum(res_adj, dim=0).to(device) if ret_mat is False else res_adj


def get_reachable_nodes(total_graphs, dataloader, num_layers, device=get_device()):
    predict_epochs = math.ceil(total_graphs / args["batch_size"])
    pbar = tqdm(total=predict_epochs)
    pbar.set_description("Progress")

    total_num_nodes = 0
    total_reachable_nodes = 0

    for i, batch_data in enumerate(dataloader):
        num_nodes = batch_data.num_nodes
        edge_index = batch_data.edge_index.to(device)

        batch_reachable_nodes = torch.sum(reachable_nodes(num_nodes, edge_index, num_layers, ret_mat=False, device=device), dim=0)
        total_reachable_nodes += batch_reachable_nodes

        total_num_nodes += num_nodes
        
        torch.cuda.empty_cache()
        pbar.update(1)
    pbar.close()

    avg_reachable_nodes = (total_reachable_nodes / total_num_nodes).item()
    avg_nodes_on_graph = total_num_nodes / total_graphs
    print(f"Layer {num_layers}")
    print("Average reachable nodes:", avg_reachable_nodes)
    print("Average nodes on a graph:", avg_nodes_on_graph)
    print(f"=================================")

    return avg_reachable_nodes, avg_nodes_on_graph

In [6]:
get_reachable_nodes(len(train_dataset), train_loader, 3)

Progress: 100%|██████████| 126/126 [04:31<00:00,  2.15s/it]

Layer 3
Average reachable nodes: 9.347262382507324
Average nodes on a graph: 28.122070779531324





(9.347262382507324, 28.122070779531324)

Calculate the shared reachable nodes

In [None]:
# def get_shared_reachable_nodes(total_graphs, dataloader, num_layers, device=get_device()):

total_graphs = len(train_dataset)
dataloader = train_loader
num_layers = 2
device = get_device()


predict_epochs = math.ceil(total_graphs / args["batch_size"])
pbar = tqdm(total=predict_epochs)
pbar.set_description("Progress")

total_num_nodes = 0
total_reachable_nodes = 0

for i, data in enumerate(train_dataset):
    data = train_dataset[1]
    num_nodes = data.num_nodes
    edge_index = data.edge_index.to(device)

    batch_shared_reachable_nodes = reachable_nodes(num_nodes, edge_index, num_layers, ret_mat=False, device=device)
    print(batch_shared_reachable_nodes)
    # total_reachable_nodes += batch_shared_reachable_nodes

    total_num_nodes += num_nodes

    torch.cuda.empty_cache()
    pbar.update(1)

    # if i == 10:
    break
# pbar.close()

# avg_reachable_nodes = (total_reachable_nodes / total_num_nodes).item()
# avg_nodes_on_graph = total_num_nodes / total_graphs
# print(f"Layer {num_layers}")
# print("Average reachable nodes:", avg_reachable_nodes)
# print("Average nodes on a graph:", avg_nodes_on_graph)
# print(f"=================================")

# return avg_reachable_nodes, avg_nodes_on_graph# def get_shared_reachable_nodes(total_graphs, dataloader, num_layers, device=get_device()):

total_graphs = len(train_dataset)
dataloader = train_loader
num_layers = 2
device = get_device()


predict_epochs = math.ceil(total_graphs / args["batch_size"])
pbar = tqdm(total=predict_epochs)
pbar.set_description("Progress")

total_num_nodes = 0
total_reachable_nodes = 0

for i, data in enumerate(train_dataset):
    data = train_dataset[1]
    num_nodes = data.num_nodes
    edge_index = data.edge_index.to(device)

    batch_shared_reachable_nodes = reachable_nodes(num_nodes, edge_index, num_layers, ret_mat=False, device=device)
    print(batch_shared_reachable_nodes)
    # total_reachable_nodes += batch_shared_reachable_nodes

    total_num_nodes += num_nodes

    torch.cuda.empty_cache()
    pbar.update(1)

    # if i == 10:
    break
# pbar.close()

# avg_reachable_nodes = (total_reachable_nodes / total_num_nodes).item()
# avg_nodes_on_graph = total_num_nodes / total_graphs
# print(f"Layer {num_layers}")
# print("Average reachable nodes:", avg_reachable_nodes)
# print("Average nodes on a graph:", avg_nodes_on_graph)
# print(f"=================================")

# return avg_reachable_nodes, avg_nodes_on_graph

In [None]:
reachability = torch.tensor(np.array([[1, 1, 0, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], dtype=bool))

# 使用 np.all 检查每个节点在所有节点之间是否都可达
overlapping_nodes = torch.all(reachability, axis=0)
print(overlapping_nodes)