In [2]:
%run test.py

Graph(num_nodes=4, num_edges=3,
      ndata_schemes={}
      edata_schemes={})


In [3]:
import torch
from torch import nn

embedding = nn.Embedding(10, 3)
embedding.forward(torch.tensor([0,1,3,3]))


tensor([[-1.2583, -0.0714,  0.4054],
        [-1.8833, -0.2661,  1.0849],
        [-1.7635,  1.0024, -0.5281],
        [-1.7635,  1.0024, -0.5281]], grad_fn=<EmbeddingBackward0>)

In [2]:
import numpy as np
from rdkit import Chem
from rdkit.Chem import rdchem, AllChem

def _generate_conformer(mol: rdchem.Mol, numConfs: int = 10) -> tuple[rdchem.Mol, rdchem.Conformer]:
    # try:
    new_mol = Chem.AddHs(mol)
    res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs)
    ### MMFF generates multiple conformations
    res = AllChem.MMFFOptimizeMoleculeConfs(new_mol)
    new_mol = Chem.RemoveHs(new_mol)
    index = np.argmin([x[1] for x in res])
    conf = new_mol.GetConformer(id=int(index))
    # except:
    #     new_mol = mol
    #     AllChem.Compute2DCoords(new_mol)
    #     conf = new_mol.GetConformer()
    return new_mol, conf

smiles = 'OCc1ccccc1CN'
mol = AllChem.MolFromSmiles(smiles)
new_mol, conf = _generate_conformer(mol)

In [50]:
import dgl, torch, numpy as np
from rdkit import Chem
from rdkit.Chem import rdchem, AllChem
from torch import nn, Tensor, IntTensor
from dgl import nn as gnn, DGLGraph
from typing import TypeAlias, Any, no_type_check, Callable, Literal
from enum import Enum
from dataclasses import dataclass


Atom: TypeAlias = rdchem.Atom
Bond: TypeAlias = rdchem.Bond
Mol: TypeAlias = rdchem.Mol
Conformer: TypeAlias = rdchem.Conformer

