In [1]:
import numpy as np
import os
import csv
import json

from pymatgen.io.cif import CifWriter, CifParser
from pymatgen.core.structure import Structure
from torch_geometric.data import Data
import torch

## struc -> tensor([[node_index,...],[hyper_edge_index,...]]) ##

## subfunction for hedge_list responsible for singleton/node hedges, needs small tolerance
## to include itself in get_neighbor_list 

## might need to rework to include features in each subfunction
## atom_init from CGCNN for singletons

## should be passed first in hgraph construction, since it assumes to be first 

def struc2singletons(struc,  hedge_list = [[],[]], tol=0.01):
    singletons = struc.get_neighbor_list(r = tol, exclude_self=False)[0]
    hedge_counter = 0
    if hedge_list == [[],[]]:
        hedge_counter = 0
    else:
        hedge_counter = np.max(hedge_list[1]) + 1
    for node in singletons:
        hedge_list[0].append(node)
        hedge_list[1].append(hedge_counter)
        hedge_counter += 1
    return hedge_list




## CIF, tensor([[node_index,...],[hyper_edge_index,...]]) -> tensor([[node_index,...],[hyper_edge_index,...]]) ##

## subfunction for hedge_list responsible for pair-wise hedges
## can be based on threshold + min_rad or some interatomic radius

## might need to rework to include features in each subfunction
## gaussian distance for pairs

def struc2pairs(struc, hedge_list = [[],[]], radius: float = 3, min_rad: bool = True, tol: float = 0.1):
    if min_rad == False:
        nbr_lst = struc.get_neighbor_list(r = radius, exclude_self=True)
    elif min_rad == True:
        nbr_lst = struc.get_neighbor_list(r = 25, exclude_self=True)
        min_rad = np.min(nbr_lst[3])
        nbr_lst = struc.get_neighbor_list(r = min_rad+tol, exclude_self=True)

    pair_center_idx = nbr_lst[0]
    pair_neighbor_idx = nbr_lst[1]
    if hedge_list == []:
        hedge_counter = 0
    else:
        hedge_counter = np.max(hedge_list[1]) + 1

    ## currently double counts pair-wise edges
    for pair_1,pair_2 in zip(pair_center_idx, pair_neighbor_idx):
        hedge_list[0].append(pair_1)
        hedge_list[0].append(pair_2)

        hedge_list[1].append(hedge_counter)
        hedge_list[1].append(hedge_counter)
        hedge_counter += 1

    return hedge_list



## CIF, tensor([[node_index,...],[hyper_edge_index,...]]) -> tensor([[node_index,...],[hyper_edge_index,...]]) ##

## subfunction for hedge_list responsible for motif-wise hedges
## currently based only on distance (min+thresh or set range)

## might need to rework to include features in each subfunction
## motif order parameters

def struc2motifs(struc, hedge_list = [[],[]], radius: float = 3, min_rad: bool = True, tol: float = 0.1):
    if min_rad == False:
        nbr_lst = struc.get_neighbor_list(r = radius, exclude_self=True)
    elif min_rad == True:
        nbr_lst = struc.get_neighbor_list(r = 25, exclude_self=True)
        min_rad = np.min(nbr_lst[3])
        nbr_lst = struc.get_neighbor_list(r = min_rad+tol, exclude_self=True)

    pair_center_idx = nbr_lst[0]
    pair_neighbor_idx = nbr_lst[1]
    if hedge_list == [[],[]]:
        hedge_counter = 0
    else:
        hedge_counter = np.max(hedge_list[1]) + 1

    last_center = pair_center_idx[0]

    for pair_1,pair_2 in zip(pair_center_idx, pair_neighbor_idx):
        new_center = pair_1
        if last_center != new_center:
            hedge_list[0].append(last_pair_1)
            hedge_list[1].append(hedge_counter)
            last_center = pair_1
            hedge_counter += 1
            
        hedge_list[0].append(pair_2)
        hedge_list[1].append(hedge_counter)

        last_pair_1 = pair_1

    return hedge_list



In [2]:
struc = CifParser('test_cif.cif').get_structures()[0]

In [3]:
hedge_list=[[],[]]
hedge_list = struc2singletons(struc, hedge_list)
hedge_list = struc2pairs(struc, hedge_list)
hedge_list = struc2motifs(struc, hedge_list)
print(hedge_list)

