In [56]:
import collections
import itertools
import os
import json
import warnings
import math

import torch
import torch_geometric
from torch_geometric.data import Data, Batch
import numpy as np
import h5py

from deeph.model import get_spherical_from_cartesian, SphericalHarmonics
from deeph.from_pymatgen import find_neighbors, _one_to_three, _compute_cube_index, _three_to_one
from pymatgen.core.structure import Structure


In [57]:
folder = '/home/t.hsu/example2/work_dir/dataset/processed/0'
structure = Structure(np.loadtxt(os.path.join(folder, 'lat.dat')).T,
                        np.loadtxt(os.path.join(folder, 'element.dat')),
                        np.loadtxt(os.path.join(folder, 'site_positions.dat')).T,
                        coords_are_cartesian=True,
                        to_unit_cell=False)

In [58]:
print('length of frac_coord', len(structure.frac_coords))
print('length of cart_coord', len(structure.cart_coords))

length of frac_coord 36
length of cart_coord 36


In [59]:
# Define variable to get graph: 
default_dtype_torch = torch.get_default_dtype()
cart_coords = torch.tensor(structure.cart_coords, dtype=default_dtype_torch)
frac_coords = torch.tensor(structure.frac_coords, dtype=default_dtype_torch)
numbers = torch.tensor(structure.atomic_numbers)
structure.lattice.matrix.setflags(write=True)
lattice = torch.tensor(structure.lattice.matrix, dtype=default_dtype_torch)
stru_id = os.path.split(folder)[-1]
radius = 9
numerical_tol=1e-8
tb_folder=folder
interface = 'openmx'
num_l = 5
create_from_DFT = 'True'
huge_structure = 'True'
separate_onsite = 'False'
if_lcmp_graph = 'True'
max_num_nbr = 0

In [60]:
def get_graph(cart_coords, frac_coords, numbers, stru_id, r, max_num_nbr, numerical_tol, lattice,
              default_dtype_torch, tb_folder, interface, num_l, create_from_DFT, if_lcmp_graph,
              separate_onsite, target='hamiltonian', huge_structure=False, only_get_R_list=False, if_new_sp=False,
              if_require_grad=False, fid_rc=None, **kwargs):
    if create_from_DFT:
        key_atom_list = [[] for _ in range(len(numbers))]
        edge_idx, edge_fea, edge_idx_first = [], [], []
        if if_lcmp_graph:
            atom_idx_connect, edge_idx_connect = [], []
            edge_idx_connect_cursor = 0
        if if_require_grad:
            fid = fid_rc
        else:
            fid = h5py.File(os.path.join(tb_folder, 'rc.h5'), 'r')    
        for k in fid.keys():
            key = json.loads(k)
            key_tensor = torch.tensor([key[0], key[1], key[2], key[3] - 1, key[4] - 1]) # (R, i, j) i and j is 0-based index
            if separate_onsite:
                if key[0] == 0 and key[1] == 0 and key[2] == 0 and key[3] == key[4]:
                    continue
            key_atom_list[key[3] - 1].append(key_tensor) # Reorder the storage: start with all the configs from atom i = 0, atom i=1, ...
        for index_first, (cart_coord, keys_tensor) in enumerate(zip(cart_coords, key_atom_list)):
            keys_tensor = torch.stack(keys_tensor)
            cart_coords_j = cart_coords[keys_tensor[:, 4]] + keys_tensor[:, :3].type(default_dtype_torch).to(cart_coords.device) @ lattice.to(cart_coords.device)
            dist = torch.norm(cart_coords_j - cart_coord[None, :], dim=1)
            len_nn = keys_tensor.shape[0]
            edge_idx_first.extend([index_first] * len_nn)
            edge_idx.extend(keys_tensor[:, 4].tolist())
            edge_fea_single = torch.cat([dist.view(-1, 1), cart_coord.view(1, 3).expand(len_nn, 3)], dim=-1)
            edge_fea_single = torch.cat([edge_fea_single, cart_coords_j, cart_coords[keys_tensor[:, 4]]], dim=-1)
            edge_fea.append(edge_fea_single)

            if if_lcmp_graph:
                atom_idx_connect.append(keys_tensor[:, 4])
                edge_idx_connect.append(range(edge_idx_connect_cursor, edge_idx_connect_cursor + len_nn))
                edge_idx_connect_cursor += len_nn

        edge_fea = torch.cat(edge_fea).type(default_dtype_torch)
        edge_idx = torch.stack([torch.LongTensor(edge_idx_first), torch.LongTensor(edge_idx)])
    data = Data(x=numbers, edge_index=edge_idx, edge_attr=edge_fea, stru_id=stru_id, **kwargs)
    return data

