In [3]:
import pandas as pd
import torch
import os
from rdkit import Chem
from rdkit.Chem import rdchem
from tqdm import tqdm
from torch_geometric.data import Data

In [6]:
# Define atom features
def atom_features(atom):
    return torch.tensor([
        atom.GetAtomicNum(),
        atom.GetFormalCharge(),
        int(atom.GetIsAromatic()),
        atom.GetHybridization().real,
        atom.GetDegree(),
        atom.GetTotalNumHs()
    ], dtype=torch.float)

# Define bond features
def bond_features(bond):
    return torch.tensor([
        int(bond.GetBondTypeAsDouble()),  # Single=1.0, Double=2.0, etc.
        int(bond.GetIsConjugated()),
        int(bond.IsInRing())
    ], dtype=torch.float)

# Convert SMILES to PyG graph
def smiles_to_graph(smiles, mol_id):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    # Node features
    x = torch.stack([atom_features(atom) for atom in mol.GetAtoms()])

    # Edge indices and features
    edge_index = []
    edge_attr = []

    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        f = bond_features(bond)

        # Undirected edge (i <-> j)
        edge_index += [[i, j], [j, i]]
        edge_attr += [f, f]

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.stack(edge_attr)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles, mol_id=mol_id)

# Main function
def generate_graphs(smiles_csv="data/step2_kinase_inhibitors_smiles.csv", output_dir="data/graphs/"):
    os.makedirs(output_dir, exist_ok=True)

    df = pd.read_csv(smiles_csv)
    saved = 0

    for _, row in tqdm(df.iterrows(), total=len(df), desc="Creating graphs"):
        smiles = row["canonical_smiles"]
        mol_id = row["molecule_chembl_id"]
        data = smiles_to_graph(smiles, mol_id)

        if data:
            torch.save(data, os.path.join(output_dir, f"{mol_id}.pt"))
            saved += 1

    print(f"✓ Saved {saved} molecular graphs to {output_dir}")

# Run it
if __name__ == "__main__":
    generate_graphs()


Creating graphs:   0%|          | 0/10584 [00:00<?, ?it/s]

Creating graphs: 100%|██████████| 10584/10584 [01:24<00:00, 125.68it/s]

✓ Saved 10584 molecular graphs to data/graphs/





In [8]:
data = torch.load("data/graphs/CHEMBL10.pt", weights_only=False)
print(data)

Data(x=[27, 6], edge_index=[2, 60], edge_attr=[60, 3], smiles='C[S+]([O-])c1ccc(-c2nc(-c3ccc(F)cc3)c(-c3ccncc3)[nH]2)cc1', mol_id='CHEMBL10')