[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 6, 0, 11, 0, 9, 0, 7, 0, 10, 0, 5, 1, 4, 1, 11, 1, 6, 1, 10, 1, 7, 1, 8, 2, 5, 2, 8, 2, 7, 2, 11, 2, 9, 2, 4, 3, 10, 3, 8, 3, 5, 3, 9, 3, 4, 3, 6, 4, 8, 4, 2, 4, 3, 4, 1, 5, 0, 5, 3, 5, 9, 5, 2, 6, 3, 6, 10, 6, 1, 6, 0, 7, 11, 7, 0, 7, 1, 7, 2, 8, 1, 8, 4, 8, 3, 8, 2, 9, 5, 9, 2, 9, 3, 9, 0, 10, 6, 10, 0, 10, 1, 10, 3, 11, 2, 11, 1, 11, 0, 11, 7, 6, 11, 9, 7, 10, 5, 0, 4, 11, 6, 10, 7, 8, 1, 5, 8, 7, 11, 9, 4, 2, 10, 8, 5, 9, 4, 6, 3, 8, 2, 3, 1, 4, 0, 3, 9, 2, 5, 3, 10, 1, 0, 6, 11, 0, 1, 2, 7, 1, 4, 3, 2, 8, 5, 2, 3, 0, 9, 6, 0, 1, 3, 10, 2, 1, 0, 7], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31, 32, 32, 33, 33, 34, 34, 35, 35, 36, 36, 37, 37, 38, 38, 39, 39, 40, 40, 41, 41, 42, 42, 43, 43, 44, 44, 45, 45, 46, 46, 47, 47, 48, 48, 49, 49, 50, 50, 51, 51, 52, 52, 53, 53, 54, 54, 55, 55, 56, 56

In [4]:
from pymatgen.analysis.local_env import LocalStructOrderParams
#lsop = LocalStructOrderParams(types)

#Dont forget to remove initial edges before neighborhood lists

hedge_list = [[],[]]
hedge_list = struc2motifs(struc, hedge_list)
motf_nindx = hedge_list[0][7:14]
motf_indx  = 1
print(motf_nindx)

#len(lsop.get_order_parameters(struc, motf_indx, indices_neighs = motf_nindx))

[4, 11, 6, 10, 7, 8, 1]


In [5]:
from pymatgen.analysis.local_env import LocalStructOrderParams


types = [ "cn",
        "sgl_bd",
        "bent",
        "tri_plan",
        "tri_plan_max",
        "reg_tri",
        "sq_plan",
        "sq_plan_max",
        "pent_plan",
        "pent_plan_max",
        "sq",
        "tet",
        "tet_max",
        "tri_pyr",
        "sq_pyr",
        "sq_pyr_legacy",
        "tri_bipyr",
        "sq_bipyr",
        "oct",
        "oct_legacy",
        "pent_pyr",
        "hex_pyr",
        "pent_bipyr",
        "hex_bipyr",
        "T",
        "cuboct",
        "cuboct_max",
        "see_saw_rect",
        "bcc",
        "q2",
        "q4",
        "q6",
        "oct_max",
        "hex_plan_max",
        "sq_face_cap_trig_pris"]


lsop = LocalStructOrderParams(types)

## CIF, tensor([[node_index,...],[hyper_edge_index,...]]) -> tensor([[node_index,...],[hyper_edge_index,...]]) ##

## subfunction for hedge_list responsible for motif-wise hedges
## currently based only on distance (min+thresh or set range)

## might need to rework to include features in each subfunction
## motif order parameters

