In [None]:
from torch_geometric.data import HeteroData

data = HeteroData()

data['atom'].x = None  ## [atom feature 1, atom feature 2, ... ] w/ atoms from atom_init according to Z
data['bond'].x = None  ## [edge feature 1, ... ] w/ Gaussian expansion of distance for features
data['motif'].x = None ## [motif feature 1, ...] w/ Local structure order parameters for features

##Cell for alternative crystal vector
data['cell'].x = None  ## [crystal feature] w/ one-hot point groups or random init

##Homogenous edges##
data['atom', 'connects', 'atom'].edge_index = None ##
data['bond', 'connects', 'bond'].edge_index = None ##
data['motif', 'connects', 'motif'].edge_index = None ##

##Hetergeneous edges##
data['bond', 'contains', 'atom'].edge_index = None
data['motif', 'contains', 'bond'].edge_index = None
data['motif', 'contains', 'atom'].edge_index = None

##Master node edges
data['cell', 'contains', 'atom'].edge_index = None
data['cell', 'contains', 'bond'].edge_index = None
data['cell', 'contains', 'motif'].edge_index = None

In [4]:
import numpy as np
import os
import csv
import json

from pymatgen.io.cif import CifWriter, CifParser
from pymatgen.core.structure import Structure
#from torch_geometric.data import Data
#import torch

##Import example cif as pymatgen structure object
struc = CifParser('test_cif.cif').get_structures()[0]

In [10]:
#generate hypergraph dictionary elements for singleton sets
def struc2singletons(struc,  hgraph = [], tol=0.01, import_feat: bool = False, directory: str = ""):
    singletons = struc.get_neighbor_list(r = tol, exclude_self=False)[0]
    atom_count = 0
    for node in singletons:
        hgraph.append(['atom', atom_count, [atom_count], None])
        atom_count += 1
    #extract features (Z and possibly coordinates) for atoms
    site_lst = struc.sites
    features = [[],[]]
    for site in site_lst:
        features[1].append(site.coords) #Coordinate of sites
        z_site = [element.Z for element in site.species]
        features[0].append(z_site[0]) #Atomic num of sites
    #import features from CGCNN atom_init file
    if import_feat == True:
        with open(f'{directory}atom_init.json') as atom_init:
            atom_vecs = json.load(atom_init)
            features[0] = [torch.tensor(atom_vecs[f'{z}']).float() for z in features[0]]
    for hedge, feature in zip(hgraph, features[0]):
        hedge[3] = feature
    return hgraph

In [43]:
hgraph = []
hgraph = struc2singletons(struc, hgraph)

In [13]:
import math

class gaussian_expansion(object):
    def __init__(self, dmin, dmax, steps):
        assert dmin<dmax
        self.dmin = dmin
        self.dmax = dmax
        self.steps = steps
        
    def expand(self, distance, sig=None):
        drange = self.dmax-self.dmin
        step_size = drange/self.steps
        if sig == None:
            sig = step_size/2
        ds = [self.dmin + i*step_size for i in range(self.steps)]
        expansion = [math.exp(-(distance-center)**2/(2*sig**2)) for center in ds]
        return expansion

In [37]:
#Add bond nodes to hgraph list
def struc2pairs(struc, hgraph = [], radius: float = 8, min_rad: bool = False, max_neighbor: float = 12, tol: float = 2, gauss_dim: int = 40):
    nbr_lst = struc.get_neighbor_list(r = radius, exclude_self=True)

    pair_center_idx = nbr_lst[0]
    pair_neighbor_idx = nbr_lst[1]
    distances = nbr_lst[3]
    
    if gauss_dim != 1:
        ge = gaussian_expansion(dmin = 0, dmax = radius, steps = gauss_dim)
            
    features = []
    n_count=0
    bond_index = 0
    pair_1_last = pair_center_idx[0]
    ## currently double counts pair-wise edges/makes undirected edges
    for pair_1,pair_2,dist in zip(pair_center_idx, pair_neighbor_idx,distances):
        #Accounts for max_neighbor
        if pair_1 == pair_1_last:
            n_count +=1
        else:
            pair_1_last = pair_1
            n_count = 1
        if n_count < max_neighbor:
            if gauss_dim != 1:
                dist = ge.expand(dist)
            print(f'here:{dist}')
            hgraph.append(['bond', bond_index, [pair_1, pair_2], dist])
        
    return hgraph

In [44]:
print(hgraph)
struc2pairs(struc, hgraph, gauss_dim = 1)