In [61]:
y = get_graph(cart_coords, frac_coords, numbers, stru_id, radius, max_num_nbr, numerical_tol, lattice,
              default_dtype_torch, tb_folder, interface, num_l, create_from_DFT, if_lcmp_graph,
              separate_onsite, target='hamiltonian', huge_structure=False, only_get_R_list=False, if_new_sp=False,
              if_require_grad=False, fid_rc=None)

In [62]:
y

Data(edge_attr=[1436, 10], edge_index=[2, 1436], stru_id="0", x=[36])

### This is the case for stru_id = 0, which is the first configuration. Imagine we have 81 different configurations, we just loop over and keep concatenating 

In [63]:
fid = h5py.File(os.path.join(tb_folder, 'rc.h5'), 'r')
key_atom_list = [[] for _ in range(len(numbers))]

for k in fid.keys():
    key = json.loads(k)
    key_tensor = torch.tensor([key[0], key[1], key[2], key[3] - 1, key[4] - 1]) # (R, i, j) i and j is 0-based index
    if separate_onsite:
        if key[0] == 0 and key[1] == 0 and key[2] == 0 and key[3] == key[4]:
            continue
    key_atom_list[key[3] - 1].append(key_tensor) # Reorder the storage: start with all the configs from atom i = 0, atom i=1, ...

In [64]:
key_atom_list[0]

[tensor([-1, -1,  0,  0, 17]),
 tensor([-1, -1,  0,  0, 26]),
 tensor([-1, -1,  0,  0, 27]),
 tensor([-1, -1,  0,  0,  5]),
 tensor([-1, -1,  0,  0,  7]),
 tensor([-1, -1,  0,  0,  8]),
 tensor([-1,  0,  0,  0, 10]),
 tensor([-1,  0,  0,  0, 15]),
 tensor([-1,  0,  0,  0, 16]),
 tensor([-1,  0,  0,  0, 24]),
 tensor([-1,  0,  0,  0, 28]),
 tensor([-1,  0,  0,  0, 33]),
 tensor([-1,  0,  0,  0,  6]),
 tensor([-1,  0,  0,  0,  7]),
 tensor([ 0, -1,  0,  0, 11]),
 tensor([ 0, -1,  0,  0, 20]),
 tensor([ 0, -1,  0,  0, 30]),
 tensor([ 0, -1,  0,  0,  2]),
 tensor([ 0, -1,  0,  0,  5]),
 tensor([0, 0, 0, 0, 9]),
 tensor([ 0,  0,  0,  0, 11]),
 tensor([ 0,  0,  0,  0, 12]),
 tensor([ 0,  0,  0,  0, 13]),
 tensor([ 0,  0,  0,  0, 14]),
 tensor([ 0,  0,  0,  0, 16]),
 tensor([ 0,  0,  0,  0, 17]),
 tensor([ 0,  0,  0,  0, 18]),
 tensor([ 0,  0,  0,  0, 19]),
 tensor([ 0,  0,  0,  0, 21]),
 tensor([ 0,  0,  0,  0, 22]),
 tensor([0, 0, 0, 0, 1]),
 tensor([ 0,  0,  0,  0, 31]),
 tensor([ 0,  0,  