# Shape alignment
Get chamfer distance between surfces of generated molecules and crystal structure molecule as query. This is mainly done to
1. How well the RL worked
2. Get the stereochemistry fitting the shape best

Recommendations:
* Use a GPU
* Envionment to use: `shape_align`

In [None]:
import os
import sys

import torch
from tqdm.notebook import tqdm
import numpy as np
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import AllChem, PandasTools
from rdkit import RDLogger
import copy
from unidip import UniDip

old_cwd = Path.cwd()
os.chdir(Path.cwd().parent)
from structural import models, molecule
from structural.molecule import Molecules, MoleculeInfo
os.chdir(old_cwd)

sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))
from utils.chem_transforms import set_stereo2query

In [None]:
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

## Load Data

In [None]:
pdb = '6BOY'
method = 'shape'

### Load data to score

In [None]:
gen_folder = 'data/generated'
df = pd.read_csv(os.path.join(gen_folder, f'{pdb}_sampled_{method}_valid.csv'))
df.head()

In [None]:
# drop duplicates and get smiles and associated ids
df_dedupl = df.drop_duplicates(subset=['extended_linker_smiles'])
gen_smiles = df_dedupl['extended_linker_smiles'].values.tolist()
gen_ids = df_dedupl['ID'].values.tolist()

### Load model

In [None]:
root = os.path.dirname(os.path.dirname(os.getcwd()))
model = torch.load(os.path.join(root, "models/protacdb_extlinker_model_align.pth"))
if torch.cuda.is_available():
    device = torch.device("cuda")
    model.to(device)
else:
    device = torch.device("cpu")
model.eval()

### Align and save pose
`Define paths and SMILES of query`

In [None]:
xtal_folder = 'data/xtal_poses'
pdb_folder = os.path.join(xtal_folder, pdb, f'{pdb}_fragments')
query_path = os.path.join(pdb_folder,f'{pdb}_linker_extended.sdf')
query_block = Chem.MolToMolBlock(Chem.SDMolSupplier(query_path)[0])
xtal_ext_linker = Chem.MolFromMolFile(query_path)
query = MoleculeInfo.from_molblock(query_block)
xtal_protac = Chem.MolFromMolFile(os.path.join(pdb_folder, f'{pdb}_protac.sdf'))
xtal_linker = Chem.MolFromMolFile(os.path.join(pdb_folder, f'{pdb}_linker.sdf'))

