In [2]:
import torch
import tqdm
import math
import numpy as np

from args import *
from dataset import *
from utils import *
from model import CEALNetwork, GCNNetwork, load_model

import os.path as osp
from train import make_data_loader, train_step, test_evaluations

from module.madgap import *

Prepare evaluations

In [3]:
train_dataset, validation_dataset, test_dataset = make_dataset()
batch_size = 100
train_loader, val_loader, test_loader = make_data_loader(train_dataset, validation_dataset, test_dataset, batch_size=batch_size)

print("ok")

ok


Find features of high prediction error compounds

In [3]:
# from scipy.spatial.distance import cdist
# def data_stat(data_list):
#     avg_num_edges = 0
#     avg_num_nodes = 0
#     for i,d in enumerate(data_list):
#         avg_num_edges+=d.edge_index.shape[-1]
#         avg_num_nodes+=d.x.shape[0]
#     total_graphs = i+1
#     return avg_num_nodes/total_graphs,avg_num_edges/total_graphs

# def ase_data_stat(ase_data_list):
#     avg_volume = 0.0
#     avg_radius = 0.0
    
#     for i,d in enumerate(ase_data_list):
#         avg_volume+=d.get_volume()
        
#         positions = d.get_positions()
#         dist_matrix = cdist(positions, positions)
#         max_distances = dist_matrix.max(axis=1)
#         radius = max_distances.max() / 2.0
#         avg_radius+=radius
        
#     total = i+1
#     return avg_volume/total,avg_radius/total

# with torch.no_grad():
#     high_pred_compounds = []
#     for i, batch_data in enumerate(test_loader):
#         batch_data.to(get_device())
#         out = model(batch_data, node_embedding=False)
    
#         # reverse data scale
#         min, max = get_data_scale(args)
#         res_out = reverse_min_max_scalar_1d(out, min, max)
#         res_y =  reverse_min_max_scalar_1d( batch_data.y, min, max)
        
#         # get high prediction error compounds
#         error =  (res_out.squeeze() - res_y).abs()
#         index = torch.where(error > 1.0)[0]
#         compounds = [{'mid':data.mid,'idx':data.idx} for _,data in enumerate(batch_data[index])]
#         high_pred_compounds.extend(compounds)
        
#         # high_pred_compounds.append()
        
#         torch.cuda.empty_cache()
        
#     high_pred_compounds.sort(key=lambda x: int(x['idx']))
    
    
#     all_dataset = MPDataset(args)
#     dataset_stat = data_stat(all_dataset)
#     print(f"dataset.             avg_num_nodes:{round(dataset_stat[0],4)}, avg_num_edges:{round(dataset_stat[1],4)}")
#     # read all ase file from the whole dataset
#     whole_pred_ase = []
#     for i,d in enumerate(all_dataset):
#         path = osp.join(args["dataset_raw_dir"],f"CONFIG_{int(d.idx)}.poscar")
#         compound = ase_read(path, format="vasp")
#         whole_pred_ase.append(compound)
#     dataset_ase_stat = ase_data_stat(whole_pred_ase)
#     print(f"dataset.               avg_volume:{round(dataset_ase_stat[0],4)}, avg_radius:{round(dataset_ase_stat[1],4)}")

#     high_pred_data = [all_dataset[int(d['idx'])-1] for i,d in enumerate(high_pred_compounds)]
#     high_stat = data_stat(high_pred_data)
#     print(f"high_pred_compounds. avg_num_nodes:{round(high_stat[0],4)}, avg_num_edges:{round(high_stat[1],4)}")
#     # read all ase file from high pred error compounds
#     high_pred_ase = []
#     for i,d in enumerate(high_pred_data):
#         path = osp.join(args["dataset_raw_dir"],f"CONFIG_{int(d.idx)}.poscar")
#         compound = ase_read(path, format="vasp")
#         high_pred_ase.append(compound)
#     high_ase_stat = ase_data_stat(high_pred_ase)
#     print(f"high_pred_compounds.  avg_volume:{round(high_ase_stat[0],4)}, avg_radius:{round(high_ase_stat[1],4)}")

In [4]:
from scipy.spatial.distance import cdist


