In [1]:
from pymatgen.io.cif import CifWriter, CifParser
from pymatgen.core.structure import Structure
from torch_geometric.data import Data
import torch
import numpy as np

In [19]:
## CIF -> tuple(tensor([[node_index,connected_node_index],...]), tensor([dist,...])) ##

#takes cif file and returns array (2 x num_edges) of edge index
#found by collecting neighbors within radius, also adjoins distance
#associated with each edge in tuple
def cif2graphedges(cif_file, radius:float=3):
    struc = CifParser(cif_file).get_structures()[0]
    nbr_lst = struc.get_neighbor_list(radius, exclude_self=True)
    edge_list=np.stack((nbr_lst[0],nbr_lst[1])).transpose()
    edge_list=torch.tensor(edge_list)
    edge_list_w_dist = (edge_list,torch.tensor(nbr_lst[3]))
    return edge_list_w_dist



## CIF -> tuple(tensor([[node_pos], ... ]),tensor([node_atomic_num,...])) ##

#takes cif file and returns tuple of a tensor of node positions and a tensor
# of node's atomic number, indexed same as cif2graphedges
def cif2nodepos(cif_file):
    struc = CifParser(cif_file).get_structures()[0]
    site_lst = struc.sites
    nodepos_lst = []
    nodespec_lst = []
    for site in site_lst:
        nodepos_lst.append(site.coords) #Coordinate of sites
        z_site = [element.Z for element in site.species]
        nodespec_lst.append(z_site) #Atomic number list of site species (should always be single element list for crystal)
    nodepos_arr = np.array(nodepos_lst, dtype=float)
    nodespec_arr = np.squeeze(nodespec_lst)
    return  (torch.tensor(nodepos_arr),torch.tensor(nodespec_arr))


In [35]:

## CIF -> tensor([[node_index,...],[hyper_edge_index,...]]) ##

# takes cif file and returns array (2 x num_nodes_in_hedges) of hedge index
# (as specified in the HypergraphConv doc of PyTorch Geometric)
# found by collecting neighbors within spec radius for each node in one hedge
def cif2hyperedges(cif_file, radius: float = 3):
    struc = CifParser(cif_file).get_structures()[0]
    nbr_lst = struc.get_neighbor_list(radius, exclude_self=False)
    edge_list = np.stack((nbr_lst[0], nbr_lst[1])).transpose()
    edge_list = torch.tensor(edge_list)

    tk = edge_list[0][0]
    hedge_index = []
    node_index = []
    for i, j in edge_list:
        if i != tk:
            hedge_index.append(tk)
            node_index.append(tk)
            tk = i
        node_index.append(j)
        hedge_index.append(i)
    node_index.append(edge_list[-1][0])
    hedge_index.append(edge_list[-1][0])
    hedge_list = torch.stack((torch.tensor(node_index), torch.tensor(hedge_index)))
    return hedge_list

def cif2hgraph(cif, radius:float = 3):
    pos = cif2nodepos(cif)[0]
    x = cif2nodepos(cif)[1]
    hedge_indx = cif2hyperedges(cif, radius)
    chgraph = Data(x=x, hyperedge_index=hedge_indx, pos=pos)
    return chgraph


def hgraph_list_from_dir(directory='cif', root='', atom_vecs = True, radius:float=3.0):
    if root == '':
        root = os. getcwd()
    directory = root+'\\'+directory
    print(f'Searching {directory} for CIF data to convert to hgraphs')
    with open(f'{directory}\\id_prop.csv') as id_prop:
        id_prop = csv.reader(id_prop)
        id_prop_data = [row for row in id_prop]
    graph_data_list = []
    if atom_vecs:
        with open(f'{directory}\\atom_init.json') as atom_init:
            atom_vecs = json.load(atom_init)
            for filename, fileprop in id_prop_data:
                try:
                    file = directory+'\\'+filename+'.cif'
                    graph = cif2hgraph(file, radius=radius)
                    graph.y = torch.tensor(float(fileprop))
                    nodes_z = graph.x.tolist()
                    nodes_atom_vec = [atom_vecs[f'{z}'] for z in nodes_z]
                    graph.x = torch.tensor(nodes_atom_vec).float()
                    graph_data_list.append(graph)
                    print(f'Added {filename} to hgraph set')
                except:
                    print(f'Error with {filename}, confirm existence')
    else:
        for filename, fileprop in id_prop_data:
                try:
                    file = directory+'\\'+filename+'.cif'
                    graph = cif2hgraph(file, radius=radius)
                    graph.y = torch.tensor(float(fileprop))
                    hgraph_data_list.append(graph)
                    print(f'Added {filename} to hgraph set')
                except:
                    print(f'Error with {filename}, confirm existence')
    print('Done generating hypergraph data')
    return hgraph_data_list

