In [1]:
import numpy as np
import os
import csv
import json

from pymatgen.io.cif import CifWriter, CifParser
from pymatgen.core.structure import Structure

#from torch_geometric.data import Data
#import torch

##Import example cif as pymatgen structure object
struc = CifParser('test_cif.cif').get_structures()[0]

In [2]:
from pymatgen.analysis.local_env import \
    LocalStructOrderParams, \
    VoronoiNN, \
    CrystalNN, \
    JmolNN, \
    MinimumDistanceNN, \
    MinimumOKeeffeNN, \
    EconNN, \
    BrunnerNN_relative, \
    MinimumVIRENN


#generate custom neighbor list to be used by all struc2's with nearest neighbor determination technique as parameter
def gen_neighborlist(struc, nn_strategy = 'crys', max_nn=12):
    NN = {
        # these methods consider too many neighbors which may lead to unphysical resutls
        'voro': VoronoiNN(tol=0.2),
        'econ': EconNN(),
        'brunner': BrunnerNN_relative(),

        # these two methods will consider motifs center at anions
        'crys': CrystalNN(),
        'jmol': JmolNN(),

        # not sure
        'minokeeffe': MinimumOKeeffeNN(),

        # maybe the best
        'mind': MinimumDistanceNN(),
        'minv': MinimumVIRENN()
    }

    nn = NN[nn_strategy]

    neigh_list = []
    neighborhoods = []

    for n in range(len(struc.sites)):
        neigh = []
        neigh = [neighbor for neighbor in nn.get_nn(struc, n)]
        for neighbor in neigh[:max_nn-1]:
                neighbor_index = neighbor.index
                offset = struc.frac_coords[neighbor_index] - struc.frac_coords[n] + neighbor.image
                m = struc.lattice.matrix
                offset = offset @ m
                distance = np.linalg.norm(offset)
                neigh_list.append([n, neighbor_index, offset, distance])

    return neigh_list

gen_neighborlist(struc)



