In [1]:
import torch
import tqdm
import math
import numpy as np

from args import *
from dataset import *
from utils import *
from model import CEALNetwork, GCNNetwork, load_model

import os.path as osp
from train import make_data_loader, train_step, test_evaluations

from module.madgap import *

  from .autonotebook import tqdm as notebook_tqdm


Prepare evaluations

In [2]:
train_dataset, validation_dataset, test_dataset = make_dataset()
train_loader, val_loader, test_loader = make_data_loader(train_dataset, validation_dataset, test_dataset)

model, model_data = load_model(osp.join(args["result_path"], "CEAL/1717141328680483"))
model.eval()

print("ok")

ok


Find features of high prediction error compounds

In [None]:
all_dataset = MPDataset(args)
print(all_dataset[0])

In [None]:
from scipy.spatial.distance import cdist
def data_stat(data_list):
    avg_num_edges = 0
    avg_num_nodes = 0
    for i,d in enumerate(data_list):
        avg_num_edges+=d.edge_index.shape[-1]
        avg_num_nodes+=d.x.shape[0]
    total_graphs = i+1
    return avg_num_nodes/total_graphs,avg_num_edges/total_graphs

def ase_data_stat(ase_data_list):
    avg_volume = 0.0
    avg_radius = 0.0
    
    for i,d in enumerate(ase_data_list):
        avg_volume+=d.get_volume()
        
        positions = d.get_positions()
        dist_matrix = cdist(positions, positions)
        max_distances = dist_matrix.max(axis=1)
        radius = max_distances.max() / 2.0
        avg_radius+=radius
        
    total = i+1
    return avg_volume/total,avg_radius/total

with torch.no_grad():
    threshold = 0.5
    high_pred_compounds = []
    for i, batch_data in enumerate(test_loader):
        batch_data.to(get_device())
        out = model(batch_data, node_embedding=False)
    
        # reverse data scale
        min, max = get_data_scale(args)
        res_out = reverse_min_max_scalar_1d(out, min, max)
        res_y =  reverse_min_max_scalar_1d( batch_data.y, min, max)
        
        # get high prediction error compounds
        error =  (res_out.squeeze() - res_y).abs()
        index = torch.where(error > threshold)[0]
        compounds = [{'mid':data.mid,'idx':data.idx} for _,data in enumerate(batch_data[index])]
        high_pred_compounds.extend(compounds)
        
        # high_pred_compounds.append()
        
        torch.cuda.empty_cache()
        
    high_pred_compounds.sort(key=lambda x: int(x['idx']))
    
    
    all_dataset = MPDataset(args)
    dataset_stat = data_stat(all_dataset)
    print(f"dataset.             avg_num_nodes:{round(dataset_stat[0],4)}, avg_num_edges:{round(dataset_stat[1],4)}")
    # read all ase file from the whole dataset
    whole_pred_ase = []
    for i,d in enumerate(all_dataset):
        path = osp.join(args["dataset_raw_dir"],f"CONFIG_{int(d.idx)}.poscar")
        compound = ase_read(path, format="vasp")
        whole_pred_ase.append(compound)
    dataset_ase_stat = ase_data_stat(whole_pred_ase)
    print(f"dataset.               avg_volume:{round(dataset_ase_stat[0],4)}, avg_radius:{round(dataset_ase_stat[1],4)}")

    high_pred_data = [all_dataset[int(d['idx'])-1] for i,d in enumerate(high_pred_compounds)]
    high_stat = data_stat(high_pred_data)
    print(f"high_pred_compounds. avg_num_nodes:{round(high_stat[0],4)}, avg_num_edges:{round(high_stat[1],4)}")
    # read all ase file from high pred error compounds
    high_pred_ase = []
    for i,d in enumerate(high_pred_data):
        path = osp.join(args["dataset_raw_dir"],f"CONFIG_{int(d.idx)}.poscar")
        compound = ase_read(path, format="vasp")
        high_pred_ase.append(compound)
    high_ase_stat = ase_data_stat(high_pred_ase)
    print(f"high_pred_compounds.  avg_volume:{round(high_ase_stat[0],4)}, avg_radius:{round(high_ase_stat[1],4)}")

