In [None]:
# from model.CrystalGraph import CrystalGraph
from pymatgen.core.structure import Structure
from pymatgen.analysis.local_env import LocalStructOrderParams ,VoronoiNN
import pymatgen.core as mg
import multiprocessing
from multiprocessing import Pool
from tqdm import tqdm
import os
import pymatgen
from pymatgen.analysis.local_env import BrunnerNN_real
from pymatgen.analysis.local_env import NearNeighbors
import pickle
from multiprocessing import Pool
import numpy as np

## Functional Component

In [2]:
def ABX_position(structure : pymatgen.core.structure.Structure) -> dict:
    '''
        Input:
            structure
        Output:
            dict{
                'A':[pymatgen.core.periodic_table.Species], 
                'B':[pymatgen.core.periodic_table.Species], 
                'X':[pymatgen.core.periodic_table.Species]
                }
        discription:
            Counting the number of each species to determine X position, which has the
            greatest number.
            Using the data of ionic radius (ionic_radius) from pymatgen to determine 
            A & B of the species in the structure. If ionic radius isn't available, 
            we use the average cationic radius (average_cationic_radius) instead.
    '''
    Apos = []
    Bpos = []
    Xpos = []
    has_ionic_radius = True
    All_species = list(set(structure.species))
    All_species = sorted(All_species, key = lambda t: structure.species.count(t))
    Xpos.append(All_species[-1])
    All_species.remove(Xpos[0])
    
    # for item in All_species:
    #     if item.ionic_radius is None:
    #         has_ionic_radius = False
    has_ionic_radius = False
    
    if has_ionic_radius:
        All_species = sorted(All_species, key=lambda t: t.ionic_radius)
    else:
        All_species = sorted(All_species, key=lambda t: t.average_cationic_radius)
        
    if len(All_species) == 2:
        Apos.append(All_species[1])
        Bpos.append(All_species[0])
    elif len(All_species) == 3:
        if has_ionic_radius:
            mid = (All_species[0].ionic_radius + All_species[2].ionic_radius) / 2
            if All_species[1].ionic_radius < mid:
                Apos.append(All_species[2])
                Bpos.append(All_species[1])
                Bpos.append(All_species[0])
            else:
                Apos.append(All_species[2])
                Apos.append(All_species[1])
                Bpos.append(All_species[0])
        else:
            mid = (All_species[0].average_cationic_radius + All_species[2].average_cationic_radius) / 2
            if All_species[1].average_cationic_radius < mid:
                Apos.append(All_species[2])
                Bpos.append(All_species[1])
                Bpos.append(All_species[0])
            else:
                Apos.append(All_species[2])
                Apos.append(All_species[1])
                Bpos.append(All_species[0])
    else:
        Apos.append(All_species[3])
        Apos.append(All_species[2])
        Bpos.append(All_species[1])
        Bpos.append(All_species[0])
    
    return {'A' : Apos, 'B' : Bpos, 'X' : Xpos}

def coordination_number(structure : pymatgen.core.structure.Structure
                        , ABX_dict : dict) -> list:
    '''
        Input:
            structure,
            ABX_dict    --  the result of ABX_position()
        Output:
            list[int]   -- a list of coordination number for each site in structure
    '''
    ret_list = []
    vnn = VoronoiNN()
    for i, site in enumerate(structure.sites):
        if site.specie in ABX_dict['B']:
            ret_list.append(6)
        else:
            ret_list.append(vnn.get_cn(structure, i))
    return ret_list

## Generate Pickle Files
neighbour，abx，distance，atom，disrank，lattice，volume