def data_stat(data_list):
    avg_num_edges = 0
    avg_num_nodes = 0
    for i, d in enumerate(data_list):
        avg_num_edges += d.edge_index.shape[-1]
        avg_num_nodes += d.x.shape[0]
    total_graphs = i + 1
    return avg_num_nodes / total_graphs, avg_num_edges / total_graphs


def ase_data_stat(ase_data_list):
    avg_volume = 0.0
    avg_radius = 0.0

    for i, d in enumerate(ase_data_list):
        avg_volume += d.get_volume()

        positions = d.get_positions()
        dist_matrix = cdist(positions, positions)
        max_distances = dist_matrix.max(axis=1)
        radius = max_distances.max() / 2.0
        avg_radius += radius

    total = i + 1
    return avg_volume / total, avg_radius / total


def get_high_pred_error_stats(dataloader, model_path, threshold=0.5):

    high_pred_compounds = []

    model, _ = load_model(model_path)
    model.eval()

    with torch.no_grad():
        for i, batch_data in enumerate(dataloader):
            batch_data.to(get_device())
            out = model(batch_data, node_embedding=False)

            # reverse data scale
            min, max = get_data_scale(args)
            res_out = reverse_min_max_scalar_1d(out, min, max)
            res_y = reverse_min_max_scalar_1d(batch_data.y, min, max)

            # get high prediction error compounds
            error = (res_out.squeeze() - res_y).abs()
            index = torch.where(error > threshold)[0]
            compounds = [{"mid": data.mid, "idx": data.idx} for _, data in enumerate(batch_data[index])]
            high_pred_compounds.extend(compounds)

            torch.cuda.empty_cache()

    high_pred_compounds.sort(key=lambda x: int(x["idx"]))

    all_dataset = MPDataset(args)
    dataset_stat = data_stat(all_dataset)
    print(f"dataset.             avg_num_nodes:{round(dataset_stat[0],4)}, avg_num_edges:{round(dataset_stat[1],4)}")
    # read all ase file from the whole dataset
    whole_pred_ase = []
    for i, d in enumerate(all_dataset):
        path = osp.join(args["dataset_raw_dir"], f"CONFIG_{int(d.idx)}.poscar")
        compound = ase_read(path, format="vasp")
        whole_pred_ase.append(compound)
    dataset_ase_stat = ase_data_stat(whole_pred_ase)
    print(f"dataset.               avg_volume:{round(dataset_ase_stat[0],4)}, avg_radius:{round(dataset_ase_stat[1],4)}")

    high_pred_data = [all_dataset[int(d["idx"]) - 1] for i, d in enumerate(high_pred_compounds)]
    high_stat = data_stat(high_pred_data)
    print(f"high_pred_compounds. avg_num_nodes:{round(high_stat[0],4)}, avg_num_edges:{round(high_stat[1],4)}")
    # read all ase file from high pred error compounds
    high_pred_ase = []
    for i, d in enumerate(high_pred_data):
        path = osp.join(args["dataset_raw_dir"], f"CONFIG_{int(d.idx)}.poscar")
        compound = ase_read(path, format="vasp")
        high_pred_ase.append(compound)
    high_ase_stat = ase_data_stat(high_pred_ase)
    print(f"high_pred_compounds.  avg_volume:{round(high_ase_stat[0],4)}, avg_radius:{round(high_ase_stat[1],4)}")


get_high_pred_error_stats(test_loader, osp.join(args["result_path"], "CEAL/1717229364248831"), threshold=0.5)

dataset.             avg_num_nodes:28.3963, avg_num_edges:164.8967
dataset.               avg_volume:710.3466, avg_radius:6.4753
high_pred_compounds. avg_num_nodes:17.8939, avg_num_edges:94.5455
high_pred_compounds.  avg_volume:1409.8142, avg_radius:6.3831


In [5]:
all_dataset = MPDataset(args)
dataset_stat = data_stat(all_dataset)
print(f"dataset.             avg_num_nodes:{round(dataset_stat[0],4)}, avg_num_edges:{round(dataset_stat[1],4)}")
# read all ase file from the whole dataset
whole_pred_ase = []
for i, d in enumerate(all_dataset):
    path = osp.join(args["dataset_raw_dir"], f"CONFIG_{int(d.idx)}.poscar")
    compound = ase_read(path, format="vasp")
    whole_pred_ase.append(compound)
