In [2]:
import numpy as np
import pandas as pd
import polars as pl
import pickle
from pathlib import Path
from rdkit import Chem
from rdkit.Chem import Draw


import sys
sys.path.append('../')
from src.setup.setup_data import setup_train_x_data, setup_train_y_data, sample_data, setup_inference_data
from src.typing.xdata import XData, DataRetrieval

# Common Variables

In [3]:
DATA_PATH = Path('../data/')
SHRUNKEN_PATH = DATA_PATH / 'shrunken/'

# check if global variables are defined
if '_x_data' not in globals():
    _x_data = None

if '_y_data' not in globals():
    _y_data = None

if '_train_data' not in globals():
    _train_data = None

if '_test_data' not in globals():
    _test_data = None

# Dataset Utils

In [4]:
def get_train_x_data() -> XData:
    global _x_data

    if _x_data is not None:
        return _x_data
    
    print('Loading x_data from disk...')
    _x_data = setup_train_x_data(SHRUNKEN_PATH, get_train_data())

    return _x_data

def get_train_y_data():
    global _y_data

    if _y_data is not None:
        return _y_data

    print('Loading y_data from disk...')
    _y_data = setup_train_y_data(get_train_data())

    return _y_data

def get_test_x_data():
    global _test_data

    if _test_data is not None:
        return _test_data

    inference_data = pl.read_parquet(SHRUNKEN_PATH / "test.parquet")
    inference_data = inference_data.to_pandas(use_pyarrow_extension_array=True)
    _test_data = setup_inference_data(SHRUNKEN_PATH, inference_data)

    return _test_data

def get_train_data(sample_size: int = -1, sample_split: float = 0.5):
    global _train_data

    if _train_data is not None:
        return _train_data

    print('Loading train_data from disk...')
    _train_data = pl.read_parquet(SHRUNKEN_PATH / "train.parquet")
    _train_data = _train_data.to_pandas(use_pyarrow_extension_array=True)

    if sample_size > 0:
        _train_data = sample_data(_train_data, sample_size, sample_split)

    return _train_data

def visualize_molecule(molecule, desc = ""):
    print(desc)
    if isinstance(molecule, str):
        molecule = Chem.MolFromSmiles(molecule)
    molecule_image = Draw.MolToImage(molecule, size=(300, 300))
    display(molecule_image)


# Exploration

In [5]:
from rdkit.Chem import AllChem
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
import os
from rdkit import RDConfig
import py3Dmol
from pprint import pprint

IPythonConsole.ipython_3d = True

smiles = "C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1"
mol = Chem.MolFromSmiles(smiles)

# Generate Conformation
mol = Chem.AddHs(mol)
ps = AllChem.ETKDGv3()
ps.randomSeed = 0xf00d
AllChem.EmbedMolecule(mol, ps)

# Create Feature Map
fdef = AllChem.BuildFeatureFactory(os.path.join(RDConfig.RDDataDir,'BaseFeatures.fdef'))
pharmafeature_list = [fdef.GetMolFeature(mol, idx) for idx in range(fdef.GetNumMolFeatures(mol))]
# pharmafeature_list = fdef.GetFeaturesForMol(mol)

# print(pharmafeature_list)
# feature = pharmafeature_list[0]

# featList = [f for f in featList if f.GetFamily() in ['Donor','Acceptor','NegIonizable','PosIonizable','Aromatic','ZnBinder','LumpedHydrophobe','Hydrophobe']]
# featList = [f for f in featList if f.GetFamily() in ['Donor']]

# Colors
def colorToHex(rgb):
    return '0x'+''.join([f'{int(255*x):02x}' for x in rgb])
featColors = {
  'Donor': (0, 1, 1),           # cyan
  'Acceptor': (1, 0, 1),        # magenta
  'NegIonizable': (1, 0, 0),    # red
  'PosIonizable': (0, 0, 1),    # blue
  'ZnBinder': (1, .5, .5),      # pink
  'Aromatic': (1, .8, .2),      # orange
  'LumpedHydrophobe': (.5, .25, 0), # brown
  'Hydrophobe': (0, 1, 0),      # green
}
featColors_hex = {k:colorToHex(v) for k,v in featColors.items()}

