In [1]:
import numpy as np
import os
import csv
import json

from pymatgen.io.cif import CifWriter, CifParser
from pymatgen.core.structure import Structure
from torch_geometric.data import Data
import torch

## struc -> tensor([[node_index,...],[hyper_edge_index,...]]) ##

## subfunction for hedge_list responsible for singleton/node hedges, needs small tolerance
## to include itself in get_neighbor_list 

## might need to rework to include features in each subfunction
## atom_init from CGCNN for singletons

## should be passed first in hgraph construction, since it assumes to be first 

def struc2singletons(struc,  hedge_list = [[],[]], tol=0.01):
    singletons = struc.get_neighbor_list(r = tol, exclude_self=False)[0]
    if hedge_list == [[],[]]:
        hedge_counter = 0
    else:
        hedge_counter = np.max(hedge_list[1]) + 1
    for node in singletons:
        hedge_list[0].append(node)
        hedge_list[1].append(hedge_counter)
        hedge_counter += 1
    return hedge_list




## CIF, tensor([[node_index,...],[hyper_edge_index,...]]) -> tensor([[node_index,...],[hyper_edge_index,...]]) ##

## subfunction for hedge_list responsible for pair-wise hedges
## can be based on threshold + min_rad or some interatomic radius

## might need to rework to include features in each subfunction
## gaussian distance for pairs

def struc2pairs(struc, hedge_list = [[],[]], radius: float = 3, min_rad: bool = True, tol: float = 0.1):
    if min_rad == False:
        nbr_lst = struc.get_neighbor_list(r = radius, exclude_self=True)
    elif min_rad == True:
        nbr_lst = struc.get_neighbor_list(r = 25, exclude_self=True)
        min_rad = np.min(nbr_lst[3])
        nbr_lst = struc.get_neighbor_list(r = min_rad+tol, exclude_self=True)

    pair_center_idx = nbr_lst[0]
    pair_neighbor_idx = nbr_lst[1]
    if hedge_list == []:
        hedge_counter = 0
    else:
        hedge_counter = np.max(hedge_list[1]) + 1

    ## currently double counts pair-wise edges
    for pair_1,pair_2 in zip(pair_center_idx, pair_neighbor_idx):
        hedge_list[0].append(pair_1)
        hedge_list[0].append(pair_2)

        hedge_list[1].append(hedge_counter)
        hedge_list[1].append(hedge_counter)
        hedge_counter += 1

    return hedge_list



## CIF, tensor([[node_index,...],[hyper_edge_index,...]]) -> tensor([[node_index,...],[hyper_edge_index,...]]) ##

## subfunction for hedge_list responsible for motif-wise hedges
## currently based only on distance (min+thresh or set range)

## might need to rework to include features in each subfunction
## motif order parameters

def struc2motifs(struc, hedge_list = [[],[]], radius: float = 3, min_rad: bool = True, tol: float = 0.1):
    if min_rad == False:
        nbr_lst = struc.get_neighbor_list(r = radius, exclude_self=True)
    elif min_rad == True:
        nbr_lst = struc.get_neighbor_list(r = 25, exclude_self=True)
        min_rad = np.min(nbr_lst[3])
        nbr_lst = struc.get_neighbor_list(r = min_rad+tol, exclude_self=True)

    pair_center_idx = nbr_lst[0]
    pair_neighbor_idx = nbr_lst[1]
    if hedge_list == [[],[]]:
        hedge_counter = 0
    else:
        hedge_counter = np.max(hedge_list[1]) + 1

    last_center = pair_center_idx[0]

    for pair_1,pair_2 in zip(pair_center_idx, pair_neighbor_idx):
        new_center = pair_1
        if last_center != new_center:
            hedge_list[0].append(last_pair_1)
            hedge_list[1].append(hedge_counter)
            last_center = pair_1
            hedge_counter += 1
            
        hedge_list[0].append(pair_2)
        hedge_list[1].append(hedge_counter)

        last_pair_1 = pair_1

    return hedge_list


## CIF -> tensor([[node_index,...],[hyper_edge_index,...]]) ##

# takes cif file and returns array (2 x num_nodes_in_hedges) of hedge index
# (as specified in the HypergraphConv doc of PyTorch Geometric)
# found by collecting neighbors within spec radius for each node in one hedge


def cif2hedges(cif_file, radius: float = 3, min_rad: bool = True, tol: float = 0.1):
    struc = CifParser(cif_file).get_structures()[0]
    hedge_list = struc2singletons(struc)
    hedge_list = struc2pairs(struc, hedge_list, radius, min_rad, tol)
    hedge_list = struc2motifs(struc, hedge_list, radius, min_rad, tol)
    return hedge_list

