In [1]:

from pathlib import Path

import torch

from torch_geometric.data import Data

from rdkit import Chem

root: Path = Path("/Users/arvandkaveh/Projects/kit/graph_hdc")
DATASET_PATH = root / "_datasets"
smiles_dir: Path = root / "Smiles/ZINC_smile"
files: list[Path] = [
    smiles_dir / "train_smile.txt",
    smiles_dir / "test_smile.txt",
    smiles_dir / "valid_smile.txt",
]


def iter_smiles(file: Path):
    """Yield SMILES strings from each file, skipping an optional header."""
    with open(file) as fh:
        for line_no, line in enumerate(fh):
            if line_no == 0 and line.strip().lower() == "smiles":
                continue
            s = line.strip().split()[0]
            if s:
                yield s

In [3]:
from rdkit.Chem import Atom, ValenceType

atom_to_idx: dict[str, int] = {
    'Br': 0, 'C': 1, 'Cl': 2, 'F': 3, 'I': 4, 'N': 5, 'O': 6, 'P': 7, 'S': 8
}

idx_to_atom: dict[int, str] = {v: k for k, v in atom_to_idx.items()}

def atom_key(atom: Atom):
    """
    Create the categorical 'atom type' key similar to the one
    used in the PyG ZINC pickles: element + formal charge +
    aromatic flag + #implicit Hs.

    -> Its not fully aligned with ZINC yet since it generates 45 categories, where ZINC250 has 28 distinct atom types
    """
    return {
            "symbol": atom.GetSymbol(),
            "degree": atom.GetDegree(),
            "valence": atom.GetValence(ValenceType.EXPLICIT),
            "charge": atom.GetFormalCharge(),
            "num_hydrogens": atom.GetTotalNumHs(),
            "is_aromatic": atom.GetIsAromatic(),
            "is_part_of_ring": atom.IsInRing(),
    }

In [29]:
def mol_to_data(mol: Chem.rdchem.Mol, atom2idx: dict[str, int] = atom_to_idx) -> Data:
    x = [
        [
            float(atom2idx[a.GetSymbol()]), # 9 Categories
            float(a.GetDegree()), # 5 Categories
            float(a.GetFormalCharge()), # 3 Categories
            float(a.GetTotalNumHs()), # 4 Categories
        ]
        for a in mol.GetAtoms()
    ]
    ei_src, ei_dst = [], []
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        ei_src += [i, j]
        ei_dst += [j, i]
    return Data(
        x=torch.tensor(x, dtype=torch.float32),
        edge_index=torch.tensor([ei_src, ei_dst], dtype=torch.long),
        smiles=Chem.MolToSmiles(mol),
    )

In [30]:
## Debugging
debug_smiles = list(iter_smiles(smiles_dir / "debug_smile.txt"))
debug_mols = [Chem.MolFromSmiles(s) for s in debug_smiles]
debug_set = [mol_to_data(m, atom_to_idx) for m in debug_mols]
data = debug_set[0]
print(data)

Data(x=[32, 4], edge_index=[2, 68], smiles='C/C=C(\C)[C@@H]1C=C[C@@H]2C[C@H](C)C[C@H](C)[C@@H]2[C@@H]1C(=O)C1=C([O-])[C@H](C[C@](C)(O)C(=O)[O-])NC1=O')


In [37]:
from typing import Optional, Callable
from torch.serialization import clear_safe_globals, add_safe_globals
from torch_geometric.data import InMemoryDataset
clear_safe_globals()
class ZincSmiles(InMemoryDataset):
    """
    A lightweight `InMemoryDataset` wrapping lists of SMILES strings stored in
    plain-text files (one per line).

    Parameters
    ----------
    root : str | Path
        Root directory that contains ``raw/`` and ``processed/`` sub-dirs.
    split : {'train', 'valid', 'test'}
        Which split file to load (``<split>_smile.txt`` is expected in ``root/raw``).
    transform, pre_transform, pre_filter : callable | None
        See PyG `Dataset` API – they are executed on-the-fly (``transform``) or once
        at processing time (``pre_*``).
    """

    def __init__(
        self,
        root: str | Path,
        split: str = "train",
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_filter: Optional[Callable] = None,
    ):
        self.split = split.lower()
        assert self.split in {"train", "valid", "test"}
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)

    # ── filenames -----------------------------------------------------------------
    @property
    def raw_file_names(self):        # → list[str]
        return [f"{self.split}_smile.txt"]

    @property
    def processed_file_names(self):  # → list[str]
        return [f"data_{self.split}.pt"]

    # ── I/O stubs (no download, local files only) ---------------------------------
    def download(self):  # noqa: D401  (nothing to download)
        pass

    # ── heavy lifting --------------------------------------------------------------
    def process(self):
        data_list: list[Data] = []
        for s in iter_smiles(Path(self.raw_paths[0])):
            mol = Chem.MolFromSmiles(s)
            if mol is None:
                continue  # skip unparsable strings
            data = mol_to_data(mol)
            if self.pre_filter is not None and not self.pre_filter(data):
                continue
            if self.pre_transform is not None:
                data = self.pre_transform(data)
            data_list.append(data)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [38]:
ds = ZincSmiles(root=DATASET_PATH / "ZincSmiles", split="valid")