# Visualize
def drawit(m, feats, p=None, confId=-1, removeHs=True):
        if p is None:
            p = py3Dmol.view(width=400, height=400)
        p.removeAllModels()
        if removeHs:
            m = Chem.RemoveHs(m)
        IPythonConsole.addMolToView(m,p,confId=confId)
        for feat in feats:
            print("Feat ID: ", feat.GetId(), "Mol IDs", feat.GetAtomIds(), "Family: ", feat.GetFamily())
            pos = feat.GetPos()
            clr = featColors_hex.get(feat.GetFamily())
            p.addSphere({'center':{'x':pos.x,'y':pos.y,'z':pos.z},'radius':.5,'color': clr})
        p.zoomTo()
        return p.show()

drawit(mol,pharmafeature_list)

Feat ID:  1 Mol IDs (9,) Family:  Donor
Feat ID:  2 Mol IDs (13,) Family:  Donor
Feat ID:  3 Mol IDs (28,) Family:  Donor
Feat ID:  4 Mol IDs (36,) Family:  Donor
Feat ID:  5 Mol IDs (3,) Family:  Acceptor
Feat ID:  6 Mol IDs (11,) Family:  Acceptor
Feat ID:  7 Mol IDs (24,) Family:  Acceptor
Feat ID:  8 Mol IDs (25,) Family:  Acceptor
Feat ID:  9 Mol IDs (26,) Family:  Acceptor
Feat ID:  10 Mol IDs (35,) Family:  Acceptor
Feat ID:  11 Mol IDs (38,) Family:  Acceptor
Feat ID:  12 Mol IDs (4, 5, 6, 7, 39, 40) Family:  Aromatic
Feat ID:  13 Mol IDs (10, 11, 12, 26, 27, 38) Family:  Aromatic
Feat ID:  14 Mol IDs (20, 21, 22, 23, 24, 25) Family:  Aromatic
Feat ID:  15 Mol IDs (5,) Family:  Hydrophobe
Feat ID:  16 Mol IDs (6,) Family:  Hydrophobe
Feat ID:  17 Mol IDs (7,) Family:  Hydrophobe
Feat ID:  18 Mol IDs (16,) Family:  Hydrophobe
Feat ID:  19 Mol IDs (17,) Family:  Hydrophobe
Feat ID:  20 Mol IDs (21,) Family:  Hydrophobe
Feat ID:  21 Mol IDs (22,) Family:  Hydrophobe
Feat ID:  22 M

[12:26:10] UFFTYPER: Unrecognized charge state for atom: 37
[12:26:10] UFFTYPER: Unrecognized atom type: Dy5+3 (37)


In [8]:
from tqdm import tqdm
from cProfile import Profile
from pstats import SortKey, Stats
import numpy.typing as npt
import torch
from rdkit.Chem.rdchem import Mol