In [2]:
struc = CifParser('test_cif.cif').get_structures()[0]

In [3]:
hedge_list=[[],[]]
hedge_list = struc2singletons(struc, hedge_list)
hedge_list = struc2pairs(struc, hedge_list)
hedge_list = struc2motifs(struc, hedge_list)
print(hedge_list)

[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 6, 0, 11, 0, 9, 0, 7, 0, 10, 0, 5, 1, 4, 1, 11, 1, 6, 1, 10, 1, 7, 1, 8, 2, 5, 2, 8, 2, 7, 2, 11, 2, 9, 2, 4, 3, 10, 3, 8, 3, 5, 3, 9, 3, 4, 3, 6, 4, 8, 4, 2, 4, 3, 4, 1, 5, 0, 5, 3, 5, 9, 5, 2, 6, 3, 6, 10, 6, 1, 6, 0, 7, 11, 7, 0, 7, 1, 7, 2, 8, 1, 8, 4, 8, 3, 8, 2, 9, 5, 9, 2, 9, 3, 9, 0, 10, 6, 10, 0, 10, 1, 10, 3, 11, 2, 11, 1, 11, 0, 11, 7, 6, 11, 9, 7, 10, 5, 0, 4, 11, 6, 10, 7, 8, 1, 5, 8, 7, 11, 9, 4, 2, 10, 8, 5, 9, 4, 6, 3, 8, 2, 3, 1, 4, 0, 3, 9, 2, 5, 3, 10, 1, 0, 6, 11, 0, 1, 2, 7, 1, 4, 3, 2, 8, 5, 2, 3, 0, 9, 6, 0, 1, 3, 10, 2, 1, 0, 7], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31, 32, 32, 33, 33, 34, 34, 35, 35, 36, 36, 37, 37, 38, 38, 39, 39, 40, 40, 41, 41, 42, 42, 43, 43, 44, 44, 45, 45, 46, 46, 47, 47, 48, 48, 49, 49, 50, 50, 51, 51, 52, 52, 53, 53, 54, 54, 55, 55, 56, 56

In [4]:
types = [ "cn",
        "sgl_bd",
        "bent",
        "tri_plan",
        "tri_plan_max",
        "reg_tri",
        "sq_plan",
        "sq_plan_max",
        "pent_plan",
        "pent_plan_max",
        "sq",
        "tet",
        "tet_max",
        "tri_pyr",
        "sq_pyr",
        "sq_pyr_legacy",
        "tri_bipyr",
        "sq_bipyr",
        "oct",
        "oct_legacy",
        "pent_pyr",
        "hex_pyr",
        "pent_bipyr",
        "hex_bipyr",
        "T",
        "cuboct",
        "cuboct_max",
        "see_saw_rect",
        "bcc",
        "q2",
        "q4",
        "q6",
        "oct_max",
        "hex_plan_max",
        "sq_face_cap_trig_pris"]

In [5]:
from pymatgen.analysis.local_env import LocalStructOrderParams
lsop = LocalStructOrderParams(types)

#Dont forget to remove initial edges before neighborhood lists

hedge_list = [[],[]]
hedge_list = struc2motifs(struc, hedge_list)
motf_nindx = hedge_list[0][7:14]
motf_indx  = 1
print(motf_nindx)

len(lsop.get_order_parameters(struc, motf_indx, indices_neighs = motf_nindx))

[4, 11, 6, 10, 7, 8, 1]


  rijnorm.append(rij[j] / dist[j])


35

In [6]:
from pymatgen.analysis.local_env import LocalStructOrderParams
lsop = LocalStructOrderParams(types)

types = [ "cn",
        "sgl_bd",
        "bent",
        "tri_plan",
        "tri_plan_max",
        "reg_tri",
        "sq_plan",
        "sq_plan_max",
        "pent_plan",
        "pent_plan_max",
        "sq",
        "tet",
        "tet_max",
        "tri_pyr",
        "sq_pyr",
        "sq_pyr_legacy",
        "tri_bipyr",
        "sq_bipyr",
        "oct",
        "oct_legacy",
        "pent_pyr",
        "hex_pyr",
        "pent_bipyr",
        "hex_bipyr",
        "T",
        "cuboct",
        "cuboct_max",
        "see_saw_rect",
        "bcc",
        "q2",
        "q4",
        "q6",
        "oct_max",
        "hex_plan_max",
        "sq_face_cap_trig_pris"]


