In [1]:
import numpy as np
import os
import csv
import json

from pymatgen.io.cif import CifWriter, CifParser
from pymatgen.core.structure import Structure

#from torch_geometric.data import Data
#import torch

##Import example cif as pymatgen structure object
struc = CifParser('test_cif.cif').get_structures()[0]

In [2]:
from pymatgen.analysis.local_env import \
    LocalStructOrderParams, \
    VoronoiNN, \
    CrystalNN, \
    JmolNN, \
    MinimumDistanceNN, \
    MinimumOKeeffeNN, \
    EconNN, \
    BrunnerNN_relative, \
    MinimumVIRENN


#generate custom neighbor list to be used by all struc2's with nearest neighbor determination technique as parameter
def gen_neighborlist(struc, nn_strategy = 'crys', max_nn=12):
    NN = {
        # these methods consider too many neighbors which may lead to unphysical resutls
        'voro': VoronoiNN(tol=0.2),
        'econ': EconNN(),
        'brunner': BrunnerNN_relative(),

        # these two methods will consider motifs center at anions
        'crys': CrystalNN(),
        'jmol': JmolNN(),

        # not sure
        'minokeeffe': MinimumOKeeffeNN(),

        # maybe the best
        'mind': MinimumDistanceNN(),
        'minv': MinimumVIRENN()
    }

    nn = NN[nn_strategy]

    neigh_list = []
    neighborhoods = []

    for n in range(len(struc.sites)):
        neigh = []
        neigh = [neighbor for neighbor in nn.get_nn(struc, n)]
        for neighbor in neigh[:max_nn-1]:
                neighbor_index = neighbor.index
                offset = struc.frac_coords[neighbor_index] - struc.frac_coords[n] + neighbor.image
                m = struc.lattice.matrix
                offset = offset @ m
                distance = np.linalg.norm(offset)
                neigh_list.append([n, neighbor_index, offset, distance])

    return neigh_list

gen_neighborlist(struc)



