In [1]:
import numpy as np  # sometimes needed to avoid mkl-service error
import sys
import os
import argparse
import logging
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, WandbLogger
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.utilities import rank_zero_only
import torch
from torchmdnet.module import LNNP
from torchmdnet import datasets, priors, models
from torchmdnet.data import DataModule
from torchmdnet.models import output_modules
from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping
from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number
from pathlib import Path
import wandb
import json
import pandas as pd
from rdkit.Chem import AllChem
import copy
from rdkit.Geometry import Point3D
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.ipython_3d = True
import py3Dmol
from rdkit.Chem import rdDepictor
from rdkit.Chem import rdDistGeom
import rdkit

In [2]:
command_dic={
"command_args_baselinedefault_defnoise":'command_args_baselinedefault_defnoise.txt',
"command_args_baselinedefault_rdkitnoise":"command_args_baselinedefault_rdkitnoise.txt",
"command_args_baselinedefault_rdkit04noise":"command_args_baselinedefault_rdkit04noise.txt",
"command_args_baselinedefault_8noise":"command_args_baselinedefault_8noise.txt",
"command_args_baseline8noise_defnoise":"command_args_baseline8noise_defnoise.txt",
"command_args_baseline8noise_rdkitnoise":"command_args_baseline8noise_rdkitnoise.txt",
"command_args_baseline8noise_rdkit04noise":"command_args_baseline8noise_rdkit04noise.txt",
"command_args_baseline8noise_8noise":"command_args_baseline8noise_8noise.txt",


"command_args_oursdefault_defnoise":"command_args_oursdefault_defnoise.txt",
"command_args_oursdefault_rdkitnoise":"command_args_oursdefault_rdkitnoise.txt",
"command_args_oursdefault_rdkit04noise":"command_args_oursdefault_rdkit04noise.txt",
"command_args_oursdefault_8noise":"command_args_oursdefault_8noise.txt",
"command_args_ours04noise_defnoise":"command_args_ours04noise_defnoise.txt",
"command_args_ours04noise_8noise":"command_args_ours04noise_8noise.txt",
"command_args_ours04noise_rdkitnoise":"command_args_ours04noise_rdkitnoise.txt",
"command_args_ours04noise_rdkit04noise":"command_args_ours04noise_rdkit04noise.txt",}




In [3]:
def load_data(argfile):
    root="arg_files/"
    with open(root+argfile, 'r') as f:
        args = json.load(f)
    data = DataModule(args)
    data.prepare_data()
    data.setup("fit")
    return data,args

In [4]:
data_defnoise,arg_baseline=load_data(command_dic["command_args_baselinedefault_defnoise"])
data_8noise,_=load_data(command_dic["command_args_baselinedefault_8noise"])
_,arg_ours=load_data(command_dic["command_args_oursdefault_defnoise"])
data_rdknoise,_=load_data(command_dic["command_args_oursdefault_rdkitnoise"])