## CIF, tensor([[node_index,...],[hyper_edge_index,...]]) -> tensor([[node_index,...],[hyper_edge_index,...]]) ##

## subfunction for hedge_list responsible for motif-wise hedges
## currently based only on distance (min+thresh or set range)

## might need to rework to include features in each subfunction
## motif order parameters

def struc2motifs(struc, hedge_list = [[],[]], radius: float = 3, min_rad: bool = True, tol: float = 0.1):
    if min_rad == False:
        nbr_lst = struc.get_neighbor_list(r = radius, exclude_self=True)
    elif min_rad == True:
        nbr_lst = struc.get_neighbor_list(r = 25, exclude_self=True)
        min_rad = np.min(nbr_lst[3])
        nbr_lst = struc.get_neighbor_list(r = min_rad+tol, exclude_self=True)

    pair_center_idx = nbr_lst[0]
    pair_neighbor_idx = nbr_lst[1]
    if hedge_list == [[],[]]:
        hedge_counter = 0
    else:
        hedge_counter = np.max(hedge_list[1]) + 1

    last_center = pair_center_idx[0]
    neighborhoods = ([],[],[])
    neighborhood  = []

    for pair_1,pair_2 in zip(pair_center_idx, pair_neighbor_idx):
        new_center = pair_1
        if last_center != new_center:
            neighborhoods[0].append(hedge_counter)
            neighborhoods[1].append(last_pair_1)
            neighborhoods[2].append(neighborhood)
            hedge_list[0].append(last_pair_1)
            hedge_list[1].append(hedge_counter)
            last_center = pair_1
            neighborhood = []
            hedge_counter += 1
            
        hedge_list[0].append(pair_2)
        hedge_list[1].append(hedge_counter)
        neighborhood.append(pair_2)
        
        last_pair_1 = pair_1
        
    hedge_list[0].append(last_pair_1)
    hedge_list[1].append(hedge_counter)
    
    neighborhoods[0].append(hedge_counter)
    neighborhoods[1].append(last_pair_1)
    neighborhoods[2].append(neighborhood)
    
    lsop = LocalStructOrderParams(types)
    
    features = [[],[]]
    for hedge_idx, center_idx, neighbor_lst in zip(neighborhoods[0],neighborhoods[1],neighborhoods[2]):
        feature = lsop.get_order_parameters(struc, center_idx, indices_neighs = neighbor_lst)
        features[0].append(hedge_idx)
        features[1].append(feature)
            
    return hedge_list, features

In [7]:
hedge_list = [[],[]]
hedge_list = struc2motifs(struc, hedge_list)[1][1][0]
print(hedge_list)

[6.0, 2.220446049250313e-16, 7.121562665051939e-05, 0.0008384951159906882, 0.028653472314188584, 1.0658041506816435e-205, 0.14243183801701953, 0.3934481096907446, 0.030746076199275815, 0.389181844384812, 6.315412975334043e-31, 0.007254860759403706, 0.1562813946707834, 0.5113057454736281, 0.5814634668618692, 0.11932998001665218, 0.49161942485188304, 0.5308165961703892, 0.20006814693155736, 0.023694994598985445, 0.5354550624747225, 0.43823066521432336, 0.5223879964881926, 0.42443864403871856, 0.43036784535257916, 0.29334660406606244, 0.46704357385544876, 0.4068229648171845, 0.01447498552276311, 0.4058885720384663, 0.5165345556934031, 0.2803649979112953, 0.5308165961703892, 0.21853026521726338, 0.3604484545887907]


In [8]:
## struc -> tuple(tensor([[node_pos], ... ]),tensor([node_atomic_num,...])) ##

#takes cif file and returns tuple of a tensor of node positions and a tensor
# of nodes atomic number, indexed same as cif2graphedges
def cif2nodepos(struc):
    site_lst = struc.sites
    nodepos_lst = []
    nodespec_lst = []
    for site in site_lst:
        nodepos_lst.append(site.coords) #Coordinate of sites
        z_site = [element.Z for element in site.species]
        nodespec_lst.append(z_site) #Atomic number list of site species (should always be single element list for crystal)
    nodepos_arr = np.array(nodepos_lst, dtype=float)
    nodespec_arr = np.squeeze(nodespec_lst)
    return  (torch.tensor(nodepos_arr),torch.tensor(nodespec_arr))

## struc -> tensor([[node_index,...],[hyper_edge_index,...]]) ##

## subfunction for hedge_list responsible for singleton/node hedges, needs small tolerance
## to include itself in get_neighbor_list 

## might need to rework to include features in each subfunction
## atom_init from CGCNN for singletons