In [None]:
class AlignCompare():
    def __init__(self, 
        df,
        model,
        query, 
        xtal_protac,
        xtal_linker,
        xtal_ext_linker,
        pose_folder):
        self.df = df
        self.model = model
        self.query = query
        self.xtal_protac = xtal_protac
        self.xtal_linker = xtal_linker
        self.xtal_ext_linker = xtal_ext_linker
        self.pose_folder = pose_folder


    def get_attachment_frags_protac(self, useChirality=True):
        '''
        Returns the all potential indices of the attachment fragments next to the linker in a given molecule.
        '''
        match_indices_ext_linker = self.xtal_protac.GetSubstructMatches(self.xtal_ext_linker, useChirality=useChirality)
        match_indices_linker = self.xtal_protac.GetSubstructMatches(self.xtal_linker, useChirality=useChirality)
        assert match_indices_ext_linker, f'no match found for: {Chem.MolToSmiles(self.xtal_ext_linker)}'
        assert match_indices_linker, f'no match found for: {Chem.MolToSmiles(self.xtal_linker)}'
        frag_indices_combo = []
        for match_idx_ext in match_indices_ext_linker:
            for match_idx_lin in match_indices_linker:
                    frag_indices_combo.append([idx for idx in match_idx_ext if idx not in match_idx_lin])
            
        return frag_indices_combo

    def get_attachment_frags_linker(self, mol, linker):
        '''
        Returns the attachment points of a linker in a given molecule
        '''
        linker = Chem.RemoveAllHs(linker)
        match_indices = mol.GetSubstructMatches(linker, useChirality=True)
        assert match_indices, f'no match found for: {Chem.MolToSmiles(linker)}'
        frag_indices_combo = []
        for match_idx in match_indices:
            all_atom_idx = [a.GetIdx() for a in mol.GetAtoms()]
            frag_indices_combo.append([idx for idx in all_atom_idx if idx not in match_idx])

        return frag_indices_combo

    def get_frags(self, mol, indices):
        all_atoms = [a.GetIdx() for a in mol.GetAtoms()]
        remove_atoms = [a for a in all_atoms if a not in indices]
        # sort remove_atoms
        remove_atoms = sorted(remove_atoms, reverse=True)
        Chem.Kekulize(mol)
        mol_red = Chem.RWMol(mol)    
        for idx in remove_atoms:
            mol_red.RemoveAtom(idx)
        mol_red = mol_red.GetMol()
        return mol_red

    def get_correct_indices(self, ext_linker, indices_protac, indices_linker):
        for a in self.xtal_protac.GetAtoms():
            a.SetAtomMapNum(a.GetIdx()+1)
        for idx1 in indices_linker:
            ext_linker_copy = copy.deepcopy(ext_linker)
            frag_linker = self.get_frags(ext_linker_copy, idx1)
            for idx2 in indices_protac:
                protac_copy = copy.deepcopy(self.xtal_protac)
                frag_protac = self.get_frags(protac_copy, idx2)
                if frag_protac.HasSubstructMatch(frag_linker) and frag_linker.HasSubstructMatch(frag_protac):
                    mapnum_to_idx = {}
                    for a in frag_protac.GetAtoms():
                        mapnum_to_idx[a.GetIdx()] = a.GetAtomMapNum()-1
                    matches = frag_protac.GetSubstructMatch(frag_linker)
                    # get the map numbers of the atoms in the linker
                    mapnums = [mapnum_to_idx[idx] for idx in matches]
                    return idx1, mapnums

    def get_coords(self, mol, att_points):
        '''
        Returns the coordinates of the attachment points
        '''
        coords = []
        for idx in att_points:
            coords.append(mol.GetConformer().GetAtomPosition(idx))
        return coords

    def calc_rmsd(self, coords1, coords2):
        '''
        Calculates the RMSD between two sets of coordinates. 
        Need to be in matching order.
        '''
        rmsd = 0
        for i in range(len(coords1)):
            rmsd += (coords1[i].x - coords2[i].x)**2 + (coords1[i].y - coords2[i].y)**2 + (coords1[i].z - coords2[i].z)**2
        rmsd = np.sqrt(rmsd/len(coords1))
        
        return rmsd

    def calc_RMSD_att_frags(self, posepath, ID):
        '''
        Calculates the RMSD between the attachment points of a protac and an extended linker
        '''
        extended_linker_gen = Chem.MolFromMolFile(posepath)
        linker_smi = df[df.ID == ID]['linker_smiles'].values[0]
        linker_gen = Chem.MolFromSmiles(linker_smi)
        try:
            # get attachment points
            gen_atts = self.get_attachment_frags_linker(extended_linker_gen, linker_gen)
            ori_atts = self.get_attachment_frags_protac(useChirality=True)
            gen_att, ori_att = self.get_correct_indices(extended_linker_gen, ori_atts, gen_atts)
            assert len(gen_att) == len(ori_att)
            # get Coordinate positions of attachment points
            gen_coords = self.get_coords(extended_linker_gen, gen_att)
            ori_coords = self.get_coords(self.xtal_protac, ori_att)
            # calc RMSD between attachment points
            rmsd = self.calc_rmsd(gen_coords, ori_coords)
        except:
            rmsd = np.nan
        return rmsd

    def align_and_save(self, gen_smiles, gen_ids, smiles_distances):
        for smiles, gen_id in tqdm(zip(gen_smiles, gen_ids), total=len(gen_smiles)):
            try:
                alignment = self.query.align_to_multiconformer_smiles_fast2(smiles, self.model, device=device, number_of_conformers=50, es_weight=0)
            except:
                try:
                    alignment = self.query.align_to_multiconformer_smiles_fast2(smiles, self.model, device=device, number_of_conformers=50, es_weight=0, addhs_in_post=True)
                except:
                    alignment = None
            if alignment is None:
                cmf_dist, rmsd = np.nan, np.nan
            else:
                cmf_dist = alignment.chamfer_distance
                pose = alignment.molecule_2
                posepath = os.path.join(self.pose_folder, f'{gen_id}_pose.mol')
                pose.write_to_file(posepath)
                rmsd = self.calc_RMSD_att_frags(posepath, gen_id)
            smiles_distances[smiles] = (cmf_dist, rmsd)
        return smiles_distances