class Utils:
    RdChemEnum: TypeAlias = Any     # RdChem's enums have no proper typings.
    @staticmethod
    def _rdchem_enum_to_list(rdchem_enum: RdChemEnum) -> list[RdChemEnum]:
        """Converts an enum from `rdkit.Chem.rdchem` (eg. `rdchem.ChiralType`)
        to a list of all the possible enum valuess.

        Args:
            rdchem_enum (RdChemEnum): An enum defined in `rdkit.Chem`.

        Returns:
            list[RdChemEnum]: All possible enum values in a list.
        """
        return [rdchem_enum.values[i] for i in range(len(rdchem_enum.values))]


    FeatureCategory: TypeAlias = Literal['atom_feats', 'bond_feats']
    FeatureName: TypeAlias = str
    OneHotHeaders: TypeAlias = list[Any]
    GetValueFn: TypeAlias = Callable[[Atom], Bond]

    @dataclass
    class Feature:
        get_value: Callable[[Atom | Bond], Any]
        """Gets the feature value from an `rdchem.Atom` / `rdchem.Bond` instance."""
        onehot_headers: list[Any]
        """The one-hot headers for the feature."""

    FEATURES: dict[FeatureCategory, dict[FeatureName, Feature]] = {
        'atom_feats': {
            'atomic_num': Feature(
                lambda atom: atom.GetAtomicNum(),
                list(range(1, 119)) + ['misc'],
            ),
            'chiral_tag': Feature(
                lambda atom: atom.GetChiralTag(),
                _rdchem_enum_to_list(rdchem.ChiralType),
            ),
            'degree': Feature(
                lambda atom: atom.GetDegree(),
                [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
            ),
            'formal_charge': Feature(
                lambda atom: atom.GetFormalCharge(),
                [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
            ),
            'hybridization': Feature(
                lambda atom: atom.GetHybridization(),
                _rdchem_enum_to_list(rdchem.HybridizationType),
            ),
            'is_aromatic': Feature(
                lambda atom: int(atom.GetIsAromatic()),
                [0, 1],
            ),
            'total_numHs': Feature(
                lambda atom: atom.GetTotalNumHs(),
                [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
            ),
        },
        'bond_feats': {
            'bond_dir': Feature(
                lambda bond: bond.GetBondDir(),
                _rdchem_enum_to_list(rdchem.BondDir),
            ),
            'bond_type': Feature(
                lambda bond: bond.GetBondType(),
                _rdchem_enum_to_list(rdchem.BondType),
            ),
            'is_in_ring': Feature(
                lambda bond: int(bond.IsInRing()),
                [0, 1],
            ),
        }
    }
    FEATURE_NAMES: dict[FeatureCategory, list[FeatureName]] = {
        'atom_feats': list(FEATURES['atom_feats'].keys()),
        'bond_feats': list(FEATURES['bond_feats'].keys())
    }

    @staticmethod
    def smiles_to_graph(smiles: str) -> DGLGraph:
        """Convert a molecule's SMILES string into a DGL graph.

        Args:
            smiles (str): A molecule's SMILES string.

        Returns:
            DGLGraph: The molecule in graph form.
        """
        mol = AllChem.MolFromSmiles(smiles)
        if len(mol.GetAtoms()) <= 400:
            mol, conf = Utils._generate_conformer(mol)

            # Create an undirected DGL graph with all the molecule's nodes and edges.
            num_bonds = mol.GetNumBonds()
            edges = torch.zeros(num_bonds, dtype=torch.int32), torch.zeros(num_bonds, dtype=torch.int32)
            for i, bond in enumerate(mol.GetBonds()):
                edges[0][i] = bond.GetBeginAtomIdx()
                edges[1][i] = bond.GetEndAtomIdx()
            graph = dgl.graph(edges, idtype=torch.int32)

            # Add node features.
            for feat_name, feat in Utils.FEATURES['atom_feats'].items():
                graph.ndata[feat_name] = torch.tensor([feat.get_value(atom) for atom in mol.GetAtoms()])

            # Add edge features.
            for feat_name, feat in Utils.FEATURES['bond_feats'].items():
                graph.edata[feat_name] = torch.tensor([feat.get_value(bond) for bond in mol.GetBonds()])

            # graph = dgl.to_bidirected(graph)
            return graph

    @staticmethod
    def _generate_conformer(mol: Mol, numConfs: int = 10) -> tuple[Mol, Conformer]:
        new_mol = Chem.AddHs(mol)
        res = AllChem.EmbedMultipleConfs(new_mol, numConfs=numConfs)
        ### MMFF generates multiple conformations
        res = AllChem.MMFFOptimizeMoleculeConfs(new_mol)
        new_mol = Chem.RemoveHs(new_mol)
        index = np.argmin([x[1] for x in res])
        conf = new_mol.GetConformer(id=int(index))
        return new_mol, conf

smiles = 'OCc1ccccc1CN'
mol = AllChem.MolFromSmiles(smiles)
graph = Utils.smiles_to_graph(smiles)
graph.srcdata['h']

KeyError: 'h'

In [25]:
bonds_tensor = torch.zeros(2, mol.GetNumBonds(), dtype=torch.long)
for i, bond in enumerate(mol.GetBonds()):
    begin_atom_idx = bond.GetBeginAtomIdx()
    end_atom_idx = bond.GetEndAtomIdx()
    bonds_tensor[0, i] = begin_atom_idx
    bonds_tensor[1, i] = end_atom_idx

In [32]:
edges

In [14]:
[atom for atom in mol.GetAtoms()]

[<rdkit.Chem.rdchem.Atom at 0x7f3f36799bd0>,
 <rdkit.Chem.rdchem.Atom at 0x7f3f36799c40>,
 <rdkit.Chem.rdchem.Atom at 0x7f3f36799cb0>,
 <rdkit.Chem.rdchem.Atom at 0x7f3f36799d20>,
 <rdkit.Chem.rdchem.Atom at 0x7f3f36799d90>,
 <rdkit.Chem.rdchem.Atom at 0x7f3f36799e00>,
 <rdkit.Chem.rdchem.Atom at 0x7f3f36799e70>,
 <rdkit.Chem.rdchem.Atom at 0x7f3f36799ee0>,
 <rdkit.Chem.rdchem.Atom at 0x7f3f36799f50>,
 <rdkit.Chem.rdchem.Atom at 0x7f3f36799fc0>]