In [2]:
##DOWNLOAD SILICONE CIF FOR FORWARD TEST
from mp_api.client import MPRester
with MPRester(api_key="TzwPvkaJdQKn2Eish81buwJfgezz3ukz") as mpr:
    data = mpr.materials.get_data_by_id("mp-1455", fields = 'structure')
    crys_cif=CifWriter(data.structure, significant_figures=4)
    crys_cif.write_file('test_cif.cif')



Retrieving MaterialsDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

In [78]:
cif2hyperedges('test_cif.cif', radius=20, min_rad = True)

tensor([[ 6, 11,  9,  7, 10,  5,  0,  4, 11,  6, 10,  7,  8,  1,  5,  8,  7, 11,
          9,  4,  2, 10,  8,  5,  9,  4,  6,  3,  8,  2,  3,  1,  4,  0,  3,  9,
          2,  5,  3, 10,  1,  0,  6, 11,  0,  1,  2,  7,  1,  4,  3,  2,  8,  5,
          2,  3,  0,  9,  6,  0,  1,  3, 10,  2,  1,  0,  7, 11],
        [ 0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,
          2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  5,  5,  5,
          5,  5,  6,  6,  6,  6,  6,  7,  7,  7,  7,  7,  8,  8,  8,  8,  8,  9,
          9,  9,  9,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11]],
       dtype=torch.int32)

In [63]:
cif2hyperedges('test_cif.cif')

tensor([[ 6, 11,  9,  5,  7, 10,  0,  4, 11,  6, 10,  7,  8,  1,  5,  8,  7, 11,
          4,  9,  2, 10,  8,  5,  9,  4,  6,  3,  8,  2,  3,  1,  4,  0,  3,  2,
          9,  5,  1,  3,  0, 10,  6, 11,  0,  1,  2,  7,  1,  2,  3,  4,  8,  5,
          2,  0,  3,  9,  0,  1,  6,  3, 10,  2,  1,  0,  7, 11],
        [ 0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,
          2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  5,  5,  5,
          5,  5,  6,  6,  6,  6,  6,  7,  7,  7,  7,  7,  8,  8,  8,  8,  8,  9,
          9,  9,  9,  9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11]],
       dtype=torch.int32)

In [27]:
struc = CifParser('test_cif.cif').get_structures()[0]
radii_list = struc.get_neighbor_list(r = 25, exclude_self=False)[3]
min_rad = np.min(radii_list)
tolerance = 0.05
print(struc.get_neighbor_list(r = min_rad + tolerance))

(array([], dtype=int32), array([], dtype=int32), array([], shape=(0, 3), dtype=float64), array([], dtype=float64))


In [35]:
struc = CifParser('test_cif.cif').get_structures()[0]
neigh_list = struc.get_neighbor_list(r = 3, exclude_self=True)
print(neigh_list[0])
print(neigh_list[1])

[ 0  0  0  0  0  0  1  1  1  1  1  1  2  2  2  2  2  2  3  3  3  3  3  3
  4  4  4  4  5  5  5  5  6  6  6  6  7  7  7  7  8  8  8  8  9  9  9  9
 10 10 10 10 11 11 11 11]
[ 6 11  9  5  7 10  4 11  6 10  7  8  5  8  7 11  4  9 10  8  5  9  4  6
  8  2  3  1  0  3  2  9  1  3  0 10 11  0  1  2  1  2  3  4  5  2  0  3
  0  1  6  3  2  1  0  7]


In [62]:
###

def cif2hyperedges(cif_file, radius: float = 3, min_rad = False, tolerance = 0.1):
    struc = CifParser(cif_file).get_structures()[0]
    ##Determines minimum radius and returns neighbor list for within min radius + tolerance
    if min_rad == True:
        nbr_lst = struc.get_neighbor_list(r = 25, exclude_self=True)
        min_rad = np.min(nbr_lst[3])
        nbr_lst = struc.get_neighbor_list(r = min_rad + tolerance, exclude_self=True)
    else:
        nbr_lst = struc.get_neighbor_list(r = radius, exclude_self=True)
    edge_list = np.stack((nbr_lst[0], nbr_lst[1])).transpose()
    edge_list = torch.tensor(edge_list)

    tk = edge_list[0][0]
    hedge_index = []
    node_index = []
    for i, j in edge_list:
        if i != tk:
            hedge_index.append(tk)
            node_index.append(tk)
            tk = i
        node_index.append(j)
        hedge_index.append(i)
    node_index.append(edge_list[-1][0])
    hedge_index.append(edge_list[-1][0])
    hedge_list = torch.stack((torch.tensor(node_index), torch.tensor(hedge_index)))
    return hedge_list