def create_features(smiles: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    mol = Chem.MolFromSmiles(smiles)

    nodes = torch.from_numpy(create_feature_nodes(mol))
    edges = create_feture_edges(mol)
    edge_indices = torch.from_numpy(edges[0])
    edge_attr = torch.from_numpy(edges[1]) if edges[1] is not None else None

    return nodes, edge_indices, edge_attr

# Setup Feature Factory
fdef = AllChem.BuildFeatureFactory(os.path.join(RDConfig.RDDataDir,'BaseFeatures.fdef'))
fdef_dict = {fdef.GetFeatureFamilies()[i]: i for i in range(len(fdef.GetFeatureFamilies()))}
fdef_len = len(fdef.GetFeatureFamilies())
fdef_eye = np.eye(fdef_len)

# Setup Atomic Translation
atomic_nums = [5, 6, 7, 8, 9, 14, 16, 17, 35, 66]
atomic_nums_len = len(atomic_nums)

def create_feature_nodes(mol: Mol) -> npt.NDArray[np.uint8]:
    mol = Chem.AddHs(mol)
    pharmafeature_list = fdef.GetFeaturesForMol(mol)
    mol = Chem.RemoveHs(mol)

    def get_pharma_vector(feature_family: str):
        """Returns a one-hot encoded vector of the feature family."""
        return fdef_eye[fdef_dict[feature_family]]
    
    def get_atomic_vector(atomic_num: int):
        return np.eye(atomic_nums_len, dtype=np.uint8)[atomic_nums.index(atomic_num)]

    nodes = np.empty((mol.GetNumAtoms(), 3 + fdef_len + atomic_nums_len), dtype=np.uint8)
    for idx, atom in enumerate(mol.GetAtoms()):
        nodes[idx][:10] = get_atomic_vector(atom.GetAtomicNum())
        nodes[idx][10:13] = np.array([atom.GetDegree(), atom.GetHybridization(), atom.GetIsotope()])
        nodes[idx][13:] = np.sum([get_pharma_vector(f.GetFamily()) for f in pharmafeature_list if idx in f.GetAtomIds()], axis=0)
        # print(atom.GetAtomicNum(), atom.GetDegree(), atom.GetHybridization(), atom.GetIsotope(), [f.GetFamily() for f in pharmafeature_list if idx in f.GetAtomIds()])
    return nodes

def create_feture_edges(mol: Mol, *, get_bond_attributes: bool = False) -> tuple[npt.NDArray[np.uint16], npt.NDArray[np.uint8]]:
    bonds = mol.GetBonds()

    # Extract the attributes and the edge index
    edge_features = np.empty((2 * len(bonds), 2), dtype=np.uint8)
    edge_indices = np.empty((2 * len(bonds), 2), dtype=np.uint16)

    idx = 0
    for bond in bonds:
        start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
        edge_indices[idx] = [start, end]
        edge_indices[idx + 1] = [end, start]

        if get_bond_attributes:
            edge_features[idx] = [int(bond.GetIsConjugated()), int(bond.IsInRing())]
            edge_features[idx + 1] = [int(bond.GetIsConjugated()), int(bond.IsInRing())]
        
        idx += 2

    return edge_indices.T, edge_features if get_bond_attributes else None


smiles = "C#CCOc1ccc(CNc2nc(NCC3CCCN3c3cccnn3)nc(N[C@@H](CC#C)CC(=O)N[Dy])n2)cc1"

with Profile() as profile:
    result = create_features(smiles)
    (
        Stats(profile)
        .strip_dirs()
        .sort_stats(SortKey.CALLS)
        .print_stats()
    )
# print(result[0])

         755 function calls in 0.016 seconds

   Ordered by: call count

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       82    0.000    0.000    0.000    0.000 {built-in method _operator.index}
       50    0.000    0.000    0.000    0.000 1465406190.py:33(get_pharma_vector)
       42    0.000    0.000    0.000    0.000 {built-in method builtins.isinstance}
       41    0.000    0.000    0.000    0.000 {method 'items' of 'dict' objects}
       41    0.000    0.000    0.000    0.000 {method 'index' of 'list' objects}
       41    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
       41    0.000    0.000    0.000    0.000 1465406190.py:37(get_atomic_vector)
       41    0.000    0.000    0.000    0.000 1465406190.py:44(<listcomp>)
       41    0.000    0.000    0.000    0.000 {built-in method numpy.array}
       41    0.000    0.000    0.000    0.000 {built-in method numpy.zeros}
       41    0.000    0.000    0.000    0.000 {method 're

In [6]:
import sys
sys.path.append('../')
import time
from cProfile import Profile

from src.modules.training.dataset_steps.graphs.smiles_to_graph import SmilesToGraph

x = get_train_x_data()
x.retrieval = DataRetrieval.SMILES_MOL
y = get_train_y_data()
smiles_to_graph = SmilesToGraph(
    use_atom_chem_features=False,
    use_atom_pharmacophore_features=True,
    use_bond_features=False)

avg_time = 0
batch_size = 512
# for i in tqdm(range(0, len(x), batch_size)):
for i in range(0, len(x), batch_size):
    current_time = time.time()
    batch_x = x[i:i + batch_size]
    batch_y = y[i:i + batch_size]
    with Profile() as profile:
        _ = smiles_to_graph.train(batch_x, batch_y)
        (
            Stats(profile)
            .strip_dirs()
            .sort_stats(SortKey.CALLS)
            .print_stats()
        )
    print(time.time() - current_time)
    break
    

Loading x_data from disk...
Loading train_data from disk...
Loading y_data from disk...
         59989 function calls in 0.864 seconds

   Ordered by: call count

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    23489    0.001    0.000    0.001    0.000 {built-in method builtins.len}
     7680    0.002    0.000    0.004    0.000 {built-in method builtins.getattr}
     3072    0.000    0.000    0.000    0.000 {method 'append' of 'list' objects}
     2560    0.001    0.000    0.007    0.000 storage.py:100(__setattr__)
     2048    0.001    0.000    0.001    0.000 {built-in method numpy.empty}
     1550    0.000    0.000    0.000    0.000 {method 'values' of 'dict' objects}
     1536    0.000    0.000    0.000    0.000 {method 'items' of 'dict' objects}
     1536    0.002    0.000    0.002    0.000 {built-in method torch.from_numpy}
     1536    0.001    0.000    0.004    0.000 storage.py:82(_pop_cache)
     1536    0.001    0.000    0.004    0.000 storage.py:12