def struc2motifs(struc, hedge_list = [[],[]], radius: float = 3, min_rad: bool = True, tol: float = 0.1):
    if min_rad == False:
        nbr_lst = struc.get_neighbor_list(r = radius, exclude_self=True)
    elif min_rad == True:
        nbr_lst = struc.get_neighbor_list(r = 25, exclude_self=True)
        min_rad = np.min(nbr_lst[3])
        nbr_lst = struc.get_neighbor_list(r = min_rad+tol, exclude_self=True)

    pair_center_idx = nbr_lst[0]
    pair_neighbor_idx = nbr_lst[1]
    if hedge_list == [[],[]]:
        hedge_counter = 0
    else:
        hedge_counter = np.max(hedge_list[1]) + 1

    last_center = pair_center_idx[0]
    neighborhoods = ([],[],[])
    neighborhood  = []

    for pair_1,pair_2 in zip(pair_center_idx, pair_neighbor_idx):
        new_center = pair_1
        if last_center != new_center:
            neighborhoods[0].append(hedge_counter)
            neighborhoods[1].append(last_pair_1)
            neighborhoods[2].append(neighborhood)
            hedge_list[0].append(last_pair_1)
            hedge_list[1].append(hedge_counter)
            last_center = pair_1
            neighborhood = []
            hedge_counter += 1
            
        hedge_list[0].append(pair_2)
        hedge_list[1].append(hedge_counter)
        neighborhood.append(pair_2)
        
        last_pair_1 = pair_1
        
    hedge_list[0].append(last_pair_1)
    hedge_list[1].append(hedge_counter)
    
    neighborhoods[0].append(hedge_counter)
    neighborhoods[1].append(last_pair_1)
    neighborhoods[2].append(neighborhood)
    
    lsop = LocalStructOrderParams(types)
    
    features = [[],[]]
    for hedge_idx, center_idx, neighbor_lst in zip(neighborhoods[0],neighborhoods[1],neighborhoods[2]):
        feature = lsop.get_order_parameters(struc, center_idx, indices_neighs = neighbor_lst)
        features[0].append(hedge_idx)
        features[1].append(feature)
            
    return hedge_list, features


def struc2singletons(struc,  hedge_list = [[],[]], tol=0.01, import_feat: bool = False, directory: str = ""):
    singletons = struc.get_neighbor_list(r = tol, exclude_self=False)[0]
    if hedge_list == [[],[]] or hedge_list == []:
        hedge_counter = 0
    else:
        hedge_counter = np.max(hedge_list[1]) + 1
    
    features = [[],[],[]]
    for node in singletons:
        hedge_list[0].append(node)
        hedge_list[1].append(hedge_counter)
        features[0].append(hedge_counter)
        hedge_counter += 1
        
    
    site_lst = struc.sites
    for site in site_lst:
        features[2].append(site.coords) #Coordinate of sites
        z_site = [element.Z for element in site.species]
        features[1].append(z_site[0]) #Atomic num of sites
    
    if import_feat == True:
        with open(f'{directory}atom_init.json') as atom_init:
            atom_vecs = json.load(atom_init)
            features[1] = [atom_vecs[f'{z}'] for z in features[1]]

    return hedge_list, features


In [6]:
hedge_list = [[],[]]
hedge_list = struc2motifs(struc, hedge_list)[1][1][0]
print(hedge_list)

[6.0, 2.220446049250313e-16, 7.121562665051939e-05, 0.0008384951159906882, 0.028653472314188584, 1.0658041506816435e-205, 0.14243183801701953, 0.3934481096907446, 0.030746076199275815, 0.389181844384812, 6.315412975334043e-31, 0.007254860759403706, 0.1562813946707834, 0.5113057454736281, 0.5814634668618692, 0.11932998001665218, 0.49161942485188304, 0.5308165961703892, 0.20006814693155736, 0.023694994598985445, 0.5354550624747225, 0.43823066521432336, 0.5223879964881926, 0.42443864403871856, 0.43036784535257916, 0.29334660406606244, 0.46704357385544876, 0.4068229648171845, 0.01447498552276311, 0.4058885720384663, 0.5165345556934031, 0.2803649979112953, 0.5308165961703892, 0.21853026521726338, 0.3604484545887907]


In [7]:
## struc -> tuple(tensor([[node_pos], ... ]),tensor([node_atomic_num,...])) ##

#takes cif file and returns tuple of a tensor of node positions and a tensor
# of nodes atomic number, indexed same as cif2graphedges
def cif2nodepos(struc):
    site_lst = struc.sites
    nodepos_lst = []
    nodespec_lst = []
    for site in site_lst:
        nodepos_lst.append(site.coords) #Coordinate of sites
        z_site = [element.Z for element in site.species]
        nodespec_lst.append(z_site) #Atomic number list of site species (should always be single element list for crystal)
    nodepos_arr = np.array(nodepos_lst, dtype=float)
    nodespec_arr = np.squeeze(nodespec_lst)
    return  (torch.tensor(nodepos_arr),torch.tensor(nodespec_arr))


