In [None]:
import sys
sys.path.append('..')

from create_pyg_dataset import AdsorptionGraphDataset
from gnn_eads.graph_tools import graph_plotter

In [None]:
ASE_DB_PATH = "/home/smorandi/Desktop/gnn_eads/data/FG_dataset/FG_DATASET.db"
STRUCTURE_DICT = {"tolerance": 0.5, "scaling_factor": 1.2, "second_order_nn": False}
FEATURES_DICT = {"adsorbate": False, "ring": False, "aromatic": False, "radical": False, "facet": False}
GRAPH_PARAMS = {"structure": STRUCTURE_DICT, "features": FEATURES_DICT, "target": "scaled_e_ads"}

FG_dataset = AdsorptionGraphDataset(ASE_DB_PATH, GRAPH_PARAMS, "calc_type=adsorption,family=group2")
FG_dataset.print_summary()

In [None]:
print(FG_dataset)

In [None]:
print(FG_dataset.data)

In [None]:
print(FG_dataset.counter_isomorphism, FG_dataset.counter_H_filter, FG_dataset.counter_C_filter, FG_dataset.counter_fragment_filter, FG_dataset.counter_adsorption_filter, len(FG_dataset), FG_dataset.database_size)
        

In [None]:
FG_dataset.node_dim

In [None]:
random_graph = FG_dataset[59]
print(random_graph)

In [None]:
random_graph.x

In [None]:
df = random_graph.x.detach().numpy()
df.columns = FG_dataset.node_feature_list

In [None]:
import pprint
pprint.pprint(FG_dataset.node_feature_list)
pprint.pprint(random_graph.x[:, 2])
pprint.pprint(random_graph.x[:, 11])
pprint.pprint(random_graph.x[:, -1])

In [None]:
import pandas as pd
random_graph = FG_dataset[3000]
df = pd.DataFrame(random_graph.x.numpy())
df.columns = FG_dataset.node_feature_list
print(df)

In [None]:
print(random_graph)

In [None]:
random_graph = FG_dataset[4000]
graph_plotter(random_graph)

In [None]:
dir(FG_dataset)

In [None]:
print(FG_dataset.dataset_id)

In [None]:
database_size = FG_dataset.database_size
graph_dataset_size = len(FG_dataset)
bin_C_filter = FG_dataset.counter_C_filter
bin_H_filter = FG_dataset.counter_H_filter
bin_fragment_filter = FG_dataset.counter_fragment_filter
bin_adsorption_filter = FG_dataset.counter_adsorption_filter
bin_isomorphism = FG_dataset.counter_isomorphism

print("ASE database size: ", database_size)
print("Graph dataset size: ", graph_dataset_size)
print("C filter: ", bin_C_filter)
print("H filter: ", bin_H_filter)
print("Fragment filter: ", bin_fragment_filter)
print("Adsorption filter: ", bin_adsorption_filter)
print("Isomorphism: ", bin_isomorphism)


# Representation study

In [None]:
adsorbate_nodes = 0
catalyst_nodes = 0
for graph in FG_dataset:
    for node in graph.x:
        if node[-1] == 1:
            adsorbate_nodes += 1
        elif node[-1] == 0:
            catalyst_nodes += 1
        else:
            raise ValueError("Node type not recognized")
print("Adsorbate nodes: ", adsorbate_nodes)
print("Catalyst nodes: ", catalyst_nodes)

In [None]:
facet_111 = 0
facet_0001 = 0
facet_110 = 0

for graph in FG_dataset:
    if graph.facet == "fcc(111)":
        facet_111 += 1
    elif graph.facet == "hcp(0001)":
        facet_0001 += 1
    elif graph.facet == "bcc(110)":
        facet_110 += 1
    else:
        raise ValueError("Facet not recognized")
print("Facet 111: ", facet_111)
print("Facet 0001: ", facet_0001)
print("Facet 110: ", facet_110)

In [None]:
# get distribution of metals
metals = []
for graph in FG_dataset:
    metals.append(graph.metal)

#create dictionary of metals and their counts
metal_dict = {}
for metal in metals:
    if metal not in metal_dict:
        metal_dict[metal] = 1
    else:
        metal_dict[metal] += 1
# generate boxplot of metals
import matplotlib.pyplot as plt

plt.bar(metal_dict.keys(), metal_dict.values())

In [None]:
from ase.db import connect
from rdkit import Chem

db = connect(ASE_DB_PATH)
metals = ["Ag", "Au", "Cd", "Co", "Cu", "Fe", "Ir", "Ni", "Os", "Pd", "Pt", "Rh", "Ru", "Zn"]

atoms_obj = db.get_atoms(id=1000)
atoms_obj

In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np
import copy
from ase.atoms import Atoms
from rdkit.Chem import rdDetermineBonds

metals = ["Ag", "Au", "Cd", "Co", "Cu", "Fe", "Ir", "Ni", "Os", "Pd", "Pt", "Rh", "Ru", "Zn"]

def get_aromatic_atoms(atoms_obj: Atoms, 
                       molecule_elements: list[str]) -> list[int]:
    """
    Get the aromatic atoms in an atoms object

    Args:
        atoms_obj: ASE atoms object

    Returns:
        aromatic_atoms: list of aromatic atoms indices
    """
    molecule_atoms_obj = Atoms()
    molecule_atoms_obj.set_cell(atoms_obj.get_cell())
    molecule_atoms_obj.set_pbc(atoms_obj.get_pbc())
    for atom in atoms_obj:
        if atom.symbol in molecule_elements:
            molecule_atoms_obj.append(atom)
    atomic_symbols = molecule_atoms_obj.get_chemical_symbols()
    coordinates = molecule_atoms_obj.get_positions()
    xyz = '\n'.join(f'{symbol} {x} {y} {z}' for symbol, (x, y, z) in zip(atomic_symbols, coordinates))
    xyz = "{}\n\n{}".format(len(molecule_atoms_obj), xyz)
    rdkit_mol = Chem.MolFromXYZBlock(xyz)
    conn_mol = Chem.Mol(rdkit_mol)
    rdDetermineBonds.DetermineBonds(conn_mol)
    aromatic_atoms = [atom.GetIdx() for atom in conn_mol.GetAtoms() if atom.GetIsAromatic()]
    return aromatic_atoms

In [None]:
get_aromatic_atoms(atoms_obj, metals)

In [None]:
atoms_obj

In [None]:
print(aromatic_atoms)

In [None]:
# Remove all metal atoms from rdkit_mol
for atom in rdkit_mol:
    if atom.GetSymbol() in metals:
        rdkit_mol.RemoveAtom(atom.GetIdx())