In [4]:
def process_function(tmp):
    data_path = os.path.join(dataset, tmp)
    structure = Structure.from_file(data_path)
    ABX_dict = ABX_position(structure)
    reversed_dict = {value: key for key, value_list in ABX_dict.items() for value in value_list}
    coordination_nb = coordination_number(structure, ABX_dict)
    coordination.append(coordination_nb)
    
    all_nbs = structure.get_all_neighbors(8, include_index=True)
    all_nbs = [sorted(nbs, key=lambda x: x[1]) for nbs in all_nbs]
    nbs = []
    abx = []
    dis = []
    atom = []
    disrank = []
    # neighbour
    for i, site in enumerate(structure.sites):
        nbs.append([structure.sites[i]] + all_nbs[i][:coordination_nb[i]])
    # atom
    for i, site in enumerate(structure.sites):
        atom_tmp = []
        for nb in nbs[i]:
            atom_tmp.append(nb.specie.Z)
        atom.append(atom_tmp)
    # abx & distances
    for i, site in enumerate(structure.sites):
        tmp_abx = []
        tmp_dis = []
        tmp_disrank = []
        for j, nbr_site in enumerate(nbs[i]):
            # print(i, nbr_site.index)
            tmp_abx.append(reversed_dict[nbr_site.specie])
            tmp_dis.append(structure.get_distance(i, nbr_site.index) if j != 0 else 1)
        tmp_disrank = [row_idx for row_idx in range(len(tmp_dis))]
        abx.append(tmp_abx)
        dis.append(tmp_dis)
        disrank.append(tmp_disrank)
    # lattice
    lattice = list(structure.lattice.abc) + list(np.deg2rad(item) for item in list(structure.lattice.angles))
    # volume
    volume = [structure.volume]
    
    neigh_path = os.path.join(neighbour_result, tmp.rstrip('.cif'))
    with open(neigh_path, 'wb') as f_pickle:
        pickle.dump(nbs, f_pickle)
        
    abx_path = os.path.join(abx_result, tmp.rstrip('.cif'))
    with open(abx_path, 'wb') as f_pickle:
        pickle.dump(abx, f_pickle)
        
    distance_path = os.path.join(distance_result, tmp.rstrip('.cif'))
    with open(distance_path, 'wb') as f_pickle:
        pickle.dump(dis, f_pickle)
        
    atom_path = os.path.join(atom_result, tmp.rstrip('.cif'))
    with open(atom_path, 'wb') as f_pickle:
        pickle.dump(atom, f_pickle)
        
    disrank_path = os.path.join(disrank_result, tmp.rstrip('.cif'))
    with open(disrank_path, 'wb') as f_pickle:
        pickle.dump(disrank, f_pickle)
        
    lattice_path = os.path.join(lattice_result, tmp.rstrip('.cif'))
    with open(lattice_path, 'wb') as f_pickle:
        pickle.dump(lattice, f_pickle)
        
    volume_path = os.path.join(volume_result, tmp.rstrip('.cif'))
    with open(volume_path, 'wb') as f_pickle:
        pickle.dump(volume, f_pickle)


In [None]:
# TODO: Change the path to your data * 8
dataset = "../../../../../dataset/water_split/data/"
abx_result = "../../../result/WaterSplit_PGB/abx/"
distance_result = "../../../result/WaterSplit_PGB/distance/"
neighbour_result = "../../../result/WaterSplit_PGB/neighbour/"
atom_result = "../../../result/WaterSplit_PGB/atom/"
disrank_result = "../../../result/WaterSplit_PGB/disrank/"
lattice_result = "../../../result/WaterSplit_PGB/lattice/"
volume_result = "../../../result/WaterSplit_PGB/volume"
file_list = os.listdir(dataset)

All_A_position = set()
All_B_position = set()
All_X_position = set()
coordination = []
nb_ids = []

# Create Process Pool
num_processes = 80  # number of processes
pool = Pool(num_processes)

# Parameters of Process
task_parameters = file_list

# Choose the CIF File
filted_paramaters = []
for parameter in task_parameters:
    if parameter.endswith('.cif'):
        filted_paramaters.append(parameter)
task_parameters = filted_paramaters


# Using tqdm in multiprocessing to Run The Working Function
with tqdm(total=len(task_parameters)) as progress_bar:
    # Update progress bar
    def update(*args):
        progress_bar.update()

    results = []
    for parameter in task_parameters:
        # print(parameter)
        result = pool.apply_async(process_function, args=(parameter,), callback=update)
        results.append(result)
        
    for result in results:
        result.wait()
        
# Close the Progress Poll
pool.close()
pool.join()

## Padding

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
# from transformers.modeling_bert import BertPredictionHeadTransform, BertAttention, BertIntermediate, BertOutput
from transformers.models.bert.modeling_bert import BertPredictionHeadTransform, BertAttention, BertIntermediate, BertOutput

from transformers.configuration_utils import PretrainedConfig
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertPooler

import os
import pickle
import time
import numpy as np
import tqdm
import itertools

max_subgraph_size = 0
max_atom_size = 0
# TODO: Change the path to your data * 3
for i in os.listdir("/root/AI4Sci/workplace/atom2vec/result/atom"):
    filename = os.path.join("/root/AI4Sci/workplace/atom2vec/result/atom",i)
    with open(filename, "rb") as f1:
        loaded_data = pickle.load(f1)
        max_atom_size = max(max_atom_size, len(loaded_data))
        for sub_graph in loaded_data:
            max_subgraph_size = max(max_subgraph_size, len(sub_graph))


