# 3D metrics based on the surface distance

In [9]:
import os
import sys
import glob

from pathlib import Path
import torch
import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit import RDLogger, DataStructs
from rdkit.Chem import PandasTools
from rdkit.Chem.MolStandardize import rdMolStandardize
import copy


old_cwd = Path.cwd()
os.chdir(Path.cwd().parent)
from shape_alignment.molecule import Molecules, MoleculeInfo
os.chdir(old_cwd)

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd()))))
from utils.chem_transforms import remove_atom_indices, replace_atom_indices

In [10]:
from tqdm._tqdm_notebook import tqdm_notebook
tqdm_notebook.pandas()

In [11]:
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

## Load data

In [12]:
pdb = '7JTO'
method = 'difflinker'

In [13]:
gen_folder = 'data/generated'
poses_root = os.path.join(os.path.expanduser('~'), 'data/protacs/preprint')
filepath = glob.glob(os.path.join(gen_folder, f'{pdb}_sampled_{method}_valid_fil.csv'))
print(filepath)
df = pd.read_csv(filepath[0])
df.head()

['data/generated/7JTO_sampled_difflinker_valid_fil.csv']


Unnamed: 0,ID,reference,lig_id,protac_smiles,linker_smiles,anchor_smiles,warhead_smiles,anchor_ev,warhead_ev,POI,...,embedded_path,rmsd_anc,rmsd_wrh,sc_rdkit,vinardo,cd_protac_embed2xtal,cd_protac_method2xtal,tanimoto_ptc,cd_aligend_linker_embed2aligned,cd_aligned_linker_embed2aligned
0,7JTO_difflinker_678,7JTO,MS33,Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O...,CCCCCC(=O)OCOCCOC,Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O...,CN1CCN(c2ccc(-c3cccc(CN4CCNCC4)c3)cc2NC(=O)C2=...,Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O...,CN1CCN(c2ccc(-c3cccc(CN4CCN([*:2])CC4)c3)cc2NC...,WDR5,...,selected_min_conf6__7JTO_difflinker_678.sdf,0.423307,0.400898,0.788413,-17.00314,1.152685,1.108036,0.964643,5.571145,5.692617
1,7JTO_difflinker_1584,7JTO,MS33,Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O...,CCCCCC(=O)N[C@H](C)CCCC,Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O...,CN1CCN(c2ccc(-c3cccc(CN4CCNCC4)c3)cc2NC(=O)C2=...,Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O...,CN1CCN(c2ccc(-c3cccc(CN4CCN([*:2])CC4)c3)cc2NC...,WDR5,...,selected_min_conf2__7JTO_difflinker_1584.sdf,0.404619,0.547204,0.757546,-18.05297,1.250658,1.151019,0.998622,3.292739,3.470201
2,7JTO_difflinker_3051,7JTO,MS33,Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O...,CCOCCC[C@@H]1CCC[C@H]1CC=O,Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O...,CN1CCN(c2ccc(-c3cccc(CN4CCNCC4)c3)cc2NC(=O)C2=...,Cc1ncsc1-c1ccc(CNC(=O)[C@@H]2C[C@@H](O)CN2C(=O...,CN1CCN(c2ccc(-c3cccc(CN4CCN([*:2])CC4)c3)cc2NC...,WDR5,...,selected_min_conf1__7JTO_difflinker_3051.sdf,0.399479,0.538391,0.775668,-18.3847,1.204676,1.194152,0.973136,2.414808,2.486691


### Get constrained embedded conformers

In [15]:
pose_folders = os.path.join(poses_root, 'const_embed')
const_embed_path = os.path.join(pose_folders, f'{pdb}_{method}_embed')
const_embed_pose_folder = glob.glob(os.path.join(const_embed_path,'selected_min_confs'))
print(const_embed_pose_folder)

['/home/rebeccaneeser/data/protacs/preprint/const_embed/7JTO_difflinker_embed/selected_min_confs']


In [16]:
def get_embedded_confs(sdfpath: str,) -> Chem.Mol:
    if sdfpath is np.nan:
        return None
    sdfpath = os.path.join(const_embed_pose_folder[0], sdfpath)
    mol = Chem.SDMolSupplier(sdfpath)[0]
    return mol

In [17]:
# store embedded conformers and vinardo scores in dataframe using apply
df['embedded_mol'] = df.progress_apply(lambda x: get_embedded_confs(x['embedded_path']), axis=1)

  0%|          | 0/3 [00:00<?, ?it/s]

In [18]:
len(df[~df['embedded_mol'].isna()]), len(df)

(3, 3)

### Get poses respective to method (shape aligned or generated)

In [19]:
def get_shape_aligned_pose(sdffolder: str, mol_id: str) -> Chem.Mol:
    try:
        molpath = os.path.join(sdffolder, f'{mol_id}_pose.sdf')
        mol = Chem.SDMolSupplier(molpath)[0]
    except:
        try:
            molpath = os.path.join(sdffolder, f'{mol_id}_pose.mol')
            mol = Chem.MolFromMolFile(molpath)
        except:
            return None
    return mol

def get_diff_gen_pose(folder: str, sdfpath: str) -> Chem.Mol:
    if sdfpath is np.nan or sdfpath is None:
        return None
    sdffile = sdfpath.split('/')[-1]
    sdfpath_new = os.path.join(folder, pdb, sdffile)
    mol = Chem.SDMolSupplier(sdfpath_new)[0]
    return mol

In [20]:
if method == 'difflinker':
    df_all = pd.read_csv(os.path.join(gen_folder, f'{pdb}_sampled_{method}_valid.csv'))
    gen_folder_path = os.path.join(poses_root, 'difflinker_gen_confs')
    df_all['method_mol'] = df_all.progress_apply(lambda x: get_diff_gen_pose(gen_folder_path, x['ori_gen_ptc_filename']), axis=1)

    col_insert = ['E_torsion',
    'clashes_cutoff',
    'clashes_vdw',
    'embedded_mol',
    'embedded_path',
    'rmsd_anc',
    'rmsd_wrh',
    'sc_rdkit',
    'vinardo']
    for col in col_insert:
        # transfer columns from df_all to df by mapping ID
        df_all[col] = df_all['ID'].map(df.set_index('ID')[col])
    df = df_all
else:
    df_all = pd.read_csv(os.path.join(gen_folder, f'{pdb}_sampled_{method}_valid.csv'))
    shape_align_path = os.path.join(gen_folder, f'{pdb}_{method}_aligned_poses')
    df_all['method_mol'] = df_all.progress_apply(lambda x: get_shape_aligned_pose(shape_align_path, x['ID']), axis=1)
    col_insert = ['E_torsion',
    'clashes_cutoff',
    'clashes_vdw',
    'embedded_mol',
    'embedded_path',
    'rmsd_anc',
    'rmsd_wrh',
    'sc_rdkit',
    'vinardo']
    for col in col_insert:
        # transfer columns from df_all to df by mapping ID
        df_all[col] = df_all['ID'].map(df.set_index('ID')[col])
    df = df_all

  0%|          | 0/3 [00:00<?, ?it/s]

In [21]:
len(df), len(df[~df['method_mol'].isna()]), len(df[~df['embedded_mol'].isna()])

(3, 3, 3)

### Get xtal poses

In [22]:
xtal_folder = 'data/xtal_poses'
pdb_folder = os.path.join(xtal_folder, pdb, f'{pdb}_fragments')
xtal_protein_path = os.path.join(pdb_folder, f'{pdb}_protein.pdb')
# xtal_ext_linker = Chem.MolFromMolFile(os.path.join(pdb_folder,f'{pdb}_linker_extended.sdf'))
xtal_protac = Chem.MolFromMolFile(os.path.join(pdb_folder, f'{pdb}_protac.sdf'))
xtal_linker = Chem.MolFromMolFile(os.path.join(pdb_folder, f'{pdb}_linker.sdf'))
xtal_anchor = Chem.MolFromMolFile(os.path.join(pdb_folder, f'{pdb}_anchor.sdf'))
xtal_warhead = Chem.MolFromMolFile(os.path.join(pdb_folder, f'{pdb}_warhead.sdf'))

### Get linker only from poses

In [23]:
def linker_from_extlinker(aligned_ext_linker, protac_smi, wrh_smi, anc_smi, linker_smi):
    if aligned_ext_linker is None:
        return None
    protac_mol = Chem.MolFromSmiles(protac_smi)
    wrh_mol = Chem.MolFromSmiles(wrh_smi)
    anc_mol = Chem.MolFromSmiles(anc_smi)
    linker_mol = Chem.MolFromSmiles(linker_smi)
    try:
        linker_mol = Chem.RemoveAllHs(linker_mol)
    except:
        return None
    match_ext_link= protac_mol.GetSubstructMatch(aligned_ext_linker)
    match_wrh = protac_mol.GetSubstructMatch(wrh_mol)
    match_anc = protac_mol.GetSubstructMatch(anc_mol)
    # get overlaps
    overlap_wrh = set(match_wrh).intersection(set(match_ext_link))
    overlap_anc = set(match_anc).intersection(set(match_ext_link))
    num_frag_wrh = len(overlap_wrh)
    num_frag_anc = len(overlap_anc)
    matches_linker = aligned_ext_linker.GetSubstructMatches(linker_mol)
    true_match = None
    for match in matches_linker:
    # sort match in reverse
        match = sorted(match, reverse=True)
        try:
            ext_frags = remove_atom_indices(aligned_ext_linker, match)
            # get individual fragments
            ext_frags = Chem.GetMolFrags(ext_frags, asMols=True)
            # check if number of atoms in ext_frags matches num_frag_wrh and num_frag_anc
            if len(ext_frags[0].GetAtoms()) == num_frag_wrh and len(ext_frags[1].GetAtoms()) == num_frag_anc:
                true_match = match
                break
            elif len(ext_frags[1].GetAtoms()) == num_frag_wrh and len(ext_frags[0].GetAtoms()) == num_frag_anc:
                true_match = match
                break
        except:
            continue
    if true_match is None:
        print('No match found')
        return None
    all_ext_lin_indices = [a.GetIdx() for a in aligned_ext_linker.GetAtoms()]
    to_keep = list(set(all_ext_lin_indices).intersection(set(true_match)))
    to_remove = list(set(all_ext_lin_indices).difference(set(true_match)))
    ev_indices = []
    for a in aligned_ext_linker.GetAtoms():
        a.SetAtomMapNum(a.GetIdx()+1)
    for b in aligned_ext_linker.GetBonds():
        if b.GetBeginAtomIdx() in to_remove and b.GetEndAtomIdx() in to_keep:
            ev_indices.append(b.GetBeginAtomIdx())
            to_remove.remove(b.GetBeginAtomIdx())
        elif b.GetEndAtomIdx() in to_remove and b.GetBeginAtomIdx() in to_keep:
            ev_indices.append(b.GetEndAtomIdx())
            to_remove.remove(b.GetEndAtomIdx())
    # sort in reverse
    to_remove = sorted(to_remove, reverse=True)
    linker_pose_aligned = remove_atom_indices(aligned_ext_linker, to_remove)
    a_dict = {a.GetAtomMapNum()-1: a.GetIdx() for a in linker_pose_aligned.GetAtoms()}
    ev_indices = [a_dict[i] for i in ev_indices]
    ev_indices = sorted(ev_indices, reverse=True)
    linker_pose_aligned = replace_atom_indices(linker_pose_aligned, ev_indices)
    for a in linker_pose_aligned.GetAtoms():
        a.SetAtomMapNum(0)
    for a in aligned_ext_linker.GetAtoms():
        a.SetAtomMapNum(0)
    try:
        linker_pose_aligned = Chem.RemoveHs(linker_pose_aligned)
    except:
        return linker_pose_aligned
    return linker_pose_aligned


In [24]:
if method != 'difflinker':
    df['method_linker_mol'] = df.progress_apply(lambda x: linker_from_extlinker(x['method_mol'], x['protac_smiles'], x['warhead_smiles'], x['anchor_smiles'], x['linker_smiles']), axis=1)

In [25]:
def linker_from_fullpose(protac_pose, wrh_smi, anc_smi):
    if protac_pose is None or protac_pose is np.nan:
        return None
    wrh_mol = Chem.MolFromSmiles(wrh_smi)
    anc_mol = Chem.MolFromSmiles(anc_smi)
    match_wrh = protac_pose.GetSubstructMatch(wrh_mol)
    match_anc = protac_pose.GetSubstructMatch(anc_mol)
    # sort match in reverse
    all_ptc_indices = [a.GetIdx() for a in protac_pose.GetAtoms()]
    to_remove = match_wrh + match_anc
    to_remove = list(set(to_remove))
    to_keep = list(set(all_ptc_indices) -set(to_remove))
    ev_indices = []
    for a in protac_pose.GetAtoms():
        a.SetAtomMapNum(a.GetIdx()+1)
    for b in protac_pose.GetBonds():
        if b.GetBeginAtomIdx() in to_remove and b.GetEndAtomIdx() in to_keep:
            ev_indices.append(b.GetBeginAtomIdx())
            to_remove.remove(b.GetBeginAtomIdx())
        elif b.GetEndAtomIdx() in to_remove and b.GetBeginAtomIdx() in to_keep:
            ev_indices.append(b.GetEndAtomIdx())
            to_remove.remove(b.GetEndAtomIdx())
    to_remove = sorted(to_remove, reverse=True)
    embedded_linker = remove_atom_indices(protac_pose, to_remove)
    a_dict = {a.GetAtomMapNum()-1: a.GetIdx() for a in embedded_linker.GetAtoms()}
    ev_indices = [a_dict[i] for i in ev_indices]
    ev_indices = sorted(ev_indices, reverse=True)
    embedded_linker = replace_atom_indices(embedded_linker, ev_indices)
    for a in embedded_linker.GetAtoms():
        a.SetAtomMapNum(0)
    for a in protac_pose.GetAtoms():
        a.SetAtomMapNum(0)
    # embedded_linker = Chem.RemoveHs(embedded_linker)
    return embedded_linker

In [26]:
df['embedded_linker'] = df.progress_apply(lambda x: linker_from_fullpose(x['embedded_mol'], x['warhead_smiles'], x['anchor_smiles']), axis=1)

  0%|          | 0/3 [00:00<?, ?it/s]

In [27]:
if method == 'difflinker':
    df['method_linker_mol'] = df.progress_apply(lambda x: linker_from_fullpose(x['method_mol'], x['warhead_smiles'], x['anchor_smiles']), axis=1)

  0%|          | 0/3 [00:00<?, ?it/s]

## Load model

In [28]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
model = torch.load("/home/rebeccaneeser/data/protacs/preprint/shape_align/protacdb_extlinker_model_align.pth", map_location=device)
model.to(device)
model.eval()

INFO - 2023-02-17 00:01:02,591 - instantiator - Created a temporary directory at /tmp/tmpm8zwc8ix
INFO - 2023-02-17 00:01:02,599 - instantiator - Writing /tmp/tmpm8zwc8ix/_remote_module_non_sriptable.py


PCRSingleMasked(
  (coarse): PCRBaseMasked(
    (lin_in): Linear(in_features=3, out_features=16, bias=True)
    (attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=16, out_features=16, bias=True)
    )
    (cross_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=16, out_features=16, bias=True)
    )
    (lin_out): Linear(in_features=16, out_features=3, bias=True)
  )
)

## Metrics

### Chamfer distance between new PROTAC conf and xtal conf
New PROTAC conf is embedded and/or directly generted (in case of difflinker)

Use to calculate also similarity ratio after reverse min max scaling over all samples.

In [29]:
query_protac_xtal  = MoleculeInfo.from_sdf(os.path.join(pdb_folder, f'{pdb}_protac.sdf'))
df['cd_protac_embed2xtal'] = df.progress_apply(lambda x: query_protac_xtal.get_chamfer_distance(MoleculeInfo.from_rdkit_mol(x['embedded_mol']), device=device) \
                                               if x['embedded_mol'] is not None and x['embedded_mol'] is not np.nan else None, axis=1)
print('Chamfer distance between xtal and const embed: ', df['cd_protac_embed2xtal'].mean(skipna=True))

  0%|          | 0/3 [00:00<?, ?it/s]

Chamfer distance between xtal and const embed:  1.2042349576950073


In [30]:
query_linker_xtal  = MoleculeInfo.from_sdf(os.path.join(pdb_folder, f'{pdb}_linker.sdf'))
df['cd_linker_embed2xtal'] = df.progress_apply(lambda x: query_linker_xtal.get_chamfer_distance(MoleculeInfo.from_rdkit_mol(x['embedded_linker']), device=device) \
                                               if x['embedded_linker'] is not None and x['embedded_linker'] is not np.nan else None, axis=1)
print('Chamfer distance between xtal and const embed linker: ', df['cd_linker_embed2xtal'].mean(skipna=True))

  0%|          | 0/3 [00:00<?, ?it/s]

Chamfer distance between xtal and const embed linker:  2.3114216327667236


In [31]:
if method == 'difflinker':
    df['cd_protac_method2xtal'] = df.progress_apply(lambda x: query_protac_xtal.get_chamfer_distance(MoleculeInfo.from_rdkit_mol(x['method_mol']), device=device) \
                                               if x['method_mol'] is not None and x['method_mol'] is not np.nan else None, axis=1)
    print('Chamfer distance between xtal and generated pose: ', df['cd_protac_method2xtal'].mean(skipna=True))

    df['cd_linker_method2xtal'] = df.progress_apply(lambda x: query_linker_xtal.get_chamfer_distance(MoleculeInfo.from_rdkit_mol(x['method_linker_mol']), device=device) \
                                               if x['method_linker_mol'] is not None and x['method_linker_mol'] is not np.nan else None, axis=1)
    print('Chamfer distance between xtal and generated linker pose: ', df['cd_linker_method2xtal'].mean(skipna=True))

  0%|          | 0/3 [00:00<?, ?it/s]

Chamfer distance between xtal and generated pose:  1.1533020337422688


  0%|          | 0/3 [00:00<?, ?it/s]

Chamfer distance between xtal and generated linker pose:  2.2213172117869058


### Tanimoto similarity between new PROTAC and xtal reference

In [32]:
# max tanimoto similarity to reference linker
PandasTools.AddMoleculeColumnToFrame(df, smilesCol='protac_smiles', molCol='mol_smi')
df.loc[:,'fp_ptc'] = df.mol_smi.progress_apply(Chem.RDKFingerprint)
ori_ptc_fp = Chem.RDKFingerprint(xtal_protac)
df['tanimoto_ptc'] = df.fp_ptc.progress_apply(lambda x: DataStructs.FingerprintSimilarity(x, ori_ptc_fp))

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

### Chamfer distance between methods
* Link-INVENT: between const. embed and shape aligned
* DiffLinker: between const. embed and generated pose

Based on: linker only

In [33]:
df['cd_aligned_linker_embed2aligned'] = df.progress_apply(lambda x: MoleculeInfo.from_rdkit_mol(x['embedded_linker']).align_to_molecules2\
                                                          (Molecules.from_molecule_info([MoleculeInfo.from_rdkit_mol(x['method_linker_mol'])]), model).chamfer_distance \
                                                          if x['embedded_linker'] is not None and x['embedded_linker'] is not np.nan and \
                                                             x['method_linker_mol'] is not None and x['method_linker_mol'] is not np.nan else np.nan, axis=1)
print(df['cd_aligned_linker_embed2aligned'].mean(skipna=True))

  0%|          | 0/3 [00:00<?, ?it/s]

3.851271947224935


In [34]:
def fraction2threshold(df, threshold, column, above=True):
    '''
    Returns the fraction of molecules in a dataframe with a score above/below a given threshold
    '''
    df = df[~df[column].isna()]
    if above:
        df_above = df[df[column] >= threshold]
    else:
        df_above = df[df[column] <= threshold]
    metric = len(df_above)/len(df)*100
    return df_above, metric

In [35]:
df_scored_fil, cutoff_fraction = fraction2threshold(df, 3.5, 'cd_aligned_linker_embed2aligned', above=False)
cutoff_fraction

33.33333333333333

In [36]:
df_scored_fil, cutoff_fraction = fraction2threshold(df, 2.0, 'cd_aligned_linker_embed2aligned', above=False)
cutoff_fraction

0.0

In [37]:
df_scored_fil, cutoff_fraction = fraction2threshold(df, 1.0, 'cd_aligned_linker_embed2aligned', above=False)
cutoff_fraction

0.0

### Save data

In [42]:
df.columns

Index(['ID', 'reference', 'lig_id', 'protac_smiles', 'linker_smiles',
       'anchor_smiles', 'warhead_smiles', 'anchor_ev', 'warhead_ev', 'POI',
       'E3', 'gen_filename', 'frags', 'tanimoto', 'qed_linker', 'sa_linker',
       'num_rings_linker', 'num_rot_bonds_linker', 'branched', 'PAINS',
       'ring_arom', 'ori_E_torsion', 'ori_clashes_cutoff', 'ori_clashes_vdw',
       'ori_gen_ptc_filename', 'ori_sc_rdkit', 'to_3d', 'method_mol',
       'E_torsion', 'clashes_cutoff', 'clashes_vdw', 'embedded_mol',
       'embedded_path', 'rmsd_anc', 'rmsd_wrh', 'sc_rdkit', 'vinardo',
       'embedded_linker', 'method_linker_mol', 'cd_protac_embed2xtal',
       'cd_linker_embed2xtal', 'cd_protac_method2xtal',
       'cd_linker_method2xtal', 'mol_smi', 'fp_ptc', 'tanimoto_ptc',
       'cd_aligned_linker_embed2aligned'],
      dtype='object')

In [43]:
df_save = df.drop(columns=['method_mol', 'embedded_mol', 'method_linker_mol', 'embedded_linker', 'mol_smi', 'fp_ptc'])

df_save.to_csv(os.path.join(gen_folder, f'{pdb}_sampled_{method}_valid.csv'), index=False)
df_fil_save = df_save[df_save['to_3d']]
df_fil_save.to_csv(os.path.join(gen_folder, f'{pdb}_sampled_{method}_valid_fil.csv'), index=False)