## should be passed first in hgraph construction, since it assumes to be first 

def struc2singletons(struc,  hedge_list = [[],[]], tol=0.01, import_feat: bool = False, directory: str = ""):
    singletons = struc.get_neighbor_list(r = tol, exclude_self=False)[0]
    if hedge_list == [[],[]]:
        hedge_counter = 0
    else:
        hedge_counter = np.max(hedge_list[1]) + 1
    for node in singletons:
        hedge_list[0].append(node)
        hedge_list[1].append(hedge_counter)
        hedge_counter += 1
        
    
    features = [hedge_list[1],[],[]]
    
    site_lst = struc.sites
    for site in site_lst:
        features[1].append(site.coords) #Coordinate of sites
        z_site = [element.Z for element in site.species]
        features[2].append(z_site[0]) #Atomic num of sites
    
    if import_feat == True:
        with open(f'{directory}atom_init.json') as atom_init:
            atom_vecs = json.load(atom_init)
            features[2] = [atom_vecs[f'{z}'] for z in features[2]]

    return hedge_list, features


In [9]:
hedge_list = [[],[]]
hedge_list = struc2singletons(struc, hedge_list, import_feat= True)[0]
print(hedge_list)

[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]


In [10]:
## struc, tensor([[node_index,...],[hyper_edge_index,...]]) -> tensor([[node_index,...],[hyper_edge_index,...]]) ##

## subfunction for hedge_list responsible for pair-wise hedges
## can be based on threshold + min_rad or some interatomic radius

## might need to rework to include features in each subfunction
## gaussian distance for pairs

def struc2pairs(struc, hedge_list = [[],[]], radius: float = 3, min_rad: bool = True, tol: float = 0.1, gauss_dim: int = 1):
    if min_rad == False:
        nbr_lst = struc.get_neighbor_list(r = radius, exclude_self=True)
    elif min_rad == True:
        nbr_lst = struc.get_neighbor_list(r = 25, exclude_self=True)
        min_rad = np.min(nbr_lst[3])
        nbr_lst = struc.get_neighbor_list(r = min_rad+tol, exclude_self=True)

    pair_center_idx = nbr_lst[0]
    pair_neighbor_idx = nbr_lst[1]
    distances = nbr_lst[3]
    
    if hedge_list == [[],[]]:
        hedge_counter = 0
    else:
        hedge_counter = np.max(hedge_list[1]) + 1

    features = [[],[]]
    ## currently double counts pair-wise edges
    for pair_1,pair_2,dist in zip(pair_center_idx, pair_neighbor_idx,distances):
        hedge_list[0].append(pair_1)
        hedge_list[0].append(pair_2)

        hedge_list[1].append(hedge_counter)
        hedge_list[1].append(hedge_counter)
        
        features[0].append(hedge_counter)
        features[1].append(dist)
        
        hedge_counter += 1
    if gauss_dim != 1:
        ge = gaussian_expansion(dmin = 0 ,dmax = radius + 5*tol, steps = gauss_dim)
        features[1]=[ge.expand(dist) for dist in features[1]]

    
    return hedge_list, features


## Gaussian distance expansion function for pair hedge features

import math

class gaussian_expansion(object):
    def __init__(self, dmin, dmax, steps):
        assert dmin<dmax
        self.dmin = dmin
        self.dmax = dmax
        self.steps = steps
        
    def expand(self, distance, sig=None):
        drange = self.dmax-self.dmin
        step_size = drange/self.steps
        if sig == None:
            sig = step_size/2
        ds = [self.dmin + i*step_size for i in range(self.steps)]
        expansion = [math.exp(-(distance-center)**2/(2*sig**2)) for center in ds]
        return expansion


In [11]:
hedge_list = [[],[]]
hedge_list = struc2pairs(struc, hedge_list, gauss_dim=5)[1][1][0]
print(hedge_list)

[4.1871975449302247e-10, 2.8937889790041914e-05, 0.03662961367938433, 0.8492193684981157, 0.36060314951808503]


In [12]:
ge = gaussian_expansion(2,10,8)
ge.expand(5)

[1.522997974471263e-08,
 0.00033546262790251185,
 0.1353352832366127,
 1.0,
 0.1353352832366127,
 0.00033546262790251185,
 1.522997974471263e-08,
 1.2664165549094176e-14]

