In [23]:
import numpy as np  # sometimes needed to avoid mkl-service error
import sys
import os
import argparse
import logging
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, WandbLogger
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.utilities import rank_zero_only
import torch
from torchmdnet.module import LNNP
from torchmdnet import datasets, priors, models
from torchmdnet.data import DataModule
from torchmdnet.models import output_modules
from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping
from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number
from pathlib import Path
import wandb
import json
import pandas as pd
from rdkit.Chem import AllChem
import copy
from rdkit.Geometry import Point3D
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.ipython_3d = True
import py3Dmol
from rdkit.Chem import rdDepictor
from rdkit.Chem import rdDistGeom
import rdkit

In [24]:
with open('commandline_args.txt', 'r') as f:
    args = json.load(f)
data = DataModule(args)
data.prepare_data()
data.setup("fit")

train 400, val 50, test 37


computing mean and std:   0%|          | 0/4 [00:00<?, ?it/s]


In [25]:
def print_atoms(mol):
    atoms=mol.GetAtoms()
    names=",".join([atom.GetSymbol() for atom in atoms ])
    print(names)

In [26]:
import random
for i in range(10):
    rand=random.randrange(0,400)
    mol=data.dataset_maybe_noisy[rand].mol
    noisy_mol=data.dataset_maybe_noisy[rand].noisy_mol
    print_atoms(mol)
    print_atoms(noisy_mol)

C,C,C,C,C,C,H,H,H,H,H,H,H,H,H,H
C,C,C,C,C,C,H,H,H,H,H,H,H,H,H,H
C,C,C,C,C,H,H,H,H,H,H,H,H
C,C,C,C,C,H,H,H,H,H,H,H,H
O,C,N,C,N,H,H
O,C,N,C,N,H,H
C,C,O,C,O,H,H,H,H,H,H
C,C,O,C,O,H,H,H,H,H,H
C,C,O,C,O,H,H,H,H,H,H
C,C,O,C,O,H,H,H,H,H,H
N,C,C,C,C,N,H,H,H,H,H,H
N,C,C,C,C,N,H,H,H,H,H,H
C,C,C,C,O,C,H,H,H,H,H,H,H,H
C,C,C,C,O,C,H,H,H,H,H,H,H,H
C,C,C,C,C,H,H,H,H,H,H,H,H
C,C,C,C,C,H,H,H,H,H,H,H,H
C,O,C,C,O,H,H,H,H,H,H,H,H
C,O,C,C,O,H,H,H,H,H,H,H,H
C,N,C,N,C,O,H,H,H,H,H,H
C,N,C,N,C,O,H,H,H,H,H,H


In [5]:
sample=data.dataset_maybe_noisy[15]


tensor([[-1.7674,  3.5668, -2.2891],
        [ 1.0245,  1.0152,  0.5745],
        [ 0.6888,  0.3857,  1.9742],
        [ 1.5244,  2.3697,  0.4127],
        [ 0.5275,  2.4693,  1.6104],
        [ 0.0291,  0.7604,  0.7546],
        [ 1.9843,  0.2055, -1.6015],
        [-1.9092,  1.4258, -0.0547],
        [-3.4948, -1.8347,  2.4298]])
tensor([[-1.7674,  3.5668, -2.2891],
        [ 1.0245,  1.0152,  0.5745],
        [ 0.6888,  0.3857,  1.9742],
        [ 1.5244,  2.3697,  0.4127],
        [ 0.5275,  2.4693,  1.6104],
        [ 0.0291,  0.7604,  0.7546],
        [ 1.9843,  0.2055, -1.6015],
        [-1.9092,  1.4258, -0.0547],
        [-3.4948, -1.8347,  2.4298]])


In [10]:
sample.noisy_mol.GetConformer(0).GetPositions()

array([[ 0.1386645 ,  1.97381212,  0.38018896],
       [ 0.41072317,  0.48876456,  0.52645924],
       [-1.0129904 ,  0.98647968,  0.36468587],
       [ 0.39898532,  2.39192662, -0.61552726],
       [ 0.2348427 ,  2.55244206,  1.32331444],
       [ 0.66885703,  0.16878532,  1.5579608 ],
       [ 0.83563715,  0.00940425, -0.38130596],
       [-1.45151894,  0.81546012, -0.6416003 ],
       [-1.61611414,  0.97645888,  1.2967554 ]])