dataset_ase_stat = ase_data_stat(whole_pred_ase)
print(f"dataset.               avg_volume:{round(dataset_ase_stat[0],4)}, avg_radius:{round(dataset_ase_stat[1],4)}")

dataset.             avg_num_nodes:28.3963, avg_num_edges:353.3166
dataset.               avg_volume:710.3466, avg_radius:6.4753


Calculate the MAD

In [5]:
def calculate_MAD(dataset, model_path, device=get_device()):
    mad_total = 0.0
    total_data_size = len(dataset)

    predict_epochs = total_data_size
    pbar = tqdm(total=predict_epochs)
    pbar.set_description("Progress")

    model, _ = load_model(model_path)
    model.eval()

    with torch.no_grad():
        for i, data in enumerate(dataset):
            data.to(device)
            node_embeddings = model(data, node_embedding=True)

            in_arr = node_embeddings.cpu().detach().numpy()

            num_nodes = data.x.shape[0]
            adj = torch.zeros((num_nodes, num_nodes))
            adj[data.edge_index[0], data.edge_index[1]] = 1
            mask_arr = adj.numpy()

            mad_single = mad_value(in_arr, mask_arr)
            mad_total += mad_single
            torch.cuda.empty_cache()

            pbar.update(1)
        pbar.close()

    return mad_total / total_data_size

print("MAD:", calculate_MAD(test_dataset, osp.join(args["result_path"], "CEAL/1716963786873904")))

Progress: 100%|██████████| 6272/6272 [00:59<00:00, 105.22it/s]

MAD: 0.22039625318877537





Calculate the reachable nodes and the shared reachable nodes

In [None]:
print(DATASET_PROCESSED_DIR)

# def reachable_nodes(num_nodes, edge_index, num_layers=1, ret_mat=False, device=get_device()):
#     adj = torch.tensor(np.eye(num_nodes), dtype=int).to(device)
#     adj[edge_index[0], edge_index[1]] = 1
#     if num_layers == 1:
#         return torch.sum(adj, dim=0).to(device) if ret_mat is False else adj

#     res_adj = adj

#     for i in range(0, num_layers - 1):
#         columns = []
#         for column in range(0, num_nodes):
#             # get column data
#             data = torch.where(res_adj[column] > 0)[0].to(device)
#             # get all node index data
#             data = torch.where(torch.sum(adj[data], dim=0) > 0)[0].to(device)

#             line = torch.zeros(num_nodes, dtype=int).to(device)
#             line[data] = 1
#             columns.append(line)
#         res_adj = torch.stack(columns).to(device)

#     return torch.sum(res_adj, dim=0).to(device) if ret_mat is False else res_adj


def reachable_nodes_mat(num_nodes, edge_index, num_layers=1, device=get_device()):
    adj = torch.tensor(np.eye(num_nodes), dtype=float).to(device)
    adj[edge_index[0], edge_index[1]] = 1
    if num_layers == 1:
        adj = adj.fill_diagonal_(0.0)
        return adj

    res_adj = adj.clone()
    for i in range(0, num_layers - 1):
        res_adj = res_adj @ adj

    res_adj = res_adj.fill_diagonal_(0.0)

    res_adj_idx = torch.where(res_adj >= 1.0)
    res_adj[res_adj_idx] = 1.0

    return res_adj


# def get_avg_reachable_nodes(total_graphs, dataloader, num_layers, device=get_device()):
#     predict_epochs = math.ceil(total_graphs / batch_size)
#     pbar = tqdm(total=predict_epochs)
#     pbar.set_description("Progress")

#     total_num_nodes = 0
#     total_reachable_nodes = 0

#     total_batch_size = 0
#     total_shared_nodes = 0

#     for i, batch_data in enumerate(dataloader):
#         num_nodes = batch_data.num_nodes
#         edge_index = batch_data.edge_index.to(device)

#         # reachable matrix
#         reachable_mat = reachable_nodes_mat(num_nodes, edge_index, num_layers, device=device)
#         reachable_nodes = torch.sum(reachable_mat)
#         total_reachable_nodes += reachable_nodes