In [8]:
hedge_list = [[],[]]
hedge_list = struc2singletons(struc, hedge_list)[0]
print(hedge_list)

[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]


In [9]:
## struc, tensor([[node_index,...],[hyper_edge_index,...]]) -> tensor([[node_index,...],[hyper_edge_index,...]]) ##

## subfunction for hedge_list responsible for pair-wise hedges
## can be based on threshold + min_rad or some interatomic radius

## might need to rework to include features in each subfunction
## gaussian distance for pairs

def struc2pairs(struc, hedge_list = [[],[]], radius: float = 3, min_rad: bool = True, tol: float = 0.1, gauss_dim: int = 1):
    if min_rad == False:
        nbr_lst = struc.get_neighbor_list(r = radius, exclude_self=True)
    elif min_rad == True:
        nbr_lst = struc.get_neighbor_list(r = 25, exclude_self=True)
        min_rad = np.min(nbr_lst[3])
        nbr_lst = struc.get_neighbor_list(r = min_rad+tol, exclude_self=True)

    pair_center_idx = nbr_lst[0]
    pair_neighbor_idx = nbr_lst[1]
    distances = nbr_lst[3]
    
    if hedge_list == [[],[]]:
        hedge_counter = 0
    else:
        hedge_counter = np.max(hedge_list[1]) + 1

    features = [[],[]]
    ## currently double counts pair-wise edges
    for pair_1,pair_2,dist in zip(pair_center_idx, pair_neighbor_idx,distances):
        hedge_list[0].append(pair_1)
        hedge_list[0].append(pair_2)

        hedge_list[1].append(hedge_counter)
        hedge_list[1].append(hedge_counter)
        
        features[0].append(hedge_counter)
        features[1].append(dist)
        
        hedge_counter += 1
    if gauss_dim != 1:
        ge = gaussian_expansion(dmin = 0 ,dmax = radius + 5*tol, steps = gauss_dim)
        features[1]=[ge.expand(dist) for dist in features[1]]

    
    return hedge_list, features


## Gaussian distance expansion function for pair hedge features

import math

class gaussian_expansion(object):
    def __init__(self, dmin, dmax, steps):
        assert dmin<dmax
        self.dmin = dmin
        self.dmax = dmax
        self.steps = steps
        
    def expand(self, distance, sig=None):
        drange = self.dmax-self.dmin
        step_size = drange/self.steps
        if sig == None:
            sig = step_size/2
        ds = [self.dmin + i*step_size for i in range(self.steps)]
        expansion = [math.exp(-(distance-center)**2/(2*sig**2)) for center in ds]
        return expansion

    
    

In [10]:
hedge_list = [[],[]]
hedge_list = struc2pairs(struc, hedge_list, gauss_dim=5)[1][1][0]
print(hedge_list)

[4.1871975449302247e-10, 2.8937889790041914e-05, 0.03662961367938433, 0.8492193684981157, 0.36060314951808503]


In [11]:
ge = gaussian_expansion(2,10,8)
ge.expand(5)

[1.522997974471263e-08,
 0.00033546262790251185,
 0.1353352832366127,
 1.0,
 0.1353352832366127,
 0.00033546262790251185,
 1.522997974471263e-08,
 1.2664165549094176e-14]

In [12]:
def hedge_packer(hedge_list):
    new_hedges = [[],[]]
    hedge_nodes = []
    old_hedge = hedge_list[1][0]
    for idx,(node,hedge) in enumerate(zip(hedge_list[0],hedge_list[1])):
        new_hedge = hedge_list[1][idx]
        if new_hedge != old_hedge:
            new_hedges[0].append(old_hedge)
            old_hedge = hedge_list[1][idx]
            new_hedges[1].append(hedge_nodes)
            hedge_nodes = []
        hedge_nodes.append(node)
    new_hedges[0].append(new_hedge)
    new_hedges[1].append(hedge_nodes)
    return new_hedges