In [None]:
smiles_distances = {}
smiles2id = dict(zip(gen_smiles, gen_ids))
pose_folder = os.path.join(gen_folder, f'{pdb}_{method}_aligned_poses')
os.makedirs(pose_folder, exist_ok=True)
aligner = AlignCompare(df, model, query, xtal_protac, xtal_linker, xtal_ext_linker, pose_folder)
smiles_distances = aligner.align_and_save(gen_smiles, gen_ids, smiles_distances)
rmsds = [smiles_distances[smiles][1] for smiles in gen_smiles]
# remove nans
rmsds_sorted = np.msort(rmsds)
rmsds_sorted = rmsds_sorted[~np.isnan(rmsds_sorted)]
intervals = UniDip(rmsds_sorted).run()
try:
    split_point = (rmsds_sorted[intervals[0][1]] + rmsds_sorted[intervals[-1][0]]) / 2
except:
    split_point = (rmsds_sorted[intervals[0][0]] + rmsds_sorted[intervals[0][1]]) / 2
# get all smiles with an RMSD above the split point
gen_smiles_above = [smiles for smiles in gen_smiles if smiles_distances[smiles][1] > split_point]
gen_ids_above = [smiles2id[smiles] for smiles in gen_smiles_above]
while gen_ids_above:
    smiles_distances = aligner.align_and_save(gen_smiles_above, gen_ids_above, smiles_distances)
    gen_smiles_above = [smiles for smiles in gen_smiles_above if smiles_distances[smiles][1] > split_point and not np.isnan(smiles_distances[smiles][1])]
    gen_ids_above = [smiles2id[smiles] for smiles in gen_smiles_above]

In [None]:
# map the chamfer distances to the smiles in df
df['chamfer_distance'] = df['extended_linker_smiles'].map(lambda x: smiles_distances[x][0])
df['rmsd'] = df['extended_linker_smiles'].map(lambda x: smiles_distances[x][1])

### Set chiral tags based on aligned pose

In [None]:
extra_stereo = 0
ids = []
for i, row in df.iterrows():
    protac = row['protac_smiles']
    idx = row['ID']
    m = Chem.MolFromSmiles(protac)
    isomers = tuple(Chem.EnumerateStereoisomers.EnumerateStereoisomers(m))
    if len(isomers) > 2:
        extra_stereo += 1
        ids.append(idx)
len(df), extra_stereo

In [None]:
def get_linker_stereo(linker_ext: Chem.Mol, linker_unassigned: Chem.Mol) -> Chem.Mol:
    '''
    Returns a linker with stereocenters set based on query.
    :param linker_ext: extended linker pose
    :param linker_unassigned: unassigned linker molecule
    :return: linker with stereocenters set
    '''
    atom_idx_keep_all = linker_ext.GetSubstructMatches(linker_unassigned)
    for atom_idx_keep in atom_idx_keep_all:
        linker_ext_copy = copy.deepcopy(linker_ext)
        all_atom_idx = [atom.GetIdx() for atom in linker_ext_copy.GetAtoms()]
        atom_idx_remove = [idx for idx in all_atom_idx if idx not in atom_idx_keep]
        
        onestep_idx = []
        extra_atoms_loop = atom_idx_remove.copy()
        for a_idx in extra_atoms_loop:
            for b_idx in atom_idx_keep:
                bond = linker_ext_copy.GetBondBetweenAtoms(a_idx, b_idx)
                if bond is not None:
                    if bond.GetBondType() == Chem.rdchem.BondType.SINGLE:
                        onestep_idx.append(a_idx)

        atom_idx_remove = list(set(atom_idx_remove).difference(set(onestep_idx)))
        atom_idx_remove.sort(reverse=True)
        linker_ext_copy = Chem.RWMol(linker_ext_copy)

        for a in linker_ext_copy.GetAtoms():
                a.SetAtomMapNum(a.GetIdx())
        for extra_atom_idx in atom_idx_remove:
            linker_ext_copy.RemoveAtom(extra_atom_idx)
        for a in linker_ext_copy.GetAtoms():
            mapnum = a.GetAtomMapNum()
            if mapnum in onestep_idx:
                linker_ext_copy.ReplaceAtom(a.GetIdx(), Chem.Atom(1))
        try:
            linker_ext_copy = Chem.RemoveHs(linker_ext_copy)
            for a in linker_ext_copy.GetAtoms():
                a.SetAtomMapNum(0)
            return linker_ext_copy
        except:
            continue

