In [1]:
import lmdb
import pickle
from utils.pdb_parser import PDBProtein
import os.path as osp
import sys
# sys.path.append('/home/haotian/Molecule_Generation/MG/Flex-SBDD')
from tqdm import tqdm
from utils.data import ProteinLigandData, torchify_dict
from utils.protein_ligand import parse_sdf_file, read_ply
from utils.chem import read_sdf, read_pkl
import argparse
import torch
import numpy as np
from torch_geometric.transforms import Compose
from rdkit import Chem
from glob import glob
from utils.chem import write_pkl, read_pkl, write_sdf

from easydict import EasyDict
from rdkit.Chem.rdMolAlign import CalcRMS

def get_result(docked_sdf, ref_mol=None):
    suppl = Chem.SDMolSupplier(docked_sdf,sanitize=False)
    results = []
    for i, mol in enumerate(suppl):
        if mol is None:
            continue
        line = mol.GetProp('REMARK').splitlines()[0].split()[2:]
        try:
            rmsd = CalcRMS(ref_mol, mol)
        except:
            rmsd = np.nan
        results.append(EasyDict({
            'rdmol': mol,
            'mode_id': i,
            'affinity': float(line[0]),
            'rmsd_lb': float(line[1]),
            'rmsd_ub': float(line[2]),
            'rmsd_ref': rmsd
        }))
    return results

def checkatoms(mol, allowed_atoms=['C', 'N', 'O', 'F', 'P', 'S', 'Cl']):
    """
    """
    for atom in mol.GetAtoms():
        if atom.GetSymbol() not in allowed_atoms:
            return False
    return True

## create the protein-liagand pair files

In [2]:
surface_file = './causal_inference/4fny_protein_pocket_8.0.ply'
ligand_files = glob('./causal_inference/SDF/*_out.sdf')

index_list = []
for ligand_file in ligand_files:
    try:
        mol = read_sdf(ligand_file)[0]
        mol = Chem.RemoveHs(mol)
        # write_sdf([mol], ligand_file) # write the sdf file without H, if you encounter the atomic error
        if not checkatoms(mol):
            continue
        mol_impact = - torch.tensor(get_result(ligand_file)[0]['affinity'])
        index_list.append((surface_file, ligand_file, mol_impact)) # change the label to your own
    except:
        pass
print('Get the',len(index_list), 'mols')
    
write_pkl(index_list, './causal_inference/index_list.pkl')

Get the 2695 mols
pkl file saved at ./casual_inference/index_list.pkl


## create the LMDB database

In [5]:
index = read_pkl('./causal_inference/index_list.pkl')

processed_path = './data/causal_inference_data.lmdb' 
db = lmdb.open(
    processed_path,
    map_size=10*(1024*1024*1024),   # 10GB
    create=True,
    subdir=False,
    readonly=False, # Writable
)

num_skipped = 0
with db.begin(write=True, buffers=True) as txn:
    for i, (ply_file, sdf_file, mol_impact) in enumerate(tqdm(index)):
        if ply_file is None: continue
        try:
            pocket_dict = read_ply(ply_file)
            ligand_dict = parse_sdf_file(sdf_file)
            data = ProteinLigandData.from_protein_ligand_dicts(
                protein_dict=torchify_dict(pocket_dict),
                ligand_dict=torchify_dict(ligand_dict),
            )
            data.mol_impact = mol_impact # change the label to your own
            data.protein_filename = ply_file
            data.ligand_filename = sdf_file
            data.mol = ligand_dict['mol']
            txn.put(
                key = str(i).encode(),
                value = pickle.dumps(data)
            )
        except:
            num_skipped += 1
            if num_skipped%100 == 0:
                print('Skipping (%d)' % (num_skipped, ))
            continue
db.close()


100%|██████████| 2695/2695 [01:25<00:00, 31.68it/s]


## prepare the train data

In [9]:
from utils.transforms import *

config_file = './configs/causual_inference.yml'
config = load_config(config_file)

protein_featurizer = FeaturizeProteinAtom()
ligand_featurizer = FeaturizeLigandAtom()                   
masking = get_mask(config.train.transform.mask)
composer = AtomComposer(protein_featurizer.feature_dim, ligand_featurizer.feature_dim, config.model.encoder.knn)

edge_sampler = EdgeSample(config.train.transform.edgesampler)
cfg_ctr = config.train.transform.contrastive
contrastive_sampler = ContrastiveSample(cfg_ctr.num_real, cfg_ctr.num_fake, cfg_ctr.pos_real_std, cfg_ctr.pos_fake_std, config.model.field.knn)

transform = Compose([
    RefineData(),
    LigandCountNeighbors(),
    protein_featurizer,
    ligand_featurizer,
    masking,
    composer,

    FocalBuilder(),
    edge_sampler,
    contrastive_sampler,
])

In [7]:
name2id_path = './data/causal_inference_data_name2id.pt'
processed_path = './data/causal_inference_data.lmdb'

db = lmdb.open(
        processed_path,
        map_size=10*(1024*1024*1024),   # 10GB
        create=False,
        subdir=False,
        readonly=True,
        lock=False,
        readahead=False,
        meminit=False,
    )
with db.begin() as txn:
    keys = list(txn.cursor().iternext(values=False))
    
name2id = {}
for i in tqdm(range(len(keys))):
    try:

        for _ in range(20):
            data = transform(pickle.loads(db.begin().get(keys[i])))

        name = (data['protein_filename'], data['ligand_filename'])
        name2id[name] = i
    except Exception as e:
        print(i, e)
        continue

torch.save(name2id, name2id_path)
print('saved name2id at {}'.format(name2id_path))
print('Get',len(name2id.keys()), 'mols')

  0%|          | 0/2695 [00:00<?, ?it/s]

492 23
530 37
795 4
1050 25
1179 28
2580 empty range for randrange() (0, 0, 0)
saved name2id at ./data/causal_inference_data_name2id.pt
Get 2689 mols


In [8]:
name2id = torch.load(name2id_path)
name2id_list = list(name2id.keys())

split_name = {}
split_name['train'] = []
for i in range(len(name2id.keys())):
    split_name['train'].append(name2id_list[i])
split_name['val'] = split_name['train'][:20] 
torch.save(split_name, './data/causal_inference_data_split_name.pt')