Calculate the MAD

In [3]:
mad_total = 0.0
with torch.no_grad():
    for i, batch_data in enumerate(train_loader):
        batch_data.to(get_device())
        node_embeddings = model(batch_data, node_embedding=True)

        print(node_embeddings.shape)
        
        in_arr = node_embeddings.cpu().detach().numpy()

        num_nodes = batch_data.x.shape[0]
        adj = torch.zeros((num_nodes, num_nodes))
        adj[batch_data.edge_index[0], batch_data.edge_index[1]] = 1
        mask_arr = adj.numpy()

        mad_single = mad_value(in_arr, mask_arr)
        mad_total += mad_single
        torch.cuda.empty_cache()
        print(i, "single MAD:", mad_single)

print("MAD:", mad_total / (i + 1))

torch.Size([2388, 75])
0 single MAD: 0.1265
torch.Size([3061, 75])
1 single MAD: 0.1302
torch.Size([3024, 75])
2 single MAD: 0.1089
torch.Size([2922, 75])
3 single MAD: 0.1347
torch.Size([2842, 75])
4 single MAD: 0.0861
torch.Size([2468, 75])
5 single MAD: 0.1028
torch.Size([2836, 75])
6 single MAD: 0.0958
torch.Size([2464, 75])
7 single MAD: 0.0764
torch.Size([2388, 75])
8 single MAD: 0.1098
torch.Size([3005, 75])
9 single MAD: 0.1032
torch.Size([2444, 75])
10 single MAD: 0.1398
torch.Size([2690, 75])
11 single MAD: 0.0796
torch.Size([2746, 75])
12 single MAD: 0.0943
torch.Size([2844, 75])
13 single MAD: 0.1314
torch.Size([2787, 75])
14 single MAD: 0.1272
torch.Size([3103, 75])
15 single MAD: 0.0796
torch.Size([2879, 75])
16 single MAD: 0.101
torch.Size([2800, 75])
17 single MAD: 0.1141
torch.Size([2619, 75])
18 single MAD: 0.0713
torch.Size([2798, 75])
19 single MAD: 0.1365
torch.Size([2480, 75])
20 single MAD: 0.1022
torch.Size([2656, 75])
21 single MAD: 0.0795
torch.Size([2407, 75]

Calculate the reachable nodes

In [None]:
print(DATASET_PROCESSED_DIR)

def reachable_nodes(num_nodes, edge_index, num_layers=1, ret_mat=False, device=get_device()):
    adj = torch.tensor(np.eye(num_nodes), dtype=int).to(device)
    adj[edge_index[0], edge_index[1]] = 1
    if num_layers == 1:
        return torch.sum(adj, dim=0).to(device) if ret_mat is False else adj

    res_adj = adj

    for i in range(0, num_layers - 1):
        columns = []
        for column in range(0, num_nodes):
            # get column data
            data = torch.where(res_adj[column] > 0)[0].to(device)
            # get all node index data
            data = torch.where(torch.sum(adj[data], dim=0) > 0)[0].to(device)

            line = torch.zeros(num_nodes, dtype=int).to(device)
            line[data] = 1
            columns.append(line)
        res_adj = torch.stack(columns).to(device)

    return torch.sum(res_adj, dim=0).to(device) if ret_mat is False else res_adj

def get_reachable_nodes(total_graphs, dataloader, num_layers, device=get_device()):
    predict_epochs = math.ceil(total_graphs / args["batch_size"])
    pbar = tqdm(total=predict_epochs)
    pbar.set_description("Progress")

    total_num_nodes = 0
    total_reachable_nodes = 0

    for i, batch_data in enumerate(dataloader):
        num_nodes = batch_data.num_nodes
        edge_index = batch_data.edge_index.to(device)

        batch_reachable_nodes = torch.sum(reachable_nodes(num_nodes, edge_index, num_layers, ret_mat=False, device=device), dim=0)
        total_reachable_nodes += batch_reachable_nodes

        total_num_nodes += num_nodes
        
        torch.cuda.empty_cache()
        pbar.update(1)
    pbar.close()

    avg_reachable_nodes = (total_reachable_nodes / total_num_nodes).item()
    avg_nodes_on_graph = total_num_nodes / total_graphs
    print(f"Layer {num_layers}")
    print("Average reachable nodes:", avg_reachable_nodes)
    print("Average nodes on a graph:", avg_nodes_on_graph)
    print(f"=================================")

    return avg_reachable_nodes, avg_nodes_on_graph

./dataset.max_cutoff==3.0/processed


In [5]:
get_reachable_nodes(len(train_dataset), train_loader, 3)

Progress: 100%|██████████| 51/51 [04:30<00:00,  5.31s/it]

Layer 2
Average reachable nodes: 16.329570770263672
Average nodes on a graph: 28.281324725011956





(16.329570770263672, 28.281324725011956)

Calculate the shared reachable nodes

In [None]:
# def get_shared_reachable_nodes(total_graphs, dataloader, num_layers, device=get_device()):

total_graphs = len(train_dataset)
dataloader = train_loader
num_layers = 2
device = get_device()


predict_epochs = math.ceil(total_graphs / args["batch_size"])
pbar = tqdm(total=predict_epochs)
pbar.set_description("Progress")

total_num_nodes = 0
total_reachable_nodes = 0

for i, data in enumerate(train_dataset):
    data = train_dataset[1]
    num_nodes = data.num_nodes
    edge_index = data.edge_index.to(device)

    batch_shared_reachable_nodes = reachable_nodes(num_nodes, edge_index, num_layers, ret_mat=False, device=device)
    print(batch_shared_reachable_nodes)
    # total_reachable_nodes += batch_shared_reachable_nodes

    total_num_nodes += num_nodes

    torch.cuda.empty_cache()
    pbar.update(1)

    # if i == 10:
    break
# pbar.close()

# avg_reachable_nodes = (total_reachable_nodes / total_num_nodes).item()
# avg_nodes_on_graph = total_num_nodes / total_graphs
# print(f"Layer {num_layers}")
# print("Average reachable nodes:", avg_reachable_nodes)
# print("Average nodes on a graph:", avg_nodes_on_graph)
# print(f"=================================")

# return avg_reachable_nodes, avg_nodes_on_graph# def get_shared_reachable_nodes(total_graphs, dataloader, num_layers, device=get_device()):

total_graphs = len(train_dataset)
dataloader = train_loader
num_layers = 2
device = get_device()


predict_epochs = math.ceil(total_graphs / args["batch_size"])
pbar = tqdm(total=predict_epochs)
pbar.set_description("Progress")

total_num_nodes = 0
total_reachable_nodes = 0

for i, data in enumerate(train_dataset):
    data = train_dataset[1]
    num_nodes = data.num_nodes
    edge_index = data.edge_index.to(device)

    batch_shared_reachable_nodes = reachable_nodes(num_nodes, edge_index, num_layers, ret_mat=False, device=device)
    print(batch_shared_reachable_nodes)
    # total_reachable_nodes += batch_shared_reachable_nodes

    total_num_nodes += num_nodes

    torch.cuda.empty_cache()
    pbar.update(1)

    # if i == 10:
    break
# pbar.close()

# avg_reachable_nodes = (total_reachable_nodes / total_num_nodes).item()
# avg_nodes_on_graph = total_num_nodes / total_graphs
# print(f"Layer {num_layers}")
# print("Average reachable nodes:", avg_reachable_nodes)
# print("Average nodes on a graph:", avg_nodes_on_graph)
# print(f"=================================")

# return avg_reachable_nodes, avg_nodes_on_graph

In [None]:
reachability = torch.tensor(np.array([[1, 1, 0, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], dtype=bool))

# 使用 np.all 检查每个节点在所有节点之间是否都可达
overlapping_nodes = torch.all(reachability, axis=0)
print(overlapping_nodes)