In [13]:
def hedge_packer(hedge_list):
    new_hedges = [[],[]]
    hedge_nodes = []
    old_hedge = hedge_list[1][0]
    for idx,(node,hedge) in enumerate(zip(hedge_list[0],hedge_list[1])):
        new_hedge = hedge_list[1][idx]
        if new_hedge != old_hedge:
            new_hedges[0].append(old_hedge)
            old_hedge = hedge_list[1][idx]
            new_hedges[1].append(hedge_nodes)
            hedge_nodes = []
        hedge_nodes.append(node)
    new_hedges[0].append(new_hedge)
    new_hedges[1].append(hedge_nodes)
    return new_hedges



In [22]:
def relatives(hedge_pack):
    relative_list = [[],[]]
    for idx1, hedge in enumerate(hedge_pack[0]):
        #print(idx1)
        relative_list[0].append(hedge)
        relatives = []
        for idx2,nodes_contained in enumerate(hedge_pack[1]):
            #print(nodes_contained)
            if  all(item in nodes_contained for item in hedge_pack[1][idx1]) or all(item in hedge_pack[1][idx1] for item in nodes_contained):
                relatives.append(hedge_pack[0][idx2])
        relatives.remove(hedge)
        relative_list[1].append(relatives)
    return relative_list
    

In [16]:
hedge_list = [[],[]]
hedge_list = struc2singletons(struc, hedge_list)[0]
hedge_list = struc2pairs(struc, hedge_list, gauss_dim=1)[0]
print(hedge_list)
print(hedge_packer(hedge_list))

[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 6, 0, 11, 0, 9, 0, 7, 0, 10, 0, 5, 1, 4, 1, 11, 1, 6, 1, 10, 1, 7, 1, 8, 2, 5, 2, 8, 2, 7, 2, 11, 2, 9, 2, 4, 3, 10, 3, 8, 3, 5, 3, 9, 3, 4, 3, 6, 4, 8, 4, 2, 4, 3, 4, 1, 5, 0, 5, 3, 5, 9, 5, 2, 6, 3, 6, 10, 6, 1, 6, 0, 7, 11, 7, 0, 7, 1, 7, 2, 8, 1, 8, 4, 8, 3, 8, 2, 9, 5, 9, 2, 9, 3, 9, 0, 10, 6, 10, 0, 10, 1, 10, 3, 11, 2, 11, 1, 11, 0, 11, 7], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31, 32, 32, 33, 33, 34, 34, 35, 35, 36, 36, 37, 37, 38, 38, 39, 39, 40, 40, 41, 41, 42, 42, 43, 43, 44, 44, 45, 45, 46, 46, 47, 47, 48, 48, 49, 49, 50, 50, 51, 51, 52, 52, 53, 53, 54, 54, 55, 55, 56, 56, 57, 57, 58, 58, 59, 59, 60, 60, 61, 61, 62, 62, 63, 63, 64, 64, 65, 65, 66, 66, 67, 67]]
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,

In [18]:
# generates edge pairs for relatives graph from relatives function output

def relatives2graphedges(relative_list):
    rel_ge = [[],[]]
    for idx,pair_center in enumerate(relative_list[0]):
        for neighbor in relative_list[1][idx]:
            rel_ge[0].append(pair_center)
            rel_ge[1].append(neighbor)
    return rel_ge

In [23]:

relative = relatives(hedge_packer(hedge_list))
print(relative)
print(relatives2graphedges(relative))

[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67], [[12, 13, 14, 15, 16, 17, 40, 47, 49, 59, 61, 66], [18, 19, 20, 21, 22, 23, 39, 46, 50, 52, 62, 65], [24, 25, 26, 27, 28, 29, 37, 43, 51, 55, 57, 64], [30, 31, 32, 33, 34, 35, 38, 41, 44, 54, 58, 63], [18, 29, 34, 36, 37, 38, 39, 53], [17, 24, 32, 40, 41, 42, 43, 56], [12, 20, 35, 44, 45, 46, 47, 60], [15, 22, 26, 48, 49, 50, 51, 67], [23, 25, 31, 36, 52, 53, 54, 55], [14, 28, 33, 42, 56, 57, 58, 59], [16, 21, 30, 45, 60, 61, 62, 63], [13, 19, 27, 48, 64, 65, 66, 67], [0, 6, 47], [0, 11, 66], [0, 9, 59], [0, 7, 49], [0, 10, 61], [0, 5, 40], [1, 4, 39], [1, 11, 65], [1, 6, 46], [1, 10, 62], [1, 7, 50], [1, 8, 52], [2, 5, 43], [2, 8, 55], [2, 7, 51], [2, 11, 64], [2, 9, 57], [2, 4, 37], [3, 10, 63], [3, 8, 54], [3, 5, 41], [3, 