# #neighbour
# for i in os.listdir("/root/AI4Sci/workplace/atom2vec/result/neighbour"):
#     filename = os.path.join("/root/AI4Sci/workplace/atom2vec/result/neighbour",i)
#     with open(filename, "rb") as f1:
#         loaded_data = pickle.load(f1)
#         padded_data = [] 
#         for item in loaded_data:
#             tmp_data = item + [0 for item in range(max_subgraph_size - len(item))]
#             padded_data.append(tmp_data)
#         path = os.path.join('/root/AI4Sci/workplace/atom2vec/result/neighbour_padded', i)
#         with open(path, 'wb') as f_pickle:
#             pickle.dump(padded_data, f_pickle)
            
#atom
# TODO: Change the path to your data * 3
for i in os.listdir("/root/AI4Sci/workplace/atom2vec/result/atom"):
    filename = os.path.join("/root/AI4Sci/workplace/atom2vec/result/atom",i)
    with open(filename, "rb") as f1:
        loaded_data = pickle.load(f1)
        padded_data = [] 
        for item in loaded_data:
            tmp_data = item + [0 for item in range(max_subgraph_size - len(item))]
            padded_data.append(tmp_data)
        for item in range(max_atom_size - len(loaded_data)):
            padded_data.append([0 for _ in range(max_subgraph_size)])
        path = os.path.join('/root/AI4Sci/workplace/atom2vec/result/atom_padded', i)
        with open(path, 'wb') as f_pickle:
            pickle.dump(padded_data, f_pickle)
            
#abx
# TODO: Change the path to your data * 3
for i in os.listdir("/root/AI4Sci/workplace/atom2vec/result/abx"):
    filename = os.path.join("/root/AI4Sci/workplace/atom2vec/result/abx",i)
    with open(filename, "rb") as f1:
        loaded_data = pickle.load(f1)
        padded_data = [] 
        for item in loaded_data:
            tmp_data = item + ["pad" for item in range(max_subgraph_size - len(item))]
            padded_data.append(tmp_data)
        for item in range(max_atom_size - len(loaded_data)):
            padded_data.append(["pad" for _ in range(max_subgraph_size)])
        path = os.path.join('/root/AI4Sci/workplace/atom2vec/result/abx_padded', i)
        with open(path, 'wb') as f_pickle:
            pickle.dump(padded_data, f_pickle)
            
#distance
# TODO: Change the path to your data * 3
for i in os.listdir("/root/AI4Sci/workplace/atom2vec/result/distance"):
    filename = os.path.join("/root/AI4Sci/workplace/atom2vec/result/distance",i)
    with open(filename, "rb") as f1:
        loaded_data = pickle.load(f1)
        padded_data = [] 
        for item in loaded_data:
            tmp_data = item + [1000000 for item in range(max_subgraph_size - len(item))]
            padded_data.append(tmp_data)
        for item in range(max_atom_size - len(loaded_data)):
            padded_data.append([0] + [1000000 for _ in range(max_subgraph_size - 1)])
        path = os.path.join('/root/AI4Sci/workplace/atom2vec/result/distance_padded', i)
        with open(path, 'wb') as f_pickle:
            pickle.dump(padded_data, f_pickle)
            
#disrank
# TODO: Change the path to your data * 3
for i in os.listdir("/root/AI4Sci/workplace/atom2vec/result/disrank"):
    filename = os.path.join("/root/AI4Sci/workplace/atom2vec/result/disrank",i)
    with open(filename, "rb") as f1:
        loaded_data = pickle.load(f1)
        padded_data = [] 
        for item in loaded_data:
            tmp_data = item + [-1 for item in range(max_subgraph_size - len(item))]
            padded_data.append(tmp_data)
        for item in range(max_atom_size - len(loaded_data)):
            padded_data.append([0] + [-1 for _ in range(max_subgraph_size - 1)])
        path = os.path.join('/root/AI4Sci/workplace/atom2vec/result/disrank_padded', i)
        with open(path, 'wb') as f_pickle:
            pickle.dump(padded_data, f_pickle)

## View Result

In [None]:
max_seq_length = 64
filename = "../result/atom/15169"           # "../result/disrank/15169"
with open(filename, "rb") as f1:
    loaded_data = pickle.load(f1)
print(loaded_data[0])
loaded_data