In [1]:
from dataclasses import dataclass, field
from typing import List
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import Draw
IPythonConsole.ipython_useSVG=True
from rdkit.Chem import rdmolfiles, rdmolops, AllChem
from rdkit.Chem.Scaffolds import MurckoScaffold
from collections import defaultdict
import itertools
import numpy as np
import dgl.backend as F
from enum import Enum, unique
from rdkit.Chem import BondType
import dgl
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from cl_featuriser import AtomFeaturiser, BondFeaturiser

Using backend: pytorch


In [5]:
def construct_bigraph_from_mol(mol):

    g = dgl.graph(([], []), idtype=torch.int32)

    # Add nodes
    num_atoms = mol.GetNumAtoms()
    g.add_nodes(num_atoms)

    # Add edges
    src_list = []
    dst_list = []
    num_bonds = mol.GetNumBonds()
    for i in range(num_bonds):
        bond = mol.GetBondWithIdx(i)
        u = bond.GetBeginAtomIdx()
        v = bond.GetEndAtomIdx()
        src_list.extend([u, v])
        dst_list.extend([v, u])

    g.add_edges(torch.IntTensor(src_list), torch.IntTensor(dst_list))

    return g

def mol_to_graph(mol, canonical_atom_order=False):
    if mol is None:
        print('Invalid mol found')
        return None
    
    
    
    if canonical_atom_order:
        new_order = rdmolfiles.CanonicalRankAtoms(mol)
        mol = rdmolops.RenumberAtoms(mol, new_order)
    # kekulize molecule
    Chem.rdmolops.Kekulize(mol)
    # get core from whole mol
    core = MurckoScaffold.GetScaffoldForMol(mol)
    #print("core",Chem.MolToSmiles(core), mol.GetSubstructMatches(core))
    #print("mol",Chem.MolToSmiles(mol))
    sub_order = list(mol.GetSubstructMatches(core)[0])
    scaffold_list = sub_order

    for i in range(mol.GetNumAtoms()):
        if i in sub_order:
            continue
        else:
            scaffold_list.append(i)
    mol = Chem.RenumberAtoms(mol, tuple(scaffold_list))

    g = construct_bigraph_from_mol(mol)
    g_scaffold = construct_bigraph_from_mol(core)


    g.ndata.update(atom_featuriser(mol))
    g_scaffold.ndata.update(atom_featuriser(core))

    g.edata.update(bond_featuriser(mol))
    g_scaffold.edata.update(bond_featuriser(core))

    actions = []
    src, dest = g.edges()
    for i in range(core.GetNumAtoms(), mol.GetNumAtoms()):
        node_feat = g.ndata["atom_type"][i].unsqueeze(0)
        edge_dests = torch.nonzero(src==i, as_tuple=False).flatten()
        edge_srcs = dest[edge_dests]
        edge_srcs_index = torch.nonzero(edge_srcs<i, as_tuple=False).flatten()
        edge_srcs = edge_srcs[edge_srcs_index]
        edge_feat = g.edata["bond_type"][edge_dests]
        edge_feat = edge_feat[edge_srcs_index]
        #[atom type, edge type, destination]
        if len(edge_srcs) > 0:
            actions.append([
                            node_feat,
                            edge_feat,
                            edge_srcs
                            ])
        elif len(edge_srcs) == 0:
            actions.append([
                            node_feat,
                            torch.tensor([bond_featuriser.max_bond_type()]),
                            -1
                            ])
    actions.append([
                    torch.tensor([atom_featuriser.max_atom_type()]),
                    -1,
                    -1
                    ])

    return g, g_scaffold, actions

df = pd.read_csv("/BiO/pekim/GRAPHNET/data/normalised_filtered_data.csv")
gl, sl = [],[]
atom_featuriser, bond_featuriser = AtomFeaturiser(), BondFeaturiser()
for index in range(df.shape[0]):
    if index > 26000:
        try:
            single_smi = df.loc[index].smiles
            mol = Chem.MolFromSmiles(single_smi)
            g, g_scaffold, action = mol_to_graph(mol)
            gl.append(g)
            sl.append(g_scaffold)
        except:
            mol = Chem.MolFromSmiles(single_smi)
            Chem.rdmolops.Kekulize(mol)
            
            core = MurckoScaffold.GetScaffoldForMol(mol)
            sub_order = list(mol.GetSubstructMatches(core))
            print(sub_order)
            break
            






[(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17)]
