# Get 3D metrics for the constrained embedded poses

In [None]:
import os
import glob
import sys
from typing import Union, Dict, Tuple, Optional, List

import numpy as np
import pandas as pd
from rdkit import Chem, RDLogger

sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))
from utils.calc_sc_rdkit import calc_SC_RDKit_score
from utils.metrics_3d import lig_protein_clash_dist, lig_protein_clash_vdw, calc_torsion_energy, mcs_rmsd

In [None]:
from tqdm._tqdm_notebook import tqdm_notebook
tqdm_notebook.pandas()

In [None]:
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

In [None]:
pdb = '7ZNT'
method = 'base'

## Load Data

### Generated data

In [None]:
gen_folder = 'data/generated'
filepath = glob.glob(os.path.join(gen_folder, f'{pdb}_sampled_{method}_valid_fil.csv'))
print(filepath)
df = pd.read_csv(filepath[0])
df.head()

#### Add constrained embedded pose to the data

In [None]:
const_embed_path = os.path.join(os.path.expanduser('~'), 'Documents', 'data', 'protacs', 'preprint_data', 'const_embed', f'{pdb}_{method}_embed')
const_embed_pose_folder = glob.glob(os.path.join(const_embed_path, 'selected_min_confs'))
print(const_embed_pose_folder)

In [None]:
def get_embedded_confs(mol_id: str, folderpath: str) -> pd.Series:
    sdfpath = glob.glob(os.path.join(folderpath, f'*__{mol_id}.sdf'))
    if sdfpath:
        filename = os.path.basename(sdfpath[0])
        sdfpath = sdfpath[0]
        mol = Chem.SDMolSupplier(sdfpath)[0]
        vinardo = float(mol.GetProp('minimizedAffinity'))
    else:
        mol = None
        vinardo = np.nan
        filename = None
    return pd.Series({'vinardo': vinardo, 'embedded_mol': mol, 'embedded_path': filename}) 

In [None]:
# store embedded conformers and vinardo scores in dataframe using apply
df[['vinardo', 'embedded_mol', 'embedded_path']] = df.progress_apply(lambda x: get_embedded_confs(x['ID'], const_embed_pose_folder[0]), axis=1)

In [None]:
df_fil = df[~df['embedded_mol'].isna()]
print(f'failed embedding: {(len(df)-len(df_fil))/len(df)*100:.4f}%')

In [None]:
len(df_fil)

### Xtal references

In [None]:
xtal_folder = 'data/xtal_poses'
pdb_folder = os.path.join(xtal_folder, pdb, f'{pdb}_fragments')
xtal_protein_path = os.path.join(pdb_folder, f'{pdb}_protein.pdb')
# xtal_ext_linker = Chem.MolFromMolFile(os.path.join(pdb_folder,f'{pdb}_linker_extended.sdf'))
xtal_protac = Chem.MolFromMolFile(os.path.join(pdb_folder, f'{pdb}_protac.sdf'))
# xtal_linker = Chem.MolFromMolFile(os.path.join(pdb_folder, f'{pdb}_linker.sdf'))
xtal_anchor = Chem.MolFromMolFile(os.path.join(pdb_folder, f'{pdb}_anchor.sdf'))
xtal_warhead = Chem.MolFromMolFile(os.path.join(pdb_folder, f'{pdb}_warhead.sdf'))

## Get 3D metrics

In [None]:
# rename ramds_anc to rmsd_anc and ramds_wrh to rmsd_wrh
df_fil.rename(columns={'ramds_anc': 'rmsd_anc', 'ramds_wrh': 'rmsd_wrh'}, inplace=True)
df.rename(columns={'ramds_anc': 'rmsd_anc', 'ramds_wrh': 'rmsd_wrh'}, inplace=True)

In [None]:
# only calculate if metrics are nan
if 'rmsd_anc' in df_fil.columns:
    df_calc = df_fil[df_fil['rmsd_wrh'].isna()]
else:
    df_calc = df_fil.copy()
len(df_calc)

### RMSD

In [None]:
df_calc['rmsd_anc'] = df_calc.apply(lambda x: mcs_rmsd(x['embedded_mol'], xtal_anchor), axis=1)
df_calc['rmsd_wrh'] = df_calc.apply(lambda x: mcs_rmsd(x['embedded_mol'], xtal_warhead), axis=1)
print(f'average anchor RMSD: {df_calc["rmsd_anc"].mean(skipna=True):.4f}')
print(f'average warhead RMSD: {df_calc["rmsd_wrh"].mean(skipna=True):.4f}')

### SC RDKit

In [None]:
df_calc['sc_rdkit'] = df_calc.apply(lambda x: calc_SC_RDKit_score(x.embedded_mol, xtal_protac), axis=1)
print(df_calc['sc_rdkit'].describe())

In [None]:
print(f'average SC_RDKIT: {df_calc["sc_rdkit"].mean(skipna=True):.4f}')

### Clashes with protein

In [None]:
df_calc['clashes_cutoff'] = df_calc.apply(lambda x: lig_protein_clash_dist(xtal_protein_path, os.path.join(const_embed_pose_folder[0], x.embedded_path)), axis=1)
df_calc['clashes_vdw'] = df_calc.apply(lambda x: lig_protein_clash_vdw(xtal_protein_path, os.path.join(const_embed_pose_folder[0], x.embedded_path)), axis=1)
print(f'average clashes_cutoff: {df_calc["clashes_cutoff"].mean(skipna=True):.4f}')
print(f'average clashes_vdw: {df_calc["clashes_vdw"].mean(skipna=True):.4f}')

### Torsion energy

In [None]:
df_calc['E_torsion'] = df_calc.apply(lambda x: calc_torsion_energy(os.path.join(const_embed_pose_folder[0], x.embedded_path)), axis=1)
print(f'average E_torsion: {df_calc["E_torsion"].mean(skipna=True):.4f}')

## Format output and combine results

In [None]:
# replace respective rows in df_fil with df_calc based on ID
df_fil.loc[df_calc.index] = df_calc

In [None]:
len(df_fil), len(df), len(df_calc)

In [None]:
df_fil.columns
# use df_fil if already calculated something before! (must have those rows so taht are replaced by df_calc)

In [None]:
# drop embedded_mol column
add_cols = ['E_torsion',
  'clashes_cutoff',
  'clashes_vdw',
  'embedded_mol',
  'embedded_path',
  'rmsd_anc',
  'rmsd_wrh',
  'sc_rdkit',
  'vinardo']
# add cols to val_to3d by mapping ID
for col in add_cols:
    df[col] = df['ID'].map(df_calc.set_index('ID')[col])
df = df.drop(columns=['embedded_mol'])

In [None]:
df.head()

In [None]:
len(df[df.rmsd_anc.isnull()]), len(df)

In [None]:
df.to_csv(os.path.join(gen_folder, f'{pdb}_sampled_{method}_valid_fil.csv'), index=False)

### Summary of metrics

In [None]:
print(f'average anchor RMSD: {df["rmsd_anc"].mean(skipna=True):.4f}')
print(f'average warhead RMSD: {df["rmsd_wrh"].mean(skipna=True):.4f}')
print(f'average SC_RDKIT: {df["sc_rdkit"].mean(skipna=True):.4f}')
print(f'average clashes_vdw: {df["clashes_vdw"].mean(skipna=True):.4f}')
print(f'average E_torsion: {df["E_torsion"].mean(skipna=True):.4f}')