[['atom', 0, [0], 25], ['atom', 1, [1], 25], ['atom', 2, [2], 25], ['atom', 3, [3], 25], ['atom', 4, [4], 16], ['atom', 5, [5], 16], ['atom', 6, [6], 16], ['atom', 7, [7], 16], ['atom', 8, [8], 16], ['atom', 9, [9], 16], ['atom', 10, [10], 16], ['atom', 11, [11], 16]]
here:7.683058909258419
here:7.194221723488871
here:5.982994564559721
here:7.68305890925842
here:7.095164547238011
here:7.810984347699079
here:5.2740802834207186
here:7.095164547238012
here:7.683058909258419
here:6.764510873670025
here:7.095164547238011
here:7.683058909258419
here:7.194221723488871
here:5.982994564559722
here:5.92355390220965
here:6.764510873670024
here:7.095164547238012
here:5.340754250397146
here:5.982994564559721
here:5.340754250397146
here:5.2740802834207186
here:7.68305890925842
here:7.68305890925842
here:7.194221723488871
here:5.274080283420719
here:7.095164547238012
here:7.683058909258419
here:6.764510873670024
here:7.095164547238012
here:4.532477781473526
here:7.194221723488871
here:7.6830589092584

[['atom', 0, [0], 25],
 ['atom', 1, [1], 25],
 ['atom', 2, [2], 25],
 ['atom', 3, [3], 25],
 ['atom', 4, [4], 16],
 ['atom', 5, [5], 16],
 ['atom', 6, [6], 16],
 ['atom', 7, [7], 16],
 ['atom', 8, [8], 16],
 ['atom', 9, [9], 16],
 ['atom', 10, [10], 16],
 ['atom', 11, [11], 16],
 ['bond', 0, [0, 10], 7.683058909258419],
 ['bond', 0, [0, 7], 7.194221723488871],
 ['bond', 0, [0, 10], 5.982994564559721],
 ['bond', 0, [0, 7], 7.68305890925842],
 ['bond', 0, [0, 5], 7.095164547238011],
 ['bond', 0, [0, 0], 7.810984347699079],
 ['bond', 0, [0, 8], 5.2740802834207186],
 ['bond', 0, [0, 6], 7.095164547238012],
 ['bond', 0, [0, 5], 7.683058909258419],
 ['bond', 0, [0, 1], 6.764510873670025],
 ['bond', 0, [0, 11], 7.095164547238011],
 ['bond', 0, [1, 8], 7.683058909258419],
 ['bond', 0, [1, 7], 7.194221723488871],
 ['bond', 0, [1, 10], 5.982994564559722],
 ['bond', 0, [1, 5], 5.92355390220965],
 ['bond', 0, [1, 0], 6.764510873670024],
 ['bond', 0, [1, 11], 7.095164547238012],
 ['bond', 0, [1, 10

In [45]:
types = [ "cn",
        "sgl_bd",
        "bent",
        "tri_plan",
        "tri_plan_max",
        "reg_tri",
        "sq_plan",
        "sq_plan_max",
        "pent_plan",
        "pent_plan_max",
        "sq",
        "tet",
        "tet_max",
        "tri_pyr",
        "sq_pyr",
        "sq_pyr_legacy",
        "tri_bipyr",
        "sq_bipyr",
        "oct",
        "oct_legacy",
        "pent_pyr",
        "hex_pyr",
        "pent_bipyr",
        "hex_bipyr",
        "T",
        "cuboct",
        "cuboct_max",
        "see_saw_rect",
        "bcc",
        "q2",
        "q4",
        "q6",
        "oct_max",
        "hex_plan_max",
        "sq_face_cap_trig_pris"]

In [None]:
def struc2motifs(struc, hgraph = [], radius: float = 6, min_rad: bool = False, tol: float = 1.5):
    nbr_lst = struc.get_neighbor_list(r = radius, exclude_self=True)

    pair_center_idx = nbr_lst[0]
    pair_neighbor_idx = nbr_lst[1]
    motif_index = 0

    last_center = pair_center_idx[0]
    neighborhoods = ([],[],[])
    neighborhood  = []

    for pair_1,pair_2 in zip(pair_center_idx, pair_neighbor_idx):
        new_center = pair_1
        if last_center != new_center:
            neighborhood.append(last_pair_1)
            hgraph.append(['motif', motif_index, neighborhood, None])
            last_center = pair_1
            neighborhood = []
            motif_index += 1
        neighborhood.append(pair_2)
        last_pair_1 = pair_1
    if last_center == 
    neighborhood.append(last_pair_1)
    hgraph.append(['motif', motif_index, neighborhood, None])
    
    lsop = LocalStructOrderParams(types)
    
    features = [[],[]]
    for hedge_idx, center_idx, neighbor_lst in zip(neighborhoods[0],neighborhoods[1],neighborhoods[2]):
        feature = np.array(lsop.get_order_parameters(struc, center_idx, indices_neighs = neighbor_lst), dtype = float)
        features[0].append(hedge_idx)
        features[1].append(torch.tensor(np.nan_to_num(feature)).float())
            
    return hedge_list, features