#         # shared reachable matrix
#         shared_mat = reachable_mat @ reachable_mat
#         shared_mat = shared_mat.fill_diagonal_(0)
#         shared_nodes = torch.sum(shared_mat)
#         # solve avg. shared reachable nodes per node for this batch
#         shared_nodes = shared_nodes / (num_nodes * (num_nodes + 1))
#         total_shared_nodes += shared_nodes

#         total_batch_size += 1
#         total_num_nodes += num_nodes

#         torch.cuda.empty_cache()
#         pbar.update(1)
#     pbar.close()

#     avg_reachable_nodes = (total_reachable_nodes / total_num_nodes).item()
#     avg_shared_reachable_nodes = (total_shared_nodes / total_batch_size).item()
#     avg_nodes_on_graph = total_num_nodes / total_graphs
#     print(f"Layer {num_layers}")
#     print("Average reachable nodes:", round(avg_reachable_nodes, 4))
#     print("Average shared reachable nodes:", round(avg_shared_reachable_nodes, 4))
#     print("Average nodes on a graph:", round(avg_nodes_on_graph, 4))
#     print(f"=================================")

#     return avg_reachable_nodes, avg_nodes_on_graph

./dataset.max_cutoff==5.0/processed


In [None]:
def get_avg_reachable_nodes(dataset, num_layers, device=get_device()):
    total_num_nodes = 0
    total_reachable_nodes = 0

    total_data_size = len(dataset)
    total_shared_nodes = 0
    
    predict_epochs = total_data_size
    pbar = tqdm(total=predict_epochs)
    pbar.set_description("Progress")

    for i, data in enumerate(dataset):
        num_nodes = data.num_nodes
        edge_index = data.edge_index.to(device)

        # reachable matrix
        reachable_mat = reachable_nodes_mat(num_nodes, edge_index, num_layers, device=device)
        reachable_nodes = torch.sum(reachable_mat)
        total_reachable_nodes += reachable_nodes

        # shared reachable matrix
        shared_mat = reachable_mat @ reachable_mat
        shared_mat = shared_mat.fill_diagonal_(0)
        shared_nodes = torch.sum(shared_mat)
        # solve avg. shared reachable nodes per node for this batch
        shared_nodes = shared_nodes / (num_nodes * (num_nodes + 1))
        total_shared_nodes += shared_nodes

        total_num_nodes += num_nodes

        torch.cuda.empty_cache()
        pbar.update(1)
    pbar.close()

    avg_reachable_nodes = (total_reachable_nodes / total_num_nodes).item()
    avg_shared_reachable_nodes = (total_shared_nodes / total_data_size).item()
    avg_nodes_on_graph = total_num_nodes / total_data_size
    print(f"Layer {num_layers}")
    print("Average reachable nodes:", round(avg_reachable_nodes, 4))
    print("Average shared reachable nodes:", round(avg_shared_reachable_nodes, 4))
    print("Average nodes on a graph:", round(avg_nodes_on_graph, 4))
    print(f"=================================")

    return avg_reachable_nodes, avg_shared_reachable_nodes, avg_nodes_on_graph

In [None]:
get_avg_reachable_nodes(train_dataset, 1)
get_avg_reachable_nodes(train_dataset, 2)
get_avg_reachable_nodes(train_dataset, 3)

In [None]:
# edge_index = torch.tensor(
#     [
#         [0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 8, 8],
#         [0, 1, 2, 3, 0, 1, 8, 0, 2, 0, 3, 4, 5, 3, 4, 5, 6, 3, 4, 5, 4, 6, 7, 6, 7, 1, 8],
#     ]
# )

# reachable_mat = reachable_nodes_mat(9, edge_index, 10, device=get_device())
# print(reachable_mat)
# # reachable_nodes = torch.sum(reachable_mat)
# # print(reachable_nodes)

# shared_mat = reachable_mat @ reachable_mat
# shared_mat = shared_mat.fill_diagonal_(0)
# shared_nodes = torch.sum(shared_mat)
# print(shared_mat)
# print((shared_nodes / (7 * (7 + 1))))

# # reachable_mat_idx = torch.where(reachable_mat >= 1.0)
# # reachable_mat[shared_mat_idx] = 1.0

# # reachable_mat = reachable_nodes(9, edge_index, 2, ret_mat=True, device=get_device())
# # print(reachable_mat)

# # reachable_mat = reachable_mat.float()
# # mat = (reachable_mat @ reachable_mat)
# # print(mat)