In [13]:
def relatives(hedge_pack):
    relative_list = [[],[]]
    for idx1, hedge in enumerate(hedge_pack[0]):
        #print(idx1)
        relative_list[0].append(hedge)
        relatives = []
        for idx2,nodes_contained in enumerate(hedge_pack[1]):
            #print(nodes_contained)
            if  all(item in nodes_contained for item in hedge_pack[1][idx1]) or all(item in hedge_pack[1][idx1] for item in nodes_contained):
                relatives.append(hedge_pack[0][idx2])
        relatives.remove(hedge)
        relative_list[1].append(relatives)
    return relative_list
    

In [14]:
hedge_list = [[],[]]
hedge_list = struc2singletons(struc, hedge_list)[0]
hedge_list = struc2pairs(struc, hedge_list, gauss_dim=1)[0]
hedge_list = struc2motifs(struc, hedge_list)[0]
print(hedge_list)
print(hedge_packer(hedge_list))

[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 6, 0, 11, 0, 9, 0, 7, 0, 10, 0, 5, 1, 4, 1, 11, 1, 6, 1, 10, 1, 7, 1, 8, 2, 5, 2, 8, 2, 7, 2, 11, 2, 9, 2, 4, 3, 10, 3, 8, 3, 5, 3, 9, 3, 4, 3, 6, 4, 8, 4, 2, 4, 3, 4, 1, 5, 0, 5, 3, 5, 9, 5, 2, 6, 3, 6, 10, 6, 1, 6, 0, 7, 11, 7, 0, 7, 1, 7, 2, 8, 1, 8, 4, 8, 3, 8, 2, 9, 5, 9, 2, 9, 3, 9, 0, 10, 6, 10, 0, 10, 1, 10, 3, 11, 2, 11, 1, 11, 0, 11, 7, 6, 11, 9, 7, 10, 5, 0, 4, 11, 6, 10, 7, 8, 1, 5, 8, 7, 11, 9, 4, 2, 10, 8, 5, 9, 4, 6, 3, 8, 2, 3, 1, 4, 0, 3, 9, 2, 5, 3, 10, 1, 0, 6, 11, 0, 1, 2, 7, 1, 4, 3, 2, 8, 5, 2, 3, 0, 9, 6, 0, 1, 3, 10, 2, 1, 0, 7, 11], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31, 32, 32, 33, 33, 34, 34, 35, 35, 36, 36, 37, 37, 38, 38, 39, 39, 40, 40, 41, 41, 42, 42, 43, 43, 44, 44, 45, 45, 46, 46, 47, 47, 48, 48, 49, 49, 50, 50, 51, 51, 52, 52, 53, 53, 54, 54, 55, 55, 56

In [15]:
# generates edge pairs for relatives graph from relatives function output

def relatives2graphedges(relative_list):
    rel_ge = [[],[]]
    for idx,pair_center in enumerate(relative_list[0]):
        for neighbor in relative_list[1][idx]:
            rel_ge[0].append(pair_center)
            rel_ge[1].append(neighbor)
    return rel_ge

In [16]:
relative = relatives(hedge_packer(hedge_list))
print(relative)
print(relatives2graphedges(relative))

