# Imports

In [1]:
import h5py
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch_geometric
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import os
import glob
from dataloader import SimpleGraphDataLoader
from collections import Counter

# Load data

In [4]:
with h5py.File("/home/thomas/graph-diffusion-project/graphs_h5/graph_AntiFlourite_Ra2O_r5.h5", 'r') as file:
        edge_features = torch.tensor(file['edge_attributes'][:], dtype=torch.float32) # Edge attributes
        edge_indices = torch.tensor(file['edge_indices'][:], dtype=torch.long) # Edge (sparse) adjecency matrix
        node_features = torch.tensor(file['node_attributes'][:], dtype=torch.float32) # Node attributes

        xyz = node_features[:, 4:7]
        print(xyz.shape, xyz.max(), xyz.min())

torch.Size([23, 3]) tensor(5.3914) tensor(-5.3914)


In [143]:
mol_paths = glob.glob("/home/thomas/graph-diffusion-project/graphs_h5/*")
print(len(mol_paths))

1728


In [144]:
def load_structure(path):

    with h5py.File(path, 'r') as file:
        edge_features = torch.tensor(file['edge_attributes'][:], dtype=torch.float32) # Edge attributes
        edge_indices = torch.tensor(file['edge_indices'][:], dtype=torch.long) # Edge (sparse) adjecency matrix
        node_features = torch.tensor(file['node_attributes'][:], dtype=torch.float32) # Node attributes

        r = file['r'][...] # Real space (x-axis)
        pdf = file['pdf'][...] # G(r) (y-axis)

        # Here you can do some normalisation of the node features and perhaps pick out which you want to include.

        graph = Data(x = node_features, y = pdf, edge_attr = edge_features, edge_index = edge_indices)
    return graph

In [145]:
loader = DataLoader([load_structure(path) for path in mol_paths], batch_size=1, shuffle=True)

In [146]:
loader = [d for d in loader if d.x.shape[0] == 23]

In [147]:
len([d for d in loader if d.x.shape[0] <= 23])

3

In [148]:
g = loader[0]
print(g)

DataBatch(x=[23, 7], edge_index=[2, 56], edge_attr=[56], y=[1], batch=[23], ptr=[2])


In [149]:
import pandas as pd
from mendeleev import element

In [150]:
def xyz_to_str(xyz, atom_species=None):
    """
    Write a string in format ovito can read
    args:
        xyz: np.array of shape (n_atoms, 3)
        atom_species: np.array of shape (n_atoms, 1)
    returns:
        s: string in ovito format
    """
    n_atoms = xyz.shape[0]
    if atom_species is None:
        atom_species = np.array(["C"]*n_atoms).reshape(-1, 1)

    # if atom_species is number, convert to atomic symbol
    if atom_species.dtype != str:
        atom_species = np.array(
            [element(int(atom)).symbol if atom else atom for atom in atom_species]).reshape(-1, 1)

    s = ""
    vals = np.concatenate((atom_species, xyz), axis=1)

    # Number of atoms
    s += str(n_atoms) + "\n"
    
    # Comment line, just keep empty for now
    s += "\n"
    
    # Coordinates for each atom
    for atom, x, y, z in vals:
        s += f"{atom} {x} {y} {z} \n"

    return s


def save_to_csv(file_name, s):
    with open(file_name, "w") as f:
        f.write(s)

In [154]:
# Find the path to a graph with 23 nodes
for p in mol_paths:
    g = load_structure(p)
    if g.x.shape[0] == 23:
        print(p)
        break

/home/thomas/graph-diffusion-project/graphs_h5/graph_AntiFlourite_Ra2O_r5.h5