[[0, 11, array([-0.65836544, -0.65836544,  2.10323456]), 2.300105588783085],
 [0, 7, array([ 0.65836544,  0.65836544, -2.10323456]), 2.300105588783085],
 [0, 9, array([-2.10323456,  0.65836544, -0.65836544]), 2.300105588783085],
 [0, 10, array([0.65836544, 2.10323456, 0.65836544]), 2.300105588783085],
 [0, 6, array([-0.65836544, -2.10323456, -0.65836544]), 2.300105588783085],
 [0, 5, array([ 2.10323456, -0.65836544,  0.65836544]), 2.300105588783085],
 [1, 10, array([ 0.65836544, -0.65836544, -2.10323456]), 2.300105588783085],
 [1, 6, array([-0.65836544,  0.65836544,  2.10323456]), 2.300105588783085],
 [1, 7, array([ 0.65836544, -2.10323456,  0.65836544]), 2.300105588783085],
 [1, 8, array([2.10323456, 0.65836544, 0.65836544]), 2.300105588783085],
 [1, 4, array([-2.10323456, -0.65836544, -0.65836544]), 2.300105588783085],
 [1, 11, array([-0.65836544,  2.10323456, -0.65836544]), 2.300105588783085],
 [2, 4, array([ 0.65836544,  2.10323456, -0.65836544]), 2.300105588783085],
 [2, 7, array(

In [28]:
import numpy as np
import math

import os
import os.path as osp
import csv
import json
import itertools
import time

from pymatgen.io.cif import CifParser
from pymatgen.core.structure import Structure
from pymatgen.analysis.local_env import \
    LocalStructOrderParams, \
    VoronoiNN, \
    CrystalNN, \
    JmolNN, \
    MinimumDistanceNN, \
    MinimumOKeeffeNN, \
    EconNN, \
    BrunnerNN_relative, \
    MinimumVIRENN

#generate custom neighbor list to be used by all struc2's with nearest neighbor determination technique as parameter
def get_nbrlist(struc, nn_strategy = 'mind', max_nn=12):
    NN = {
        # these methods consider too many neighbors which may lead to unphysical resutls
        'voro': VoronoiNN(tol=0.2),
        'econ': EconNN(),
        'brunner': BrunnerNN_relative(),

        # these two methods will consider motifs center at anions
        'crys': CrystalNN(),
        'jmol': JmolNN(),

        # not sure
        'minokeeffe': MinimumOKeeffeNN(),

        # maybe the best
        'mind': MinimumDistanceNN(),
        'minv': MinimumVIRENN()
    }

    nn = NN[nn_strategy]

    center_idxs = []
    neighbor_idxs = []
    offsets = []
    distances = []

    reformat_nbr_lst = []

    for n in range(len(struc.sites)):
        neigh = []
        neigh = [neighbor for neighbor in nn.get_nn(struc, n)]

        neighbor_reformat=[]
        for neighbor in neigh[:max_nn-1]:
            neighbor_index = neighbor.index
            offset = struc.frac_coords[neighbor_index] - struc.frac_coords[n] + neighbor.image
            m = struc.lattice.matrix
            offset = offset @ m
            distance = np.linalg.norm(offset)
            center_idxs.append(n)
            neighbor_idxs.append(neighbor_index)
            offsets.append(offset)
            distances.append(distance)

            neighbor_reformat.append((neighbor_index, offset, distance))
        reformat_nbr_lst.append((n,neighbor_reformat))
    nbr_list = [center_idxs, neighbor_idxs, offsets, distances]

    return nbr_list, reformat_nbr_lst


#generate hypergraph dictionary elements for singleton sets
def struc2singletons(struc,  hgraph = [], tol=0.01, import_feat: bool = False, directory: str = "cif"):
    singletons = struc.get_neighbor_list(r = tol, exclude_self=False)[0]
    atom_count = 0
    for node in singletons:
        hgraph.append(['atom', atom_count, [atom_count], None])
        atom_count += 1
    #extract features (Z and possibly coordinates) for atoms
    site_lst = struc.sites
    features = [[],[]]
    i = 0
    for site in site_lst:
        #### IMPORTANT: ASSOCIATES SITE INDEX FOR LATER REFERENCE IN MOTIF GENERATION #####
        site.properties = {'index': i}
        features[1].append(site.coords) #Coordinate of sites
        z_site = [element.Z for element in site.species]
        features[0].append(z_site[0]) #Atomic num of sites
        i+=1
    #import features from CGCNN atom_init file
    if import_feat == True:
        with open(osp.join(directory,'atom_init.json')) as atom_init:
            atom_vecs = json.load(atom_init)
            features[0] = [torch.tensor(atom_vecs[f'{z}']).float() for z in features[0]]
    for hedge, feature in zip(hgraph, features[0]):
        hedge[3] = feature
    return hgraph




#define gaussian expansion for distance features

class gaussian_expansion(object):
    def __init__(self, dmin, dmax, steps):
        assert dmin<dmax
        self.dmin = dmin
        self.dmax = dmax
        self.steps = steps-1
        
    def expand(self, distance, sig=None, tolerance = 0.01):
        drange = self.dmax-self.dmin
        step_size = drange/self.steps
        if sig == None:
            sig = step_size/2
        ds = [self.dmin + i*step_size for i in range(self.steps+1)]
        expansion = [math.exp(-(distance-center)**2/(2*sig**2))  for center in ds]
        expansion = [i if i > tolerance else 0 for i in expansion]
        return expansion
    
#Add bond nodes to hgraph list
def struc2pairs(struc, hgraph, nbr_lst, radius = 4.5, gauss_dim: int = 24):

    pair_center_idx = nbr_lst[0]
    pair_neighbor_idx = nbr_lst[1]
    distances = nbr_lst[3]
    
    if gauss_dim != 1:
        ge = gaussian_expansion(dmin = 0, dmax = radius, steps = gauss_dim)
            
    bond_index = 0
    ## currently double counts pair-wise edges/makes undirected edges
    for pair_1,pair_2,dist in zip(pair_center_idx, pair_neighbor_idx,distances):
        if gauss_dim != 1:
            dist = ge.expand(dist)
        hgraph.append(['bond', bond_index, [pair_1, pair_2], dist])
        bond_index += 1
    
    return hgraph



#Add ALIGNN-like triplets with angle feature vector
def struc2triplets(struc, hgraph, nbr_lst, gauss_dim = 10):

    #requires second output of get_nbr_lst!!!!! leaving this as reminder
    reformat_nbr_lst = nbr_lst
    
    if gauss_dim != 1:
        ge = gaussian_expansion(dmin = -1, dmax = 1, steps = gauss_dim)
    
    triplet_index = 0
    for cnt_idx, neighborset in reformat_nbr_lst:
            for i in itertools.combinations(neighborset, 2):
                (pair_1_idx, offset_1, distance_1), (pair_2_idx, offset_2, distance_2) = i

                offset_1 = np.stack(offset_1)
                offset_2 = np.stack(offset_2)
                cos_angle = (offset_1 * offset_2).sum(-1) / (np.linalg.norm(offset_1, axis=-1) * np.linalg.norm(offset_2, axis=-1))

                #Stop-gap to fix nans from zero displacement vectors
                cos_angle = np.nan_to_num(cos_angle, nan=1)
                
                if gauss_dim != 1:
                    cos_angle = ge.expand(cos_angle)
            
                hgraph.append(['triplet', triplet_index, [cnt_idx, pair_1_idx, pair_2_idx], cos_angle])
                triplet_index += 1

    return hgraph

#Types of structure-order parameters to calculate
all_types = [ "cn",
        "sgl_bd",
        "bent",
        "tri_plan",
        "tri_plan_max",
        "reg_tri",
        "sq_plan",
        "sq_plan_max",
        "pent_plan",
        "pent_plan_max",
        "sq",
        "tet",
        "tet_max",
        "tri_pyr",
        "sq_pyr",
        "sq_pyr_legacy",
        "tri_bipyr",
        "sq_bipyr",
        "oct",
        "oct_legacy",
        "pent_pyr",
        "hex_pyr",
        "pent_bipyr",
        "hex_bipyr",
        "T",
        "cuboct",
        "cuboct_max",
        "see_saw_rect",
        "bcc",
        "q2",
        "q4",
        "q6",
        "oct_max",
        "hex_plan_max",
        "sq_face_cap_trig_pris"]

#Add motif hedges to hgraph
def struc2motifs(struc, hgraph, nbr_lst, types = all_types, lsop_tol = 0.05):
    reformat_nbr_lst = nbr_lst

    neighborhoods = []
    ####IMPORTANT: REQUIRES YOU RUN struc2atoms FIRST####
    for n, neighborset in reformat_nbr_lst:
        neigh_idxs = []
        for i in neighborset:
            neigh_idxs.append(i[0])
        neighborhoods.append([n, neigh_idxs])
        
    lsop = LocalStructOrderParams(types)
    motif_index = 0
    for site, neighs in neighborhoods:
        ##Calculate order parameters for Voronoii neighborhood (excluding center)
        feat = lsop.get_order_parameters(struc, site, indices_neighs = neighs)
        for n,f in enumerate(feat):
            if f == None:
                feat[n] = 0
            elif f > 1:
                feat[n] = f
            ##Account for tolerance:
            elif f > lsop_tol:
                feat[n] = f
            else:
                feat[n] = 0
        ##Add center to node set before adding to hgraph list
        neighs.append(site)
        hgraph.append(['motif', motif_index, neighs, feat])
        motif_index += 1

    return hgraph


#Include unit cell hedge for pooling
def struc2cell(struc, hgraph, random_x = True):
    if random_x == True:
        feat = np.random.rand(64)
    else:
        feat = None
    nodes = list(range(len(struc.sites)))
    hgraph.append(['cell', 0, nodes, feat])
    return hgraph

## Now bring together process into overall hgraph generation
def hgraph_gen(struc, dir = 'cif', cell = False):
    hgraph = []
    nbr_lst,reformat_nbr_lst = get_nbrlist(struc)
    hgraph = struc2singletons(struc, hgraph, directory= dir)
    hgraph = struc2pairs(struc, hgraph, nbr_lst)
    hgraph = struc2triplets(struc, hgraph, reformat_nbr_lst)
    hgraph = struc2motifs(struc, hgraph, reformat_nbr_lst)
    if cell == True:
        hgraph = struc2cell(struc, hgraph)
    
    return hgraph

## Helper functions for relatives heterograph construction

def ordertype(hgraph, string):
    order_hedge = []
    for hedge in hgraph:
        if hedge[0] == string:
            order_hedge.append(hedge)
    return order_hedge    

def decompose(hgraph, order_types=['atom','bond','triplet','motif']):
    sep = []
    for string in order_types:
        sep.append(ordertype(hgraph, string))
    return sep
    
def contains(big, small):
    if all(item in big for item in small):
        return True
    else:
        return False
    
def touches(one, two):
    if any(item in one for item in two):
        return True
    else:
        return False
    
##Define function that generates relatives heterograph edge indices
def hetero_rel_edges(hgraph, cell_vector = True):
    atoms, bonds, triplets, motifs = decompose(hgraph)
    edges = {}
    atom_atom_hom = [[],[]]
    for bond in bonds:
        atom_atom_hom[0].append(bond[2][0])
        atom_atom_hom[1].append(bond[2][1])
    edges['atom','bonds','atom'] = atom_atom_hom

    atom_bonds_het = [[],[]]
    for atom in atoms:
        for bond in bonds:
            if contains(bond[2],atom[2]):
                atom_bonds_het[0].append(atom[1])
                atom_bonds_het[1].append(bond[1])
    edges['atom','in','bond'] = atom_bonds_het

    atom_trip_het = [[],[]]
    for atom in atoms:
        for triplet in triplets:
            if contains(triplet[2],atom[2]):
                atom_trip_het[0].append(atom[1])
                atom_trip_het[1].append(triplet[1])
    edges['atom','in','triplet'] = atom_trip_het

    atom_motifs_het = [[],[]]
    for atom in atoms:
        for motif in motifs:
            if contains(motif[2],atom[2]):
                atom_motifs_het[0].append(atom[1])
                atom_motifs_het[1].append(motif[1])
    edges['atom','in','motif'] = atom_motifs_het

    bond_bond_hom = [[],[]]
    for bond1 in bonds:
        for bond2 in bonds:
            if bond1!=bond2:
                if touches(bond1[2], bond2[2]):
                    bond_bond_hom[0].append(bond1[1])
                    bond_bond_hom[1].append(bond2[1])

    edges['bond','touches','bond'] = bond_bond_hom

    bond_trip_het = [[],[]]
    for bond in bonds:
        for triplet in triplets:
            if contains(triplet[2],bond[2]):
                bond_trip_het[0].append(bond[1])
                bond_trip_het[1].append(triplet[1])
    edges['bond','in','triplet'] = bond_trip_het

    bond_motifs_het = [[],[]]
    for bond in bonds:
        for motif in motifs:
            if contains(motif[2],bond[2]):
                bond_motifs_het[0].append(bond[1])
                bond_motifs_het[1].append(motif[1])
    edges['bond','in','motif'] = bond_motifs_het


    trip_trip_hom = [[],[]]
    for t1 in triplets:
        for t2 in triplets:
            if t1!= t2:
                if touches(t1, t2):
                    trip_trip_hom[0].append(t1[1])
                    trip_trip_hom[1].append(t2[1])
    edges['triplet', 'touches', 'triplet'] = trip_trip_hom

    trip_motifs_het = [[],[]]
    for triplet in triplets:
        for motif in motifs:
            if contains(motif, triplet):
                trip_motifs_het[0].append(triplet[1])
                trip_motifs_het[1].append(motif[1])
    edges['triplet', 'in', 'motif'] = trip_motifs_het

    mot_mot_hom = [[],[]]
    for m1 in motifs:
        for m2 in motifs:
            if m1!=m2:
                if touches(m1, m2):
                    mot_mot_hom[0].append(m1[1])
                    mot_mot_hom[1].append(m2[1])
    edges['motif','touches','motif'] = mot_mot_hom

    if cell_vector == True:
        orders = ['motif', 'atom', 'bond']
        for string, order in zip(orders, decompose(hgraph, orders)):
            edge_idx = [[],[]]
            for ent in order:
                edge_idx[0].append(0)
                edge_idx[1].append(ent[1])
            edges['cell', 'contains', string] = edge_idx
    
    return edges


In [29]:
hgraph = hgraph_gen(struc)
print(hgraph)

[['atom', 0, [0], 25], ['atom', 1, [1], 25], ['atom', 2, [2], 25], ['atom', 3, [3], 25], ['atom', 4, [4], 16], ['atom', 5, [5], 16], ['atom', 6, [6], 16], ['atom', 7, [7], 16], ['atom', 8, [8], 16], ['atom', 9, [9], 16], ['atom', 10, [10], 16], ['atom', 11, [11], 16], ['bond', 0, [0, 6], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.318746233716891, 0.8878263808759485, 0.045293210065534495, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], ['bond', 1, [0, 11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.318746233716891, 0.8878263808759485, 0.045293210065534495, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], ['bond', 2, [0, 9], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.318746233716891, 0.8878263808759485, 0.045293210065534495, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], ['bond', 3, [0, 5], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.318746233716891, 0.8878263808759485, 0.045293210065534495, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], ['bond', 4, [0, 7], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.318746233716891, 0.8878263808759485, 0.045293210065534495, 0, 0, 0, 0, 0, 0, 0, 0, 0, 