[[0, 11, array([-0.65836544, -0.65836544,  2.10323456]), 2.300105588783085],
 [0, 7, array([ 0.65836544,  0.65836544, -2.10323456]), 2.300105588783085],
 [0, 9, array([-2.10323456,  0.65836544, -0.65836544]), 2.300105588783085],
 [0, 10, array([0.65836544, 2.10323456, 0.65836544]), 2.300105588783085],
 [0, 6, array([-0.65836544, -2.10323456, -0.65836544]), 2.300105588783085],
 [0, 5, array([ 2.10323456, -0.65836544,  0.65836544]), 2.300105588783085],
 [1, 10, array([ 0.65836544, -0.65836544, -2.10323456]), 2.300105588783085],
 [1, 6, array([-0.65836544,  0.65836544,  2.10323456]), 2.300105588783085],
 [1, 7, array([ 0.65836544, -2.10323456,  0.65836544]), 2.300105588783085],
 [1, 8, array([2.10323456, 0.65836544, 0.65836544]), 2.300105588783085],
 [1, 4, array([-2.10323456, -0.65836544, -0.65836544]), 2.300105588783085],
 [1, 11, array([-0.65836544,  2.10323456, -0.65836544]), 2.300105588783085],
 [2, 4, array([ 0.65836544,  2.10323456, -0.65836544]), 2.300105588783085],
 [2, 7, array(

In [None]:
import numpy as np
import math

import os
import os.path as osp
import csv
import json
import itertools
import time

from pymatgen.io.cif import CifParser
from pymatgen.core.structure import Structure
from pymatgen.analysis.local_env import \
    LocalStructOrderParams, \
    VoronoiNN, \
    CrystalNN, \
    JmolNN, \
    MinimumDistanceNN, \
    MinimumOKeeffeNN, \
    EconNN, \
    BrunnerNN_relative, \
    MinimumVIRENN


#generate custom neighbor list to be used by all struc2's with nearest neighbor determination technique as parameter
def get_nbrlist(struc, nn_strategy = 'mind', max_nn=12):
    NN = {
        # these methods consider too many neighbors which may lead to unphysical resutls
        'voro': VoronoiNN(tol=0.2),
        'econ': EconNN(),
        'brunner': BrunnerNN_relative(),

        # these two methods will consider motifs center at anions
        'crys': CrystalNN(),
        'jmol': JmolNN(),

        # not sure
        'minokeeffe': MinimumOKeeffeNN(),

        # maybe the best
        'mind': MinimumDistanceNN(),
        'minv': MinimumVIRENN()
    }

    nn = NN[nn_strategy]

    center_idxs = []
    neighbor_idxs = []
    offsets = []
    distances = []

    for n in range(len(struc.sites)):
        neigh = []
        neigh = [neighbor for neighbor in nn.get_nn(struc, n)]
        for neighbor in neigh[:max_nn-1]:
            neighbor_index = neighbor.index
            offset = struc.frac_coords[neighbor_index] - struc.frac_coords[n] + neighbor.image
            m = struc.lattice.matrix
            offset = offset @ m
            distance = np.linalg.norm(offset)
            center_idxs.append(n)
            neighbor_idxs.append(neighbor_index)
            offsets.append(offset)
            distances.append(distance)
    nbr_list = [center_idxs, neighbor_idxs, offsets, distances]

    return nbr_list


#generate hypergraph dictionary elements for singleton sets
def struc2singletons(struc,  hgraph = [], tol=0.01, import_feat: bool = True, directory: str = "cif"):
    singletons = struc.get_neighbor_list(r = tol, exclude_self=False)[0]
    atom_count = 0
    for node in singletons:
        hgraph.append(['atom', atom_count, [atom_count], None])
        atom_count += 1
    #extract features (Z and possibly coordinates) for atoms
    site_lst = struc.sites
    features = [[],[]]
    i = 0
    for site in site_lst:
        #### IMPORTANT: ASSOCIATES SITE INDEX FOR LATER REFERENCE IN MOTIF GENERATION #####
        site.properties = {'index': i}
        features[1].append(site.coords) #Coordinate of sites
        z_site = [element.Z for element in site.species]
        features[0].append(z_site[0]) #Atomic num of sites
        i+=1
    #import features from CGCNN atom_init file
    if import_feat == True:
        with open(osp.join(directory,'atom_init.json')) as atom_init:
            atom_vecs = json.load(atom_init)
            features[0] = [torch.tensor(atom_vecs[f'{z}']).float() for z in features[0]]
    for hedge, feature in zip(hgraph, features[0]):
        hedge[3] = feature
    return hgraph




#define gaussian expansion for distance features

class gaussian_expansion(object):
    def __init__(self, dmin, dmax, steps):
        assert dmin<dmax
        self.dmin = dmin
        self.dmax = dmax
        self.steps = steps-1
        
    def expand(self, distance, sig=None, tolerance = 0.01):
        drange = self.dmax-self.dmin
        step_size = drange/self.steps
        if sig == None:
            sig = step_size/2
        ds = [self.dmin + i*step_size for i in range(self.steps+1)]
        expansion = [math.exp(-(distance-center)**2/(2*sig**2))  for center in ds]
        expansion = [i if i > tolerance else 0 for i in expansion]
        return expansion
    
#Add bond nodes to hgraph list
def struc2pairs(struc, hgraph, nbr_lst = [], radius: float = 4, min_rad: bool = False, max_neighbor: float = 13, tol: float = 2, gauss_dim: int = 24):
    if nbr_lst == []:
        nbr_lst = struc.get_neighbor_list(r = radius, exclude_self=True)

    pair_center_idx = nbr_lst[0]
    pair_neighbor_idx = nbr_lst[1]
    distances = nbr_lst[3]
    
    if gauss_dim != 1:
        ge = gaussian_expansion(dmin = 0, dmax = radius, steps = gauss_dim)
            
    features = []
    n_count = 0
    bond_index = 0
    pair_1_last = pair_center_idx[0]
    ## currently double counts pair-wise edges/makes undirected edges
    for pair_1,pair_2,dist in zip(pair_center_idx, pair_neighbor_idx,distances):
        #Accounts for max_neighbor
        if pair_1 == pair_1_last:
            n_count +=1
        else:
            pair_1_last = pair_1
            n_count = 1
        if n_count < max_neighbor:
            if gauss_dim != 1:
                dist = ge.expand(dist)
            hgraph.append(['bond', bond_index, [pair_1, pair_2], dist])
            bond_index += 1
        
    return hgraph



#Add ALIGNN-like triplets with angle feature vector
def struc2triplets(struc, hgraph, nbr_lst = [], radius = 4, max_neighbor=12, gauss_dim = 40):
    if nbr_lst == []:
        nbr_lst = struc.get_neighbor_list(r = radius, exclude_self=True)

    pair_center_idx = nbr_lst[0]
    pair_neighbor_idx = nbr_lst[1]
    offsets = nbr_lst[2]
    
    if gauss_dim != 1:
        ge = gaussian_expansion(dmin = 0, dmax = 1, steps = gauss_dim)
    
    n_count = 0
    triplet_index = 0
    last_center_idx = pair_center_idx[0]
    local_neighs = []
    for center_idx,pair_idx,offset in zip(pair_center_idx, pair_neighbor_idx, offsets):
        if last_center_idx == center_idx:
            n_count +=1
            if n_count <= max_neighbor+1:
                local_neighs.append((pair_idx,offset))
        else:
            for i in itertools.combinations(local_neighs, 2):
                (pair_1_idx, offset_1), (pair_2_idx, offset_2) = i
                offset_1 = np.array(offset_1)
                offset_2 = np.array(offset_2)
                m = struc.lattice.matrix
                edge1 = offset_1 @ m
                edge2 = offset_2 @ m

                edge1 = np.stack(edge1)
                edge2 = np.stack(edge2)
                cos_angle = (edge1 * edge2).sum(-1) / (np.linalg.norm(edge1, axis=-1) * np.linalg.norm(edge2, axis=-1))

                #Stop-gap to fix nans from zero displacement vectors
                cos_angle = np.nan_to_num(cos_angle, nan=1)
                
                if gauss_dim != 1:
                    cos_angle = ge.expand(cos_angle)
                hgraph.append(['triplet', triplet_index, [last_center_idx, pair_1_idx, pair_2_idx], cos_angle])
                triplet_index += 1
            n_count = 1
            last_center_idx=center_idx
            local_neighs = []
            local_neighs.append((pair_idx, offset))
    for i in itertools.combinations(local_neighs, 2):
        (pair_1_idx, offset_1), (pair_2_idx, offset_2) = i
        offset_1 = np.array(offset_1)
        offset_2 = np.array(offset_2)
        m = struc.lattice.matrix
        edge1 = offset_1 @ m
        edge2 = offset_2 @ m

        edge1 = np.stack(edge1)
        edge2 = np.stack(edge2)
        cos_angle = (edge1 * edge2).sum(-1) / (np.linalg.norm(edge1, axis=-1) * np.linalg.norm(edge2, axis=-1))

        #Stop-gap to fix nans from zero displacement vectors
        cos_angle = np.nan_to_num(cos_angle, nan=1)
        
        if gauss_dim != 1:
            cos_angle = ge.expand(cos_angle)
        hgraph.append(['triplet', triplet_index, [last_center_idx, pair_1_idx, pair_2_idx], cos_angle])
        triplet_index += 1

    return hgraph

#Types of structure-order parameters to calculate
types = [ "cn",
        "sgl_bd",
        "bent",
        "tri_plan",
        "tri_plan_max",
        "reg_tri",
        "sq_plan",
        "sq_plan_max",
        "pent_plan",
        "pent_plan_max",
        "sq",
        "tet",
        "tet_max",
        "tri_pyr",
        "sq_pyr",
        "sq_pyr_legacy",
        "tri_bipyr",
        "sq_bipyr",
        "oct",
        "oct_legacy",
        "pent_pyr",
        "hex_pyr",
        "pent_bipyr",
        "hex_bipyr",
        "T",
        "cuboct",
        "cuboct_max",
        "see_saw_rect",
        "bcc",
        "q2",
        "q4",
        "q6",
        "oct_max",
        "hex_plan_max",
        "sq_face_cap_trig_pris"]

#Add motif hedges to hgraph
def struc2motifs(struc, hgraph, types = types, lsop_tol = 0.05):
    neighborhoods = []
    vnn = VoronoiNN(tol=0.1, targets=None)
    ####IMPORTANT: REQUIRES YOU RUN struc2atoms FIRST####
    for n in range(len(struc.sites)):
        neigh = [neighbor.properties['index'] for neighbor in vnn.get_nn(struc, n)]
        neighborhoods.append([n, neigh])
        neigh = []
        
    lsop = LocalStructOrderParams(types)
    motif_index = 0
    for site, neighs in neighborhoods:
        ##Calculate order parameters for Voronoii neighborhood (excluding center)
        feat = lsop.get_order_parameters(struc, site, indices_neighs = neighs)
        for n,f in enumerate(feat):
            if f == None:
                feat[n] = 0
            elif f > 1:
                feat[n] = f
            ##Account for tolerance:
            elif f > lsop_tol:
                feat[n] = f
            else:
                feat[n] = 0
        ##Add center to node set before adding to hgraph list
        neighs.append(site)
        hgraph.append(['motif', motif_index, neighs, feat])
        motif_index += 1

    return hgraph

#Include unit cell hedge for pooling
def struc2cell(struc, hgraph, random_x = True):
    if random_x == True:
        feat = np.random.rand(64)
    else:
        feat = None
    nodes = list(range(len(struc.sites)))
    hgraph.append(['cell', 0, nodes, feat])
    return hgraph

## Now bring together process into overall hgraph generation
def hgraph_gen(struc, dir = 'cif', cell = False):
    hgraph = []
    nbr_lst = get_nbrlist(struc)
    hgraph = struc2singletons(struc, hgraph, directory= dir)
    hgraph = struc2pairs(struc, hgraph, nbr_lst = nbr_lst)
    hgraph = struc2triplets(struc, hgraph, nbr_lst = nbr_lst)
    #hgraph = struc2motifs(struc, hgraph)
    if cell == True:
        hgraph = struc2cell(struc, hgraph)
    
    return hgraph

## Helper functions for relatives heterograph construction

def ordertype(hgraph, string):
    order_hedge = []
    for hedge in hgraph:
        if hedge[0] == string:
            order_hedge.append(hedge)
    return order_hedge    

def decompose(hgraph, order_types=['atom','bond','triplet','motif']):
    sep = []
    for string in order_types:
        sep.append(ordertype(hgraph, string))
    return sep
    
def contains(big, small):
    if all(item in big for item in small):
        return True
    else:
        return False
    
def touches(one, two):
    if any(item in one for item in two):
        return True
    else:
        return False
    
##Define function that generates relatives heterograph edge indices
def hetero_rel_edges(hgraph, cell_vector = True):
    atoms, bonds, triplets, motifs = decompose(hgraph)
    edges = {}
    atom_atom_hom = [[],[]]
    for bond in bonds:
        atom_atom_hom[0].append(bond[2][0])
        atom_atom_hom[1].append(bond[2][1])
    edges['atom','bonds','atom'] = atom_atom_hom

    atom_bonds_het = [[],[]]
    for atom in atoms:
        for bond in bonds:
            if contains(bond[2],atom[2]):
                atom_bonds_het[0].append(atom[1])
                atom_bonds_het[1].append(bond[1])
    edges['atom','in','bond'] = atom_bonds_het

    atom_trip_het = [[],[]]
    for atom in atoms:
        for triplet in triplets:
            if contains(triplet[2],atom[2]):
                atom_trip_het[0].append(atom[1])
                atom_trip_het[1].append(triplet[1])
    edges['atom','in','triplet'] = atom_trip_het

    atom_motifs_het = [[],[]]
    for atom in atoms:
        for motif in motifs:
            if contains(motif[2],atom[2]):
                atom_motifs_het[0].append(atom[1])
                atom_motifs_het[1].append(motif[1])
    edges['atom','in','motif'] = atom_motifs_het

    bond_bond_hom = [[],[]]
    for bond1 in bonds:
        for bond2 in bonds:
            if bond1!=bond2:
                if touches(bond1[2], bond2[2]):
                    bond_bond_hom[0].append(bond1[1])
                    bond_bond_hom[1].append(bond2[1])

    edges['bond','touches','bond'] = bond_bond_hom

    bond_trip_het = [[],[]]
    for bond in bonds:
        for triplet in triplets:
            if contains(triplet[2],bond[2]):
                bond_trip_het[0].append(bond[1])
                bond_trip_het[1].append(triplet[1])
    edges['bond','in','triplet'] = bond_trip_het

    bond_motifs_het = [[],[]]
    for bond in bonds:
        for motif in motifs:
            if contains(motif[2],bond[2]):
                bond_motifs_het[0].append(bond[1])
                bond_motifs_het[1].append(motif[1])
    edges['bond','in','motif'] = bond_motifs_het


    trip_trip_hom = [[],[]]
    for t1 in triplets:
        for t2 in triplets:
            if t1!= t2:
                if touches(t1, t2):
                    trip_trip_hom[0].append(t1[1])
                    trip_trip_hom[1].append(t2[1])
    edges['triplet', 'touches', 'triplet'] = trip_trip_hom

    trip_motifs_het = [[],[]]
    for triplet in triplets:
        for motif in motifs:
            if contains(motif, triplet):
                trip_motifs_het[0].append(triplet[1])
                trip_motifs_het[1].append(motif[1])
    edges['triplet', 'in', 'motif'] = trip_motifs_het

    mot_mot_hom = [[],[]]
    for m1 in motifs:
        for m2 in motifs:
            if m1!=m2:
                if touches(m1, m2):
                    mot_mot_hom[0].append(m1[1])
                    mot_mot_hom[1].append(m2[1])
    edges['motif','touches','motif'] = mot_mot_hom

    if cell_vector == True:
        orders = ['motif', 'atom', 'bond']
        for string, order in zip(orders, decompose(hgraph, orders)):
            edge_idx = [[],[]]
            for ent in order:
                edge_idx[0].append(0)
                edge_idx[1].append(ent[1])
            edges['cell', 'contains', string] = edge_idx
    
    return edges


In [3]:
hgraph = hgraph_gen(struc)


NameError: name 'hgraph_gen' is not defined

In [85]:
#generate hypergraph dictionary elements for singleton sets
def struc2singletons(struc,  hgraph = [], tol=0.01, import_feat: bool = False, directory: str = ""):
    singletons = struc.get_neighbor_list(r = tol, exclude_self=False)[0]
    atom_count = 0
    for node in singletons:
        hgraph.append(['atom', atom_count, [atom_count], None])
        atom_count += 1
    #extract features (Z and possibly coordinates) for atoms
    site_lst = struc.sites
    features = [[],[]]
    i = 0
    for site in site_lst:
        #### IMPORTANT: ASSOCIATES SITE INDEX FOR LATER REFERENCE IN MOTIF GENERATION #####
        site.properties = {'index': i}
        features[1].append(site.coords) #Coordinate of sites
        z_site = [element.Z for element in site.species]
        features[0].append(z_site[0]) #Atomic num of sites
        i+=1
    #import features from CGCNN atom_init file
    if import_feat == True:
        with open(f'{directory}atom_init.json') as atom_init:
            atom_vecs = json.load(atom_init)
            features[0] = [torch.tensor(atom_vecs[f'{z}']).float() for z in features[0]]
    for hedge, feature in zip(hgraph, features[0]):
        hedge[3] = feature
    return hgraph

In [86]:
hgraph = []
hgraph = struc2singletons(struc, hgraph)

In [124]:
import math

class gaussian_expansion(object):
    def __init__(self, dmin, dmax, steps):
        assert dmin<dmax
        self.dmin = dmin
        self.dmax = dmax
        self.steps = steps
        
    def expand(self, distance, sig=None, tolerance = 0.01):
        drange = self.dmax-self.dmin
        step_size = drange/self.steps
        if sig == None:
            sig = step_size/2
        ds = [self.dmin + i*step_size for i in range(self.steps)]
        expansion = [math.exp(-(distance-center)**2/(2*sig**2))  for center in ds]
        expansion = [i if i > tolerance else 0 for i in expansion]
        return expansion

In [187]:
#Add bond nodes to hgraph list
def struc2pairs(struc, hgraph, radius: float = 8, min_rad: bool = False, max_neighbor: float = 13, tol: float = 2, gauss_dim: int = 7):
    nbr_lst = struc.get_neighbor_list(r = radius, exclude_self=True)

    pair_center_idx = nbr_lst[0]
    pair_neighbor_idx = nbr_lst[1]
    distances = nbr_lst[3]
    
    if gauss_dim != 1:
        ge = gaussian_expansion(dmin = 0, dmax = radius, steps = gauss_dim)
            
    features = []
    n_count=0
    bond_index = 0
    pair_1_last = pair_center_idx[0]
    ## currently double counts pair-wise edges/makes undirected edges
    for pair_1,pair_2,dist in zip(pair_center_idx, pair_neighbor_idx,distances):
        #Accounts for max_neighbor
        if pair_1 == pair_1_last:
            n_count +=1
        else:
            pair_1_last = pair_1
            n_count = 1
        if n_count < max_neighbor:
            if gauss_dim != 1:
                dist = ge.expand(dist)
            hgraph.append(['bond', bond_index, [pair_1, pair_2], dist])
            bond_index += 1
        
    return hgraph

In [48]:
print(hgraph)
struc2pairs(struc, hgraph, gauss_dim = 1)

[['atom', 0, [0], 25], ['atom', 1, [1], 25], ['atom', 2, [2], 25], ['atom', 3, [3], 25], ['atom', 4, [4], 16], ['atom', 5, [5], 16], ['atom', 6, [6], 16], ['atom', 7, [7], 16], ['atom', 8, [8], 16], ['atom', 9, [9], 16], ['atom', 10, [10], 16], ['atom', 11, [11], 16]]
here:7.683058909258419
here:7.194221723488871
here:5.982994564559721
here:7.68305890925842
here:7.095164547238011
here:7.810984347699079
here:5.2740802834207186
here:7.095164547238012
here:7.683058909258419
here:6.764510873670025
here:7.095164547238011
here:7.683058909258419
here:7.194221723488871
here:5.982994564559722
here:5.92355390220965
here:6.764510873670024
here:7.095164547238012
here:5.340754250397146
here:5.982994564559721
here:5.340754250397146
here:5.2740802834207186
here:7.68305890925842
here:7.68305890925842
here:7.194221723488871
here:5.274080283420719
here:7.095164547238012
here:7.683058909258419
here:6.764510873670024
here:7.095164547238012
here:4.532477781473526
here:7.194221723488871
here:7.6830589092584

[['atom', 0, [0], 25],
 ['atom', 1, [1], 25],
 ['atom', 2, [2], 25],
 ['atom', 3, [3], 25],
 ['atom', 4, [4], 16],
 ['atom', 5, [5], 16],
 ['atom', 6, [6], 16],
 ['atom', 7, [7], 16],
 ['atom', 8, [8], 16],
 ['atom', 9, [9], 16],
 ['atom', 10, [10], 16],
 ['atom', 11, [11], 16],
 ['bond', 0, [0, 10], 7.683058909258419],
 ['bond', 0, [0, 7], 7.194221723488871],
 ['bond', 0, [0, 10], 5.982994564559721],
 ['bond', 0, [0, 7], 7.68305890925842],
 ['bond', 0, [0, 5], 7.095164547238011],
 ['bond', 0, [0, 0], 7.810984347699079],
 ['bond', 0, [0, 8], 5.2740802834207186],
 ['bond', 0, [0, 6], 7.095164547238012],
 ['bond', 0, [0, 5], 7.683058909258419],
 ['bond', 0, [0, 1], 6.764510873670025],
 ['bond', 0, [0, 11], 7.095164547238011],
 ['bond', 0, [1, 8], 7.683058909258419],
 ['bond', 0, [1, 7], 7.194221723488871],
 ['bond', 0, [1, 10], 5.982994564559722],
 ['bond', 0, [1, 5], 5.92355390220965],
 ['bond', 0, [1, 0], 6.764510873670024],
 ['bond', 0, [1, 11], 7.095164547238012],
 ['bond', 0, [1, 10

In [65]:
types = [ "cn",
        "sgl_bd",
        "bent",
        "tri_plan",
        "tri_plan_max",
        "reg_tri",
        "sq_plan",
        "sq_plan_max",
        "pent_plan",
        "pent_plan_max",
        "sq",
        "tet",
        "tet_max",
        "tri_pyr",
        "sq_pyr",
        "sq_pyr_legacy",
        "tri_bipyr",
        "sq_bipyr",
        "oct",
        "oct_legacy",
        "pent_pyr",
        "hex_pyr",
        "pent_bipyr",
        "hex_bipyr",
        "T",
        "cuboct",
        "cuboct_max",
        "see_saw_rect",
        "bcc",
        "q2",
        "q4",
        "q6",
        "oct_max",
        "hex_plan_max",
        "sq_face_cap_trig_pris"]

In [138]:
def struc2motifs(struc, hgraph, types = types, lsop_tol = 0.05):
    neighborhoods = []
    vnn = VoronoiNN(tol=0.1, targets=None)
    ####IMPORTANT: REQUIRES YOU RUN struc2atoms FIRST####
    for n in range(len(struc.sites)):
        neigh = [neighbor.properties['index'] for neighbor in vnn.get_nn(struc, n)]
        neighborhoods.append([n, neigh])
        neigh = []
        
    lsop = LocalStructOrderParams(types)
    motif_index = 0
    for site, neighs in neighborhoods:
        ##Calculate order parameters for Voronoii neighborhood (excluding center)
        feat = lsop.get_order_parameters(struc, site, indices_neighs = neighs)
        for n,f in enumerate(feat):
            if f == None:
                feat[n] = 0
            elif f > 1:
                feat[n] = f
            ##CONSIDER INVERT LSOP SO THAT 1 corresponds to shape as opposed to 0 abs(f-1) and make 1 when 0.
            ##Account for tolerance:
            elif f > lsop_tol:
                feat[n] = f
            else:
                feat[n] = 0
        ##Add center to node set before adding to hgraph list
        neighs.append(site)
        hgraph.append(['motif', motif_index, neighs, feat])
        motif_index += 1

    return hgraph

In [116]:
hgraph = []
struc2motifs(struc, hgraph, types = types)

[['motif',
  0,
  [11, 9, 6, 5, 7, 10, 0],
  [6.0,
   0,
   0,
   0,
   0,
   0,
   0.14243183801701953,
   0.39344810969074473,
   0,
   0.3891818443848119,
   0,
   0,
   0.1562813946707832,
   0.5113057454736281,
   0.5814634668618692,
   0.11932998001665247,
   0.4916194248518829,
   0.5308165961703892,
   0.2000681469315574,
   0,
   0.5354550624747224,
   0.4382306652143231,
   0.5223879964881923,
   0.4244386440387184,
   0.43036784535257916,
   0.2933466040660625,
   0.46704357385544915,
   0.40682296481718483,
   0,
   0.40588857203846634,
   0.5165345556934031,
   0.2803649979112953,
   0.5308165961703892,
   0.21853026521726338,
   0.36044845458879066]],
 ['motif',
  1,
  [4, 11, 6, 8, 10, 7, 1],
  [6.0,
   0,
   0,
   0,
   0,
   0,
   0.14243183801701953,
   0.39344810969074473,
   0,
   0.38918184438481196,
   0,
   0,
   0.1562813946707834,
   0.511305745473628,
   0.5814634668618692,
   0.11932998001665218,
   0.49161942485188287,
   0.5308165961703892,
   0.20006814693

In [133]:
## playing with motif generattion techniques, NN calculations
lsop = LocalStructOrderParams(types)
n=1

vnn = VoronoiNN(tol=0.1, targets=None)
neigh = [neighbor.properties['index'] for neighbor in vnn.get_nn(struc, n)]
print(f'neighs: {neigh} for site {n}')
print(f'w/neighs: {lsop.get_order_parameters(struc, 1, neigh)}')
neigh.append(1)
print(f'w/center: {lsop.get_order_parameters(struc, 1, neigh)}')
print(f'default: {lsop.get_order_parameters(struc, 1)}')

neighs: [4, 11, 6, 8, 10, 7] for site 1
w/neighs: [6.0, 0.0, 7.121562665051939e-05, 0.0008384951159906882, 0.028653472314188584, 1.0658041506817687e-205, 0.14243183801701953, 0.39344810969074473, 0.03074607619927582, 0.38918184438481196, 6.315412975334334e-31, 0.007254860759403704, 0.1562813946707834, 0.511305745473628, 0.5814634668618692, 0.11932998001665218, 0.49161942485188287, 0.5308165961703892, 0.20006814693155725, 0.02369499459898541, 0.5354550624747225, 0.43823066521432336, 0.5223879964881925, 0.42443864403871856, 0.430367845352579, 0.2933466040660625, 0.46704357385544876, 0.4068229648171844, -0.08899561704693024, 0.40588857203846623, 0.516534555693403, 0.28036499791129554, 0.5308165961703892, 0.2185302652172633, 0.36044845458879077]
w/center: [7.0, 1.0, 0.2857651540190361, 0.00047970206688081066, 0.02468715011624237, 5.9866e-320, 0.2203925800154723, 0.39344810969074473, 0.017694246798495492, 0.4259131200679392, 1.3202480190584596e-38, 0.00414563471965926, 0.13023449555898617, 

In [208]:
def struc2cell(struc, hgraph, random_x = True):
    if random_x == True:
        feat = np.random.rand(64)
    else:
        feat = None
    nodes = list(range(len(struc.sites)))
    hgraph.append(['cell', 0, nodes, feat])
    return hgraph

In [204]:
## Now bring together process into overall hgraph generation
def hgraph_gen(struc):
    hgraph = []
    hgraph = struc2singletons(struc, hgraph)
    hgraph = struc2pairs(struc, hgraph)
    hgraph = struc2motifs(struc, hgraph)
    hgraph = struc2cell(struc, hgraph)
    
    return hgraph

In [207]:
hgraph_gen(struc)

[['atom', 0, [0], 25],
 ['atom', 1, [1], 25],
 ['atom', 2, [2], 25],
 ['atom', 3, [3], 25],
 ['atom', 4, [4], 16],
 ['atom', 5, [5], 16],
 ['atom', 6, [6], 16],
 ['atom', 7, [7], 16],
 ['atom', 8, [8], 16],
 ['atom', 9, [9], 16],
 ['atom', 10, [10], 16],
 ['atom', 11, [11], 16],
 ['bond', 0, [0, 10], [0, 0, 0, 0, 0, 0, 0.3518596941096958]],
 ['bond', 1, [0, 7], [0, 0, 0, 0, 0, 0.03495268574223037, 0.8403103964874236]],
 ['bond',
  2,
  [0, 10],
  [0, 0, 0, 0, 0.04730950585495063, 0.8953298704588116, 0.3103414666759601]],
 ['bond', 3, [0, 7], [0, 0, 0, 0, 0, 0, 0.35185969410969503]],
 ['bond',
  4,
  [0, 5],
  [0, 0, 0, 0, 0, 0.053942972761110757, 0.9169045192625017]],
 ['bond', 5, [0, 0], [0, 0, 0, 0, 0, 0, 0.24829144078684012]],
 ['bond',
  6,
  [0, 8],
  [0, 0, 0, 0, 0.4695371415713126, 0.7432469782428341, 0.02154856851752471]],
 ['bond', 7, [0, 6], [0, 0, 0, 0, 0, 0.05394297276111057, 0.9169045192625012]],
 ['bond', 8, [0, 5], [0, 0, 0, 0, 0, 0, 0.3518596941096958]],
 ['bond', 9, [0

In [151]:
def ordertype(hgraph, string):
    order_hedge = []
    for hedge in hgraph:
        if hedge[0] == string:
            order_hedge.append(hedge)
    return order_hedge    

def decompose(hgraph, order_types=['atom','bond','motif']):
    sep = []
    for string in order_types:
        sep.append(ordertype(hgraph, string))
    return sep
    
def contains(big, small):
    if all(item in big for item in small):
        return True
    else:
        return False
    
def touches(one, two):
    if any(item in one for item in two):
        return True
    else:
        return False

In [176]:
print(hgraph)

[['motif', 0, [11, 9, 6, 5, 7, 10, 0], [6.0, 0, 0, 0, 0, 0, 0.14243183801701953, 0.39344810969074473, 0, 0.3891818443848119, 0, 0, 0.1562813946707832, 0.5113057454736281, 0.5814634668618692, 0.11932998001665247, 0.4916194248518829, 0.5308165961703892, 0.2000681469315574, 0, 0.5354550624747224, 0.4382306652143231, 0.5223879964881923, 0.4244386440387184, 0.43036784535257916, 0.2933466040660625, 0.46704357385544915, 0.40682296481718483, 0, 0.40588857203846634, 0.5165345556934031, 0.2803649979112953, 0.5308165961703892, 0.21853026521726338, 0.36044845458879066]], ['motif', 1, [4, 11, 6, 8, 10, 7, 1], [6.0, 0, 0, 0, 0, 0, 0.14243183801701953, 0.39344810969074473, 0, 0.38918184438481196, 0, 0, 0.1562813946707834, 0.511305745473628, 0.5814634668618692, 0.11932998001665218, 0.49161942485188287, 0.5308165961703892, 0.20006814693155725, 0, 0.5354550624747225, 0.43823066521432336, 0.5223879964881925, 0.42443864403871856, 0.430367845352579, 0.2933466040660625, 0.46704357385544876, 0.40682296481718

In [201]:
def hetero_rel_edges(hedge, cell_vector = False):
    atoms, bonds, motifs = decompose(hgraph)
    edges = {}
    atom_atom_hom = [[],[]]
    for bond in bonds:
        atom_atom_hom[0].append(bond[2][0])
        atom_atom_hom[1].append(bond[2][1])
    edges['atom','bonds','atom'] = atom_atom_hom

    atom_bonds_het = [[],[]]
    for atom in atoms:
        for bond in bonds:
            if contains(bond[2],atom[2]):
                atom_bonds_het[0].append(atom[1])
                atom_bonds_het[1].append(bond[1])
    edges['atom','in','bond'] = atom_bonds_het

    atom_motifs_het = [[],[]]
    for atom in atoms:
        for motif in motifs:
            if contains(motif[2],atom[2]):
                atom_motifs_het[0].append(atom[1])
                atom_motifs_het[1].append(motif[1])
    edges['atom','in','motif'] = atom_motifs_het

    bond_bond_hom = [[],[]]
    for bond1 in bonds:
        for bond2 in bonds:
            if bond1!=bond2:
                if touches(bond1[2], bond2[2]):
                    bond_bond_hom[0].append(bond1[1])
                    bond_bond_hom[1].append(bond2[1])

    edges['bond','touches','bond'] = bond_bond_hom


    bond_motifs_het = [[],[]]
    for bond in bonds:
        for motif in motifs:
            if contains(motif[2],bond[2]):
                bond_motifs_het[0].append(bond[1])
                bond_motifs_het[1].append(motif[1])
    edges['bond','in','motif'] = bond_motifs_het

    mot_mot_hom = [[],[]]
    for m1 in motifs:
        for m2 in motifs:
            if m1!=m2:
                if touches(m1, m2):
                    mot_mot_hom[0].append(m1[1])
                    mot_mot_hom[1].append(m2[1])

    edges['motif','touches','motif'] = mot_mot_hom

    if cell_vector == True:
        orders = ['motif', 'atom', 'bond']
        for string, order in zip(orders, decompose(hgraph, orders)):
            edge_idx = [[],[]]
            for ent in order:
                edge_idx[0].append(0)
                edge_idx[1].append(ent[1])
            edges['cell', 'contains', string] = edge_idx
    
    return edges


In [202]:
hgraph = []
hgraph = hgraph_gen(struc)
print(hetero_rel_edges(hgraph, cell_vector = True))

{('atom', 'bonds', 'atom'): [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11], [10, 7, 10, 7, 5, 0, 8, 6, 5, 1, 11, 10, 8, 7, 10, 5, 0, 11, 10, 8, 7, 5, 4, 1, 9, 7, 10, 5, 4, 0, 11, 10, 9, 8, 7, 5, 11, 10, 9, 8, 7, 6, 5, 4, 2, 1, 0, 10, 10, 10, 7, 4, 11, 8, 6, 5, 1, 11, 10, 9, 10, 5, 8, 8, 7, 5, 8, 5, 5, 10, 8, 7, 11, 10, 5, 4, 0, 11, 10, 9, 8, 7, 6, 5, 10, 10, 7, 8, 8, 7, 5, 8, 6, 5, 1, 10, 10, 5, 10, 8, 7, 5, 8, 8, 5, 8, 7, 5, 10, 5, 4, 0, 10, 9, 7, 4, 7, 8, 6, 9, 10, 8, 6, 5, 1, 10, 8, 7, 5, 5, 11, 8, 7, 10, 5, 0, 11, 10, 8, 7, 5, 4, 0, 8]], ('atom', 'in', 'bond'): [[0, 0, 0, 0, 0, 