In [3]:
import pickle
import warnings
from elektronn.utils import generate_grid, model_kwargs, LoadModels
from elektronn.elektronn_ensemble_predict import data_from_rdkit, ElektroNN_Ensemble

with open("ElektroNN/basisfunction_params.pkl", "rb") as file:
    basisfunction_params = pickle.load(file)

model_path = 'ElektroNN/modelparams/04-20-33/'
num_models = 5
map_location = 'cuda'
warnings.filterwarnings("ignore")
models_to_load = [2]
loader_specified = LoadModels(
    models_to_load, num_models, model_path, model_kwargs, map_location, all_models=True
)
loader_specified.load()
models_specified = loader_specified.models

loaded model_fold_1.pth
loaded model_fold_2.pth
loaded model_fold_3.pth
loaded model_fold_4.pth
loaded model_fold_5.pth


In [4]:
from torch.utils.data import Dataset
from rdkit import Chem
from rdkit.Chem import AllChem

class MolGraphData(Dataset):
    def __init__(self, smiles_dict, basisfunction_params):
        super().__init__()
        self.mols = []
        self.names = []
        self.basisfunction_params = basisfunction_params
        self.params = AllChem.ETKDGv3()
        self.params.randomSeed = 42069
        self.bad_atom_mols = 0
        for name, smiles in smiles_dict.items():
            mol = Chem.MolFromSmiles(smiles)
            bad_atom = False
            for i in range(mol.GetNumAtoms()):
                atomic_number = float(mol.GetAtomWithIdx(i).GetAtomicNum())
                if atomic_number not in self.basisfunction_params:
                    bad_atom = True
                    break
            if bad_atom:
                self.bad_atom_mols += 1
                continue
            self.mols.append( mol )
            self.names.append( name )

    def __len__(self):
        return len(self.mols)
    
    def __getitem__(self, index):
        name = self.names[index]
        mol = self.mols[index]

        mol = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol, self.params)
        data_mol = data_from_rdkit(mol, name, self.basisfunction_params)

        return data_mol
    
    def get_mol(self, name):
        index = self.names.index(name)
        return self[index]



In [5]:
import csv

smiles_dict = {}

assay_id = 'AID435034'
outcome = 'actives'
filename = assay_id + '_' + outcome + '.csv'
root = './data/qsar/'
filepath = root + filename

with open(filepath) as csvfile:
    reader = csv.DictReader(csvfile, delimiter=',')
    for row in reader:
        smiles = row['SMILES']
        name = row['CID']
        smiles_dict[name] = smiles

dataset = MolGraphData(smiles_dict, basisfunction_params)
print(dataset[0])
print(dataset.bad_atom_mols)

Data(x=[61, 18], pos=[61, 3], exp=[61, 127], norm=[61, 127], atomic_numbers=[61], filename='1263872')
3


In [6]:
%load_ext autoreload
%autoreload 2
from torch_geometric.loader import DataLoader
from tqdm import tqdm

preds = {}

dl = DataLoader(dataset, pin_memory=True, batch_size=64, num_workers=8)
for mol in tqdm(dl):
    mol.to(map_location)
    pred = ElektroNN_Ensemble(mol, map_location, models_specified)
    for i in range(len(mol.filename)):
        preds[ mol.filename[i] ] = pred[ mol.batch == i ]

len(preds)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 2/2 [00:08<00:00,  4.22s/it]


75

In [7]:
from elektronn.utils import coeff_unperm_gau2grid_density_kdtree_ml_only
import numpy as np

Rs = [(14, 0), (14, 1), (5, 2), (4, 3), (2, 4)]

for name, pred in preds.items():
    if name == "1263872":
        continue
    molecule = dataset.get_mol(name)
    print("Generate Grid for", name)
    x, y, z, vol, x_spacing, y_spacing, z_spacing = generate_grid(molecule, spacing=0.25, buffer=2.0)
    # Ensure x, y, z are float64 for gau2grid
    x_flat = x.flatten().astype(np.float64)
    y_flat = y.flatten().astype(np.float64)
    z_flat = z.flatten().astype(np.float64)
    ml = coeff_unperm_gau2grid_density_kdtree_ml_only(x_flat, y_flat, z_flat, molecule, pred, Rs)
    break

Generate Grid for 655606


In [15]:
name

'1263872'

In [8]:
from elektronn.utils import write_cube_file
Rs = [(14,0),(14,1),(5,2),(4,3),(2,4)]

print("Generated Grid")
print("Generate ML object")


print("Generated ML Object")
print("Generate Cube")
output_path = "data/qsar/"
write_cube_file(output_path + f"predicted_density_{name}.cube", molecule.atomic_numbers, molecule.pos, x,y,z,ml)       

print("Generated Cube")

Generated Grid
Generate ML object
Generated ML Object
Generate Cube
Generated Cube
