In [7]:
from rdkit import Chem
from torch_geometric.data import Data
import torch
from library.functions_to_abstract_data import extract_qm9_data
from torch_geometric.datasets import QM9
from library.GCN import *

In [8]:
def smiles_to_graph(smiles, y_value):
    mol = Chem.MolFromSmiles(smiles)
    atoms = mol.GetAtoms()
    bonds = mol.GetBonds()

    # Node features (atomic number)
    x = torch.tensor([atom.GetAtomicNum() for atom in atoms], dtype=torch.long).unsqueeze(1)

    # Edges (bond connections)
    edge_index = []
    for bond in bonds:
        i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        edge_index.append((i, j))
        edge_index.append((j, i))  # undirected

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    y = torch.tensor([y_value], dtype=torch.float)

    return Data(x=x, edge_index=edge_index, y=y)

In [9]:
# --- Load data ---
dataset_qm9 = QM9(root="../data/QM9")
df_qm9 = extract_qm9_data(dataset_qm9)

smiles = df_qm9["smiles"]
gaps = df_qm9['gap']

# Save dataset

In [19]:
graph = Graph(molecule_smiles=smiles[100],node_vec_len=9)
print(graph.node_mat)
print(graph.adj_mat)

[[0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0.]]
[[1.         0.71232943 0.         0.         0.         0.99974586
  0.         0.         0.         0.         0.        ]
 [0.72262436 1.         0.64350064 0.64350064 0.         0.
  0.89333508 0.         0.         0.         0.        ]
 [0.         0.67842605 1.         0.64350064 0.         0.
  0.         0.89333508 0.89333508 0.         0.        ]
 [0.         0.67842605 0.67842605 1.         0.71232943 0.
  0.         0.         0.         0.89333508 0.        ]
 [0.         0.         0.         0.72262436 1.         0.
  0.         0.         0.         0.         0.99974586]
 [1.02014355 0.         0.         0.         0.         1.
  0.       