In [7]:
sample.pos

tensor([[-1.7674,  3.5668, -2.2891],
        [ 1.0245,  1.0152,  0.5745],
        [ 0.6888,  0.3857,  1.9742],
        [ 1.5244,  2.3697,  0.4127],
        [ 0.5275,  2.4693,  1.6104],
        [ 0.0291,  0.7604,  0.7546],
        [ 1.9843,  0.2055, -1.6015],
        [-1.9092,  1.4258, -0.0547],
        [-3.4948, -1.8347,  2.4298]])

In [27]:
sample.pos

tensor([[-1.0634,  1.7419,  2.1481],
        [ 0.9792,  1.9879, -0.4576],
        [-5.3672,  1.0327,  0.5742],
        [ 2.0826,  3.3134, -2.5213],
        [ 1.1087,  3.1769, -2.5667],
        [ 1.6334,  2.0514,  0.1729],
        [-2.0669,  3.0521, -6.2096],
        [-2.9466, -1.5596,  3.1063],
        [-1.1847, -0.3315, -0.0414]])

In [27]:
model = LNNP(args, prior_model=None, mean=data.mean, std=data.std)

In [28]:
sample=copy.deepcopy(data.dataset_maybe_noisy[0])
copSample=copy.deepcopy(sample)

In [29]:
_,noise,_=model(sample.z,sample.pos)

In [30]:
predpos=sample.pos-noise
gpos=sample.pos-sample.pos_target
initpos=sample.pos

In [31]:
def make_conformer(mol,positions):
    id=mol.GetNumConformers()
    conformer = Chem.Conformer(mol.GetNumAtoms())
    conformer.SetId(id)
    for i in range(mol.GetNumAtoms()):
        conformer.SetAtomPosition(i, Point3D(positions[i][0].item(),positions[i][1].item(),positions[i][2].item())) 
    mol.AddConformer(conformer)

In [32]:
make_conformer(sample.mol,predpos)
make_conformer(sample.mol,gpos)
make_conformer(sample.mol,initpos)

In [33]:
for i in range(sample.mol.GetNumAtoms()):
    print(sample.mol.GetAtomWithIdx(i).GetSymbol())

C
H
H
H
H


In [48]:
for i in range(sample.noisy_mol.GetNumAtoms()):
    print(sample.noisy_mol.GetAtomWithIdx(i).GetSymbol())

O
C
C
C
H
H
H
H
H
H


In [49]:
conf=sample.mol.GetConformers()
for i in conf:
    print(i.GetId())

0
1
2
3


In [37]:
IPythonConsole.drawMol3D(sample.mol,confId=3)

In [69]:
rdDistGeom.EmbedMolecule(sample.noisy_mol)
noisy_conf=sample.noisy_mol.GetConformer()
noisy_conf_positions=noisy_conf.GetPositions()
noise=np.random.randn(sample.pos.shape[0],sample.pos.shape[1]) 
noisy_positions=noisy_conf_positions+noise
for i in range(sample.noisy_mol.GetNumAtoms()):
    noisy_conf.SetAtomPosition(i, Point3D(noisy_positions[i][0],noisy_positions[i][1],noisy_positions[i][2]))   
data.noisy_mol=Chem.AddHs(sample.noisy_mol,addCoords=True)
Chem.rdForceFieldHelpers.UFFOptimizeMolecule(sample.noisy_mol,confId=0,maxIters=40)
noisy_conf_positions=torch.tensor(noisy_conf.GetPositions(), dtype=torch.float)
noise=noisy_conf_positions-sample.pos   

In [70]:
IPythonConsole.drawMol3D(sample.noisy_mol,confId=0)

In [24]:
for i in range(sample.mol.GetNumAtoms()):
    conf.SetAtomPosition(i, Point3D(predpos[i][0].item(),predpos[i][1].item(),predpos[i][2].item()))      
    

In [25]:
IPythonConsole.drawMol3D(sample.mol,confId=0)

In [26]:
for i in range(sample.mol.GetNumAtoms()):
    conf.SetAtomPosition(i, Point3D(sample.pos[i][0].item(),sample.pos[i][1].item(),sample.pos[i][2].item()))

In [42]:
IPythonConsole.drawMol3D(sample.mol,confId=0)
Chem.rdForceFieldHelpers.UFFOptimizeMolecule(sample.mol,confId=0,maxIters=300)

0