train 400, val 50, test 37


  rank_zero_warn(
computing mean and std:   0%|          | 0/4 [00:00<?, ?it/s]


train 400, val 50, test 37


computing mean and std:   0%|          | 0/4 [00:00<?, ?it/s]


train 400, val 50, test 37


computing mean and std:   0%|          | 0/4 [00:00<?, ?it/s]


train 400, val 50, test 37


computing mean and std:   0%|          | 0/4 [00:00<?, ?it/s]


In [5]:
model_baseline_def = LNNP(arg_baseline, prior_model=None, mean=None, std=None)
model_ours = LNNP(arg_ours, prior_model=None, mean=None, std=None)


In [6]:
model_baseline_def.eval()
model_ours.eval()

LNNP(
  (model): TorchMD_Net(
    (representation_model): TorchMD_ET(hidden_channels=256, num_layers=8, num_rbf=64, rbf_type=expnorm, trainable_rbf=False, activation=silu, attn_activation=silu, neighbor_embedding=NeighborEmbeddingJittable_d595c0(
      (embedding): Embedding(100, 256)
      (distance_proj): Linear(in_features=64, out_features=256, bias=True)
      (combine): Linear(in_features=512, out_features=256, bias=True)
      (cutoff): CosineCutoff()
    ), num_heads=8, distance_influence=both, cutoff_lower=0.0, cutoff_upper=5.0)
    (output_model): EquivariantScalar(
      (output_network): ModuleList(
        (0): GatedEquivariantBlock(
          (vec1_proj): Linear(in_features=256, out_features=256, bias=False)
          (vec2_proj): Linear(in_features=256, out_features=128, bias=False)
          (update_net): Sequential(
            (0): Linear(in_features=512, out_features=256, bias=True)
            (1): SiLU()
            (2): Linear(in_features=256, out_features=256, bia

In [7]:
samples_defnoise=data_defnoise.dataset_maybe_noisy[:100]
samples_8noise=data_8noise.dataset_maybe_noisy[:100]
samples_rdnoise=data_rdknoise.dataset_maybe_noisy[:100]


In [8]:
def make_conformer(mol,positions):
    id=mol.GetNumConformers()
    conformer = Chem.Conformer(mol.GetNumAtoms())
    conformer.SetId(id)
    for i in range(mol.GetNumAtoms()):
        conformer.SetAtomPosition(i, Point3D(positions[i][0].item(),positions[i][1].item(),positions[i][2].item())) 
    mol.AddConformer(conformer)

In [29]:
type(samples_defnoise[i].mol)

rdkit.Chem.rdchem.Mol

In [26]:
mol=Chem.rdchem.Mol()

In [9]:
def write_conformer(path,mol,name):
    w = Chem.SDWriter("result/"+path+"/"+name+".sdf")
    for id in range(mol.GetNumConformers()):
        w.write(mol,confId=id)
    w.close()

In [10]:
for i in range(len(samples_8noise)):
    sample_defnoise=samples_defnoise[i]
    sample_8noise=samples_8noise[i]
    sample_rdnoise=samples_rdnoise[i]
    mol_copy=copy.deepcopy(sample_defnoise.mol)
    name=sample_8noise.name
    _,base_noise_def,_=model_baseline_def(sample_defnoise.z,sample_defnoise.pos)
    _,base_noise_8,_=model_baseline_def(sample_8noise.z,sample_8noise.pos)
    _,base_noise_rdkit,_=model_baseline_def(sample_rdnoise.z,sample_rdnoise.pos)

    _,ours_noise_def,_=model_ours(sample_defnoise.z,sample_defnoise.pos)
    _,ours_noise_8,_=model_ours(sample_8noise.z,sample_8noise.pos)
    _,ours_noise_rdkit,_=model_ours(sample_rdnoise.z,sample_rdnoise.pos)

    gpos=samples_8noise[i].mol.GetConformer().GetPositions()

    base_def_pred=sample_defnoise.pos-base_noise_def
    ours_def_pred=sample_defnoise.pos-base_noise_def

    base_8_pred=sample_8noise.pos-base_noise_8
    ours_8_pred=sample_8noise.pos-ours_noise_8

    base_rdk_pred=sample_rdnoise.pos-base_noise_rdkit
    ours_rdk_pred=sample_rdnoise.pos-ours_noise_rdkit

    mol_defnoise=copy.deepcopy(mol_copy)
    mol_defnoise.RemoveAllConformers()
    make_conformer(mol_defnoise,gpos)
    make_conformer(mol_defnoise,base_def_pred)
    make_conformer(mol_defnoise,ours_def_pred)
    
    mol_8noise=copy.deepcopy(mol_copy)
    mol_8noise.RemoveAllConformers()
    make_conformer(mol_8noise,gpos)
    make_conformer(mol_8noise,base_8_pred)
    make_conformer(mol_8noise,ours_8_pred)

    mol_rdnoise=copy.deepcopy(mol_copy)
    mol_rdnoise.RemoveAllConformers()
    make_conformer(mol_rdnoise,gpos)
    make_conformer(mol_rdnoise,base_rdk_pred)
    make_conformer(mol_rdnoise,ours_rdk_pred)

    write_conformer("Default_noise",mol_defnoise,name)
    write_conformer("noise_0.8",mol_8noise,name)
    write_conformer("rdkit_noise",mol_rdnoise,name)

    

    
    
    

In [23]:
IPythonConsole.drawMol3D(mol_defnoise,confId=2)

In [20]:
def asd(samples,noises):
    for i,sample in enumerate(samples)
        sample.mol.RemoveAllConformers()
        predpos=sample.pos-noise
        gpos=sample.pos-sample.pos_target
        initpos=sample.pos
        make_conformer(sample.mol,predpos)
        make_conformer(sample.mol,gpos)
        make_conformer(sample.mol,initpos)

In [23]:
sample.mol.GetNumConformers()

3

In [28]:
IPythonConsole.drawMol3D(sample.mol,confId=2)