In [9]:
from xyz2graph import MolGraph, to_networkx_graph, to_plotly_figure
from plotly.offline import init_notebook_mode, iplot
import networkx as nx
import pandas as pd
import re
import numpy as np
import os
import torch
import torch_geometric as pyg
from torch_geometric.utils import from_networkx
from torch_geometric.data import Data
import periodictable as pt
import glob



In [6]:
def understand_mol_graph(path):
    # Parse the .xyz file
    with open(path, 'r') as file:
        lines = file.readlines()[2:]  # Skip the first two lines
        atoms = [tuple(map(float, line.split()[1:4])) for line in lines]  # Get the coordinates

    # Create the MolGraph object
    mg = MolGraph()
    mg.read_xyz(path)
    
    # Plot the molecular graph
    fig = to_plotly_figure(mg)
    iplot(fig)

    # Convert the molecular graph to the NetworkX graph
    G = to_networkx_graph(mg)
    print(G.nodes(data=True))
    print(G.edges(data=True))
    print(len(G.nodes))

    # Compare the atoms to the nodes in the graph
    graph_nodes = [data['xyz'] for node, data in G.nodes(data=True)]
    missing_atoms = set(atoms) - set(graph_nodes)
    print(f"Missing coordinates: {missing_atoms}")

understand_mol_graph('data/xyz_molsimp/monosubstituted_0001.xyz')