In [None]:
stereo_linkers = {}
stereo_protacs = {}
unsuccesful = []
for ext_lin_smi, idx in tqdm(zip(gen_smiles, gen_ids), total=len(gen_smiles)):
    filepath = os.path.join(pose_folder, f'{idx}_pose.mol')
    row = df[df.ID == idx].iloc[0]
    if os.path.isfile(filepath):
        mol = Chem.MolFromMolFile(filepath)
        try:
            # necessary step to extract bond stereochemistry
            Chem.MolToPDBFile(mol, os.path.join(pose_folder, f'{idx}_pose.pdb'))
            mol = Chem.MolFromPDBFile(os.path.join(pose_folder, f'{idx}_pose.pdb'))
            Chem.rdmolops.DetectBondStereoChemistry(mol, mol.GetConformer())
            Chem.rdmolops.AssignStereochemistry(mol, cleanIt=True, force=True)
            # save sdf (for torsional strain input)
            sd_writer = Chem.SDWriter(os.path.join(pose_folder, f'{idx}_pose.sdf'))
            sd_writer.write(mol)
            try:
                stereo_linker_only = get_linker_stereo(mol, Chem.RemoveAllHs(Chem.MolFromSmiles(row['linker_smiles'])))
                stereo_lin_smi = Chem.MolToSmiles(stereo_linker_only)
                stereo_linkers[ext_lin_smi] = stereo_lin_smi
            except:
                stereo_linkers[ext_lin_smi] = row['linker_smiles']
            try:
                protac_mol = Chem.MolFromSmiles(row['protac_smiles'])
                stereo_protac = set_stereo2query(protac_mol, mol)
                stereo_protacs[ext_lin_smi] = Chem.MolToSmiles(stereo_protac)
            except:
                stereo_protacs[ext_lin_smi] = row['protac_smiles']
        except:
            unsuccesful.append(idx)
            print(f'Could not get stereo for {idx}')
            stereo_linkers[ext_lin_smi] = row['linker_smiles']
            stereo_protacs[ext_lin_smi] = row['protac_smiles']
    else:
        print(f'no file found: {idx}')
        stereo_linkers[ext_lin_smi] = row['linker_smiles']
        stereo_protacs[ext_lin_smi] = row['protac_smiles']

In [None]:
# map the stereo smiles to the smiles in df
df['linker_smiles'] = df['extended_linker_smiles'].map(lambda x: stereo_linkers[x])
df['protac_smiles'] = df['extended_linker_smiles'].map(lambda x: stereo_protacs[x])

In [None]:
extra_stereo = 0
ids = []
for i, row in df.iterrows():
    protac = row['protac_smiles']
    idx = row['ID']
    m = Chem.MolFromSmiles(protac)
    isomers = tuple(Chem.EnumerateStereoisomers.EnumerateStereoisomers(m))
    if len(isomers) > 2:
        extra_stereo += 1
        ids.append(idx)
len(df), extra_stereo, print(ids)

In [None]:
df.to_csv(os.path.join(gen_folder, f'{pdb}_sampled_{method}_valid.csv'), index=False)

## Analyze results

In [None]:
df = pd.read_csv(os.path.join(gen_folder, f'{pdb}_sampled_{method}_valid.csv'))
df.describe()

### Chamfer distance

In [None]:
df.chamfer_distance.hist(bins=100)

In [None]:
def fraction2threshold(df, threshold, column, above=True):
    '''
    Returns the fraction of molecules in a dataframe with a score above/below a given threshold
    '''
    if above:
        df_above = df[df[column] >= threshold]
    else:
        df_above = df[df[column] <= threshold]
    metric = len(df_above)/len(df)*100
    return df_above, metric

In [None]:
df_scored_fil, cutoff_fraction = fraction2threshold(df, 3.5, 'chamfer_distance', above=False)
cutoff_fraction

In [None]:
df_scored_fil, cutoff_fraction = fraction2threshold(df, 3.0, 'chamfer_distance', above=False)
cutoff_fraction

In [None]:
df_scored_fil, cutoff_fraction = fraction2threshold(df, 2.0, 'chamfer_distance', above=False)
cutoff_fraction

### RMSD

In [None]:
df.rmsd.hist(bins=100)

In [None]:
df_scored_fil, cutoff_fraction = fraction2threshold(df, 1.0, 'rmsd', above=False)
cutoff_fraction

In [None]:
df_scored_fil, cutoff_fraction = fraction2threshold(df, 2.0, 'rmsd', above=False)
cutoff_fraction

### Correlation chamfer distance to RMSD

In [None]:
# plot correlation between rmsd and chamfer distance
plt.scatter(df.rmsd, df.chamfer_distance, s=1);
plt.xlabel('RMSD');
plt.ylabel('Chamfer distance');