[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79], [[12, 13, 14, 15, 16, 17, 40, 47, 49, 59, 61, 66, 68, 73, 74, 75, 77, 78, 79], [18, 19, 20, 21, 22, 23, 39, 46, 50, 52, 62, 65, 69, 72, 74, 75, 76, 78, 79], [24, 25, 26, 27, 28, 29, 37, 43, 51, 55, 57, 64, 70, 72, 73, 75, 76, 77, 79], [30, 31, 32, 33, 34, 35, 38, 41, 44, 54, 58, 63, 71, 72, 73, 74, 76, 77, 78], [18, 29, 34, 36, 37, 38, 39, 53, 69, 70, 71, 72, 76], [17, 24, 32, 40, 41, 42, 43, 56, 68, 70, 71, 73, 77], [12, 20, 35, 44, 45, 46, 47, 60, 68, 69, 71, 74, 78], [15, 22, 26, 48, 49, 50, 51, 67, 68, 69, 70, 75, 79], [23, 25, 31, 36, 52, 53, 54, 55, 69, 70, 71, 72, 76], [14, 28, 33, 42, 56, 57, 58, 59, 68, 70, 71, 73, 77], [16, 21, 30, 45, 60, 61, 62, 63, 68, 69, 71, 74, 7

In [43]:

## CIF -> [[node_index,...],[hyper_edge_index,...]], (singleton_feats, pairs, motifs) ##

# takes cif file and returns array (2 x num_nodes_in_hedges) of hedge index
# (as specified in the HypergraphConv doc of PyTorch Geometric)
# found by collecting neighbors within spec radius for each node in one hedge


def cif2hedges(cif_file, radius: float = 3, min_rad: bool = True, tol: float = 0.1, features: bool = False, gauss_dim: int = 5):
    struc = CifParser(cif_file).get_structures()[0]
    hedge_list = [[],[]]
    feats = []
    hedge_list, singleton_feat = struc2singletons(struc, hedge_list, import_feat = True)
    hedge_list, pair_feat = struc2pairs(struc, hedge_list, radius, min_rad, tol, gauss_dim = gauss_dim)
    hedge_list, motif_feat = struc2motifs(struc, hedge_list, radius, min_rad, tol)
    if features == True:
        return hedge_list, (singleton_feat, pair_feat, motif_feat)
    elif features == False:
        return hedge_list
    
## CIF -> array([[graph_edge_1_node_1,...],[graph_edge_1_node_2,...]]) ##


def cif2reledges(cif_file, radius: float = 3, min_rad: bool = True, tol: float = 0.1, features: bool = False, gauss_dim: int = 5):
    hedge_list, feats = cif2hedges(cif_file, radius, min_rad, tol, features, gauss_dim)
    relative = relatives(hedge_packer(hedge_list))
    edges = np.array(relatives2graphedges(relative))
    if features == True:
        return edges, feats, hedge_list
    else:
        return edges

In [61]:
hedge_list, feats = cif2hedges('./test_cif.cif', features = True, tol = 0.1)
print(len(hedge_list[0]))
relative = relatives(hedge_packer(hedge_list))
edges = np.array(relatives2graphedges(relative))
print(edges.shape)
print(cif2reledges('test_cif.cif', features = True))

192
(2, 792)
(array([[ 0,  0,  0, ..., 79, 79, 79],
       [12, 13, 14, ..., 66, 67, 75]]), ([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 

In [63]:
from torch_geometric.data import HeteroData
data = HeteroData()
data['singletons'].x = np.array(feats[0][1])
data['pairs'].x = np.array(feats[1][1])
data['motifs'].x = np.array(feats[2][1])
data['singletons','pairs','motifs'].edge_index = edges

In [64]:
from torch_geometric.data import Data
data = Data()
data.x = np.array(feats[0][1])
data.x = np.array(feats[1][1])
data.x = np.array(feats[2][1])
data.edge_index = edges

In [20]:
print(data.x_index_dict)

{}


In [None]:
## Load data objects with relative edges defined already
###### save regular graph data to file in list
## Load heterogeneous features
###### save feature list as independent object to graph data file
## Feed features into projection layer
###### load feature list in network structure
## Associate homogenized features with data
###### set data.x after projection in network structure
## Associate labels with data
###### set data.y in network structure
## Pass to convolutional layers


In [38]:
try:
    import cPickle as pickle
except ModuleNotFoundError:
    import pickle


with open('graph_list.pkl', 'wb') as storage:
    pickle.dump(hedge_list, storage, pickle.HIGHEST_PROTOCOL)
    pickle.dump(relative, storage, pickle.HIGHEST_PROTOCOL)
    pickle.dump(feats, storage, pickle.HIGHEST_PROTOCOL)
    
del hedge_list
del relative
del feats

In [40]:
try:
    import cPickle as pickle
except ModuleNotFoundError:
    import pickle

with open('graph_list.pkl', 'rb') as storage:
    hedge_list = pickle.load(storage)
    relative =  pickle.load(storage)
    feats = pickle.load(storage)

In [95]:
from torch_geometric.data import Data
import csv
import os

def relgraph_list_from_dir(directory='cif', root='', atom_vecs = True, radius:float=3.0):
    if root == '':
        root = os. getcwd()
    directory = root+'\\'+directory
    print(f'Searching {directory} for CIF data to convert to hgraphs')
    with open(f'{directory}\\id_prop.csv') as id_prop:
        id_prop = csv.reader(id_prop)
        id_prop_data = [row for row in id_prop]
    relgraphs = []
    hedges = []
    feats_list = []

    for filename, fileprop in id_prop_data:
            try:
                file = directory+'\\'+filename+'.cif'
                edges, feats, hedge_list = cif2reledges(file, radius=radius, features = True)
                graph = Data()
                graph.edge_index = torch.tensor(edges, dtype = int)
                graph.y = torch.tensor(float(fileprop))
                relgraphs.append(graph)
                hedges.append(hedge_list)
                feats_list.append(feats)
                print(f'Added {filename} to relgraph set')
            except:
                print(f'Error with {filename}, confirm existence')
                
    print('Done generating relatives graph data with features')
    return relgraphs, feats_list, hedges

In [96]:
relgraph_list = relgraph_list_from_dir(directory='cif', root='', atom_vecs = True)


with open('relgraph_list.pkl', 'wb') as storage:
    pickle.dump(relgraph_list, storage, pickle.HIGHEST_PROTOCOL)
    
del relgraph_list

Searching C:\Users\ajh01\Desktop\chgcnn\cif for CIF data to convert to hgraphs
Added mp-1 to relgraph set
Added mp-2 to relgraph set
Added mp-4 to relgraph set
Added mp-9 to relgraph set
Added mp-10 to relgraph set
Added mp-13 to relgraph set
Done generating relatives graph data with features


In [162]:
with open('relgraph_list.pkl', 'rb') as storage:
    relgraphs, feats_list, hedges = pickle.load(storage)

In [196]:
import torch
import torch.nn.functional as F
from torch import nn

class ProjectionLayer(nn.Module):
    def __init__(self, feats, hidden_dims):
        input_dims = []
        for feat_type in feats:
            super(ProjectionLayer, self).__init__()
            input_dims.append(len(feat_type[1][0]))
        self.linears = nn.ModuleList([nn.Linear(in_dim, hidden_dims, dtype=float) for in_dim in input_dims])
    def forward(self, feats):
        feat_homog = []
        for feat_type, feat_proj in zip(feats, self.linears):
            feat_homog.append(feat_proj(feat_type))
        feat_list = torch.cat(feat_homog, dim = 0)
        return feat_list

feats = feats_list[0]
    
feats_tens = [torch.tensor(np.nan_to_num(np.array(feat[1], dtype=float))) for feat in feats]

proj = ProjectionLayer(feats, 64)

proj.forward(feats_tens)

tensor([[-0.1567,  0.1008, -0.2885,  0.1977,  0.0321, -0.2686, -0.1807, -0.0601,
          0.0832,  0.1999, -0.0798, -0.2872, -0.1723, -0.2098,  0.3836,  0.1426,
         -0.0612, -0.2512, -0.0785, -0.1298, -0.1479,  0.0667,  0.3088, -0.0557,
         -0.1801,  0.0283,  0.3386,  0.1750,  0.0057, -0.1860, -0.0488, -0.3855,
         -0.0219, -0.0627,  0.2324,  0.0257, -0.0712,  0.0934, -0.2909, -0.4823,
          0.0625, -0.1202,  0.0725,  0.2366, -0.0715, -0.0872, -0.0783, -0.0060,
          0.2094,  0.0252, -0.0185,  0.0748,  0.1049, -0.0290, -0.3065,  0.1063,
          0.0550, -0.4192, -0.1399,  0.0762, -0.0880,  0.1361,  0.0785,  0.1529],
        [ 0.3927, -0.1792, -0.2581, -0.1792,  0.1945, -0.3664, -0.0551, -0.2715,
          0.2714, -0.0291, -0.3146,  0.2764,  0.1113, -0.1643,  0.3655,  0.0948,
         -0.1770,  0.4336, -0.2163, -0.1716,  0.1812,  0.4331,  0.3026, -0.3239,
          0.2701, -0.1478,  0.0823, -0.0307, -0.0345,  0.2538,  0.2220,  0.0478,
         -0.0484,  0.0449, 

In [194]:
with open('relgraph_list.pkl', 'rb') as storage:
    relgraphs, feats_list, hedges = pickle.load(storage)

feat_ex = feats_list[0]
proj = ProjectionLayer(feat_ex, 64)

for graph, feat in zip(relgraphs, feats_list):
    print(feat[0])
    featsl = proj.forward(feat)
    print(featsl)
    graph.x = featsl

[[0], [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]], [array([0., 0., 0.])]]


TypeError: linear(): argument 'input' (position 1) must be Tensor, not list