[(np.int64(0), {'element': 'Fe', 'xyz': (0.0, 0.0, 0.0)}), (np.int64(1), {'element': 'B', 'xyz': (0.0, 2.1, -0.0)}), (np.int64(20), {'element': 'O', 'xyz': (2.08, -0.0, -0.0)}), (np.int64(23), {'element': 'O', 'xyz': (-0.0, -2.08, 0.0)}), (np.int64(26), {'element': 'O', 'xyz': (-2.08, 0.0, 0.0)}), (np.int64(29), {'element': 'O', 'xyz': (0.0, 0.0, 2.08)}), (np.int64(32), {'element': 'O', 'xyz': (0.0, -0.0, -2.08)}), (np.int64(2), {'element': 'N', 'xyz': (-1.315822, 2.875, 0.265496)}), (np.int64(5), {'element': 'N', 'xyz': (1.315822, 2.875, -0.265496)}), (np.int64(3), {'element': 'C', 'xyz': (-1.629548, 3.720125, -0.895642)}), (np.int64(4), {'element': 'C', 'xyz': (-2.411063, 1.919746, 0.486484)}), (np.int64(8), {'element': 'H', 'xyz': (-1.746782, 3.104968, -1.763247)}), (np.int64(9), {'element': 'H', 'xyz': (-2.537891, 4.255123, -0.712365)}), (np.int64(10), {'element': 'H', 'xyz': (-0.832329, 4.415444, -1.056493)}), (np.int64(11), {'element': 'H', 'xyz': (-2.528296, 1.304585, -0.381118)

In [7]:
electronegativities = {
    "H": 2.20, "He": None,
    "Li": 0.98, "Be": 1.57, "B": 2.04, "C": 2.55, "N": 3.04, "O": 3.44, "F": 3.98, "Ne": None,
    "Na": 0.93, "Mg": 1.31, "Al": 1.61, "Si": 1.90, "P": 2.19, "S": 2.58, "Cl": 3.16, "Ar": None,
    "K": 0.82, "Ca": 1.00, "Sc": 1.36, "Ti": 1.54, "V": 1.63, "Cr": 1.66, "Mn": 1.55, "Fe": 1.83,
    "Co": 1.88, "Ni": 1.91, "Cu": 1.90, "Zn": 1.65, "Ga": 1.81, "Ge": 2.01, "As": 2.18, "Se": 2.55,
    "Br": 2.96, "Kr": 3.00,
    "Rb": 0.82, "Sr": 0.95, "Y": 1.22, "Zr": 1.33, "Nb": 1.6, "Mo": 2.16, "Tc": 1.9, "Ru": 2.2,
    "Rh": 2.28, "Pd": 2.20, "Ag": 1.93, "Cd": 1.69, "In": 1.78, "Sn": 1.96, "Sb": 2.05, "Te": 2.1,
    "I": 2.66, "Xe": 2.6,
    "Cs": 0.79, "Ba": 0.89, "La": 1.10, "Ce": 1.12, "Pr": 1.13, "Nd": 1.14, "Pm": None, "Sm": 1.17,
    "Eu": None, "Gd": 1.20, "Tb": None, "Dy": 1.22, "Ho": 1.23, "Er": 1.24, "Tm": 1.25, "Yb": None,
    "Lu": 1.27, "Hf": 1.3, "Ta": 1.5, "W": 2.36, "Re": 1.9, "Os": 2.2, "Ir": 2.20, "Pt": 2.28,
    "Au": 2.54, "Hg": 2.00, "Tl": 1.62, "Pb": 2.33, "Bi": 2.02, "Po": 2.0, "At": 2.2, "Rn": None,
    "Fr": 0.7, "Ra": 0.9, "Ac": 1.1, "Th": 1.3, "Pa": 1.5, "U": 1.38, "Np": 1.36, "Pu": 1.28,
    "Am": 1.3, "Cm": 1.3, "Bk": 1.3, "Cf": 1.3, "Es": 1.3, "Fm": 1.3, "Md": 1.3, "No": 1.3,
    "Lr": 1.3
}


def get_element_properties(element_symbol):
    element = pt.elements.symbol(element_symbol)
    
    # Normalize atomic number and weight
    normalized_atomic_number = element.number / 118.0
    normalized_weight = element.mass / 294.0

    electronegativity = electronegativities[element_symbol]
    atomic_radius = element.covalent_radius

    return normalized_atomic_number, normalized_weight, electronegativity, atomic_radius

In [10]:
for element in pt.elements:
    print(element.symbol, element.number, element.mass, element.covalent_radius)

n 0 1.00866491597 0.2
H 1 1.00794 0.31
He 2 4.002602 0.28
Li 3 6.941 1.28
Be 4 9.012182 0.96
B 5 10.811 0.84
C 6 12.0107 0.76
N 7 14.0067 0.71
O 8 15.9994 0.66
F 9 18.9984032 0.57
Ne 10 20.1797 0.58
Na 11 22.98977 1.66
Mg 12 24.305 1.41
Al 13 26.981538 1.21
Si 14 28.0855 1.11
P 15 30.973761 1.07
S 16 32.065 1.05
Cl 17 35.453 1.02
Ar 18 39.948 1.06
K 19 39.0983 2.03
Ca 20 40.078 1.76
Sc 21 44.95591 1.7
Ti 22 47.867 1.6
V 23 50.9415 1.53
Cr 24 51.9961 1.39
Mn 25 54.938049 1.39
Fe 26 55.845 1.32
Co 27 58.9332 1.26
Ni 28 58.6934 1.24
Cu 29 63.546 1.32
Zn 30 65.409 1.22
Ga 31 69.723 1.22
Ge 32 72.64 1.2
As 33 74.9216 1.19
Se 34 78.96 1.2
Br 35 79.904 1.2
Kr 36 83.798 1.16
Rb 37 85.4678 2.2
Sr 38 87.62 1.95
Y 39 88.90585 1.9
Zr 40 91.224 1.75
Nb 41 92.90638 1.64
Mo 42 95.94 1.54
Tc 43 98 1.47
Ru 44 101.07 1.46
Rh 45 102.9055 1.42
Pd 46 106.42 1.39
Ag 47 107.8682 1.45
Cd 48 112.411 1.44
In 49 114.818 1.42
Sn 50 118.71 1.39
Sb 51 121.76 1.39
Te 52 127.6 1.38
I 53 126.90447 1.39
Xe 54 131.293 1

In [11]:
get_element_properties('Co')

(0.2288135593220339, 0.2004530612244898, 1.88, 1.26)

In [12]:
def to_pyg_graph_triple(name, gen_xyz_path, gen_csv_path):
    # Read the csv file
    try:
        df = pd.read_csv(f"{gen_csv_path}.csv")
    except FileNotFoundError:
        print(f"No csv file found at {gen_csv_path}.csv")
        return None
    
    # Filter the csv data to get the row associated with the name
    row = df[df['name'] == name]
    
    # Check if the row is empty
    if row.empty:
        print(f"No data found for name {name} in the csv file {gen_csv_path}.csv")
        return None
    
    # Extract the first two letters of the name which symbolize the element
    element = re.match(r"^[A-Za-z]{2}", name).group()
    
    # Construct the rest of the name
    rest_of_the_name = re.sub(r"^[A-Za-z]{2}", "", name)
    
    # Construct the filenames for the xyz data
    filenames = [f"{element}_II{rest_of_the_name}_{state}.xyz" for state in ['HS', 'LS']]
    filenames.extend([f"{element}_III{rest_of_the_name}_{state}.xyz" for state in ['HS', 'LS']])
    
    data_list = []
    for filename in filenames:
        try:
            # Create the MolGraph object for the file
            mg = MolGraph()
            mg.read_xyz(f"{gen_xyz_path}/{filename}")

            # Convert the molecular graph to the NetworkX graph
            G = to_networkx_graph(mg)

            # Get element properties for each node and store them in the node features
            for node in G.nodes:
                element_symbol = G.nodes[node]['element']
                G.nodes[node]['element_properties'] = get_element_properties(element_symbol)

            # Convert the NetworkX graph to the PyG graph
            data = from_networkx(G, group_node_attrs=['xyz', 'element_properties'], group_edge_attrs=['length'])

            if data is not None:
                data_list.append(data)
        except FileNotFoundError:
            print(f"No xyz file found at {gen_xyz_path}/{filename}")
    
    # Concatenate the PyG graph data
    x = [d.x for d in data_list if d.x is not None]
    edge_index = [d.edge_index for d in data_list if d.edge_index is not None]
    edge_attr = [d.edge_attr for d in data_list if d.edge_attr is not None]

    if x and edge_index and edge_attr:  # Check if the lists are not empty
        data = Data(x=torch.cat(x, dim=0),
                    edge_index=torch.cat(edge_index, dim=1),
                    edge_attr=torch.cat(edge_attr, dim=0))
    else:
        print("No valid data found in the files.")
        return None
    
    # Get the last 7 columns of the filtered row
    output_data = row.iloc[:, -7:]
    
    # Add the output data to the PyG graph data
    for col in output_data.columns:
        data[col] = output_data[col].values[0]
    
    return data

In [13]:
def to_pyg_graph_single(name, gen_xyz_path, gen_csv_path):
    # Read the csv file
    try:
        df = pd.read_csv(f"{gen_csv_path}.csv")
    except FileNotFoundError:
        print(f"No csv file found at {gen_csv_path}.csv")
        return None
    
    # Filter the csv data to get the row associated with the name
    row = df[df['name'] == name]
    
    # Check if the row is empty
    if row.empty:
        print(f"No data found for name {name} in the csv file {gen_csv_path}.csv")
        return None
    
    # Extract the first two letters of the name which symbolize the element
    element = re.match(r"^[A-Za-z]{2}", name).group()
    
    # Construct the rest of the name
    rest_of_the_name = re.sub(r"^[A-Za-z]{2}", "", name)
    
    # Construct the filenames for the xyz data
    filenames = [f"{element}_II{rest_of_the_name}_LS.xyz", f"{element}_II{rest_of_the_name}_HS.xyz", f"{element}_III{rest_of_the_name}_LS.xyz", f"{element}_III{rest_of_the_name}_HS.xyz"]
    
    graph_list = []
    for filename in filenames:
        filepath = f"{gen_xyz_path}/{filename}"
        if os.path.exists(filepath):
            # Create the MolGraph object for the file
            mg = MolGraph()
            mg.read_xyz(filepath)

            # Convert the molecular graph to the NetworkX graph
            G = to_networkx_graph(mg)

            # Get element properties for each node and store them in the node features (only for the first graph)
            if not graph_list:
                for node in G.nodes:
                    element_symbol = G.nodes[node]['element']
                    element_properties = get_element_properties(element_symbol)
                    G.nodes[node]['element_properties'] = element_properties

            graph_list.append(G)
        else:
            print(f"No xyz file found at {filepath}")
    
    if not graph_list:
        print(f"No valid graphs found for {name}")
        return None
    
    # Concatenate the xyz data from the available graphs
    merged_graph = graph_list[0]
    for i, G in enumerate(graph_list[1:], start=1):
        for node in G.nodes:
            merged_graph.nodes[node][f'xyz_{i}'] = G.nodes[node]['xyz']
    
    # Create the PyTorch Geometric Data object
    data = Data()
    
    # Define the desired order of node attributes
    node_attr_order = ['element_properties', 'xyz'] + [f'xyz_{i}' for i in range(1, len(graph_list))]
    
    # Set the node features
    node_attrs = []
    for attr in node_attr_order:
        if attr in merged_graph.nodes[0]:
            if isinstance(merged_graph.nodes[0][attr], (int, float)):
                node_attrs.append(torch.tensor([[merged_graph.nodes[node][attr] for node in merged_graph.nodes]], dtype=torch.float).t())
            elif isinstance(merged_graph.nodes[0][attr], (tuple, list)):
                node_attrs.append(torch.tensor([merged_graph.nodes[node][attr] for node in merged_graph.nodes], dtype=torch.float))
    data.x = torch.cat(node_attrs, dim=-1)
    
    # Set the edge index and edge features
    edge_index = torch.tensor(list(merged_graph.edges), dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor([list(merged_graph.edges[edge].values()) for edge in merged_graph.edges], dtype=torch.float)
    data.edge_index = edge_index
    data.edge_attr = edge_attr
    
    # Get the last 7 columns of the filtered row
    output_data = row.iloc[:, -7:]
    
    # Add the output data to the PyG graph data
    for col in output_data.columns:
        data[col] = output_data[col].values[0]
    
    return data

In [14]:
def to_pyg_graph_noxyz(name, gen_xyz_path, gen_csv_path):
    # Read the csv file
    try:
        df = pd.read_csv(f"{gen_csv_path}.csv")
    except FileNotFoundError:
        print(f"No csv file found at {gen_csv_path}.csv")
        return None
    
    # Filter the csv data to get the row associated with the name
    row = df[df['name'] == name]
    
    # Check if the row is empty
    if row.empty:
        print(f"No data found for name {name} in the csv file {gen_csv_path}.csv")
        return None
    
    # Extract the first two letters of the name which symbolize the element
    element = re.match(r"^[A-Za-z]{2}", name).group()
    
    # Construct the rest of the name
    rest_of_the_name = re.sub(r"^[A-Za-z]{2}", "", name)
    
    # Construct the filename for the xyz data
    filename = f"{element}_II{rest_of_the_name}_LS.xyz"
    
    filepath = f"{gen_xyz_path}/{filename}"
    if os.path.exists(filepath):
        # Create the MolGraph object for the file
        mg = MolGraph()
        mg.read_xyz(filepath)

        # Convert the molecular graph to the NetworkX graph
        G = to_networkx_graph(mg)

        # Get element properties for each node and store them in the node features
        for node in G.nodes:
            element_symbol = G.nodes[node]['element']
            element_properties = get_element_properties(element_symbol)
            G.nodes[node]['element_properties'] = element_properties
    else:
        print(f"No xyz file found at {filepath}")
        return None
    
    # Create the PyTorch Geometric Data object
    data = Data()
    
    # Set the node features directly from the element properties
    data.x = torch.tensor([G.nodes[node]['element_properties'] for node in G.nodes], dtype=torch.float)
    
    # Set the edge index without edge features
    edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
    data.edge_index = edge_index
    
    # Get the last 7 columns of the filtered row
    output_data = row.iloc[:, -7:]
    
    # Add the output data to the PyG graph data
    for col in output_data.columns:
        data[col] = output_data[col].values[0]
    
    return data

In [15]:
def clean_xyz_files(folder_path):
    for filename in os.listdir(folder_path):
        if filename.endswith('.xyz'):
            file_path = os.path.join(folder_path, filename)
            with open(file_path, 'r') as file:
                lines = file.readlines()

            # Remove stoichiometry and whatever comes after in the second line
            if len(lines) > 1:
                second_line_parts = lines[1].split('|')
                cleaned_second_line = second_line_parts[0] + '\n'
                lines[1] = cleaned_second_line

            # Write the cleaned lines back to the file
            with open(file_path, 'w') as file:
                file.writelines(lines)

# # Replace 'your_folder_path' with the actual path to your folder
# clean_xyz_files('Data/Cambridge/X1/')

In [16]:
def to_pyg_graph_no_output(filepath):
    # Check if the file exists
    if os.path.exists(filepath):
        # Create the MolGraph object for the file
        mg = MolGraph()
        mg.read_xyz(filepath)

        # Convert the molecular graph to the NetworkX graph
        G = to_networkx_graph(mg)

        # Get element properties for each node and store them in the node features
        for node in G.nodes:
            element_symbol = G.nodes[node]['element']
            element_properties = get_element_properties(element_symbol)
            G.nodes[node]['element_properties'] = element_properties
    else:
        print(f"No xyz file found at {filepath}")
        return None
    
    # Create the PyTorch Geometric Data object
    data = Data()
    
    # Set the node features directly from the element properties
    data.x = torch.tensor([G.nodes[node]['element_properties'] for node in G.nodes], dtype=torch.float)
    
    # Set the edge index without edge features
    edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
    data.edge_index = edge_index
    
    return data

In [17]:
abemax = to_pyg_graph_no_output('data/xyz_molsimp/monosubstituted_0001.xyz')

In [20]:
print(abemax.x.shape)
print(abemax.x[1,:])
print(abemax.edge_index.shape)

torch.Size([35, 4])
tensor([0.0424, 0.0368, 2.0400, 0.8400])
torch.Size([2, 34])
