# Train a shape alignment model
Use the `shape_align` environment.

In [None]:
import sys
from pathlib import Path

import torch
from pytorch3d.loss import chamfer_distance
from structural.loss import chamfer_distance as cmf
from tqdm.notebook import tqdm
import numpy as np
from pytorch_lightning import Trainer
import pandas as pd
from unidip import UniDip
import unidip.dip as dip
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import RDLogger

In [None]:
# XXX add path to point_cloud_methods repository
shape_align_path = '/path/to/point_cloud_methods/repository'
sys.path.append(shape_align_path)

from structural import models, molecule
from structural.models import PCRSingleMasked, PCRSepFeat
from structural.molecule import Molecules, MoleculeInfo

In [None]:
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
data_folder = Path('data')
output_folder = Path.joinpath(data_folder, 'shape_align')

## Process data

In [None]:
df_protacdb = pd.read_csv(Path.joinpath(data_folder, 'protacdb_extended_linkers.csv'))
df_protacdb.head()

In [None]:
df_pdb = pd.read_csv(Path.joinpath(data_folder, 'pdb_systems_data.csv'))
df_pdb.head()

### Make sure case studies are included

In [None]:
investigated_sys = ['5T35', '7ZNT', '6HAY', '6HAX', '7S4E', '6BN7', '6BOY', '7JTP', '7Q2J', '7JTO']
df_pdb = df_pdb[df_pdb['PDB'].isin(investigated_sys)]
# check if extended linkers in protacdb are in pdb
df_missing = df_pdb[~df_pdb['linker_ext_smiles'].isin(df_protacdb['ext_linker_smiles'])]
df_missing

In [None]:
smiles = df_protacdb['ext_linker_smiles'].tolist()
smiles.extend(df_missing['linker_ext_smiles'].tolist())
# drop duplicates
smiles = list(set(smiles))
query_smiles = df_pdb.linker_ext_smiles.tolist()
query_indices = [smiles.index(query_smile) for query_smile in query_smiles]
len(smiles), len(query_smiles)

In [None]:
mols = [Chem.MolFromSmiles(smi) for smi in query_smiles]
Chem.Draw.MolsToGridImage(mols, molsPerRow=5, subImgSize=(300,200))

### Create training data

In [None]:
training_batches = []
for query_id in tqdm(query_indices, total=len(query_indices), desc='Self align queries'): # make data to learn self alignment
    for _ in tqdm(range(10), desc='Self align subsets'):
        rest = [query_id]*5
        stored = None
        count = 0
        batch_num = 16
        while stored is None:
            try:
                training_batches += MoleculeInfo.from_smiles(smiles[query_id]).get_training_batches([smiles[i] for i in rest], batch_num=2, batch_size=int(batch_num))
                stored = 1
            except ValueError:
                stored = 1
                continue
            except RuntimeError: # retrying can fix as dependent on conformer generation (stochastic)
                count += 1
                if count > 10:
                    batch_num = batch_num / 2
                    print(f'Self of index {query_id}: Reducing batch size to {batch_num}')
                    count = 0
                    if batch_num < 1:
                        print(f'Self of index {query_id}: Batch size too small, skipping')
                        stored = 1
                continue
        
for query_id in tqdm(query_indices, total=len(query_indices), desc='Others align queries'): # make data for query vs others alignments
    for _ in tqdm(range(10), desc='Others align subsets'):
        rest = np.random.choice(range(len(smiles)), 5)
        stored = None
        count = 0
        batch_num = 16
        while stored is None:
            try:
                training_batches += MoleculeInfo.from_smiles(smiles[query_id]).get_training_batches([smiles[i] for i in rest], batch_num=2, batch_size=int(batch_num))
                stored = 1
            except ValueError:
                stored = 1
                continue
            except RuntimeError: # retrying can fix as dependent on conformer generation (stochastic)
                count += 1
                if count > 20:
                    batch_num = batch_num / 2
                    print(f'Others of index {query_id}: Reducing batch size to {batch_num}')
                    count = 0
                    if batch_num < 1:
                        print(f'Others of index {query_id}: Batch size too small, skipping')
                        stored = 1
                continue

validation_batches = []

for query_id in tqdm(query_indices, total=len(query_indices), desc='Validation queries'): # make some validation batches (self vs others)
    for _ in tqdm(range(10), desc='Validation subsets'):
        rest = np.random.choice(range(len(smiles)), 1)
        stored = None
        count = 0
        batch_num = 16
        while stored is None:
            try:
                validation_batches += MoleculeInfo.from_smiles(smiles[query_id]).get_training_batches([smiles[i] for i in rest], batch_num=1, batch_size=int(batch_num))
                stored = 1
            except ValueError:
                stored = 1
                continue
            except RuntimeError: # retrying can fix as dependent on conformer generation (stochastic)
                count += 1
                if count > 10:
                    batch_num = batch_num / 2
                    print(f'Val of index {query_id}: Reducing batch size to {batch_num}')
                    count = 0
                    if batch_num < 1:
                        print(f'Val of index {query_id}: Batch size too small, skipping')
                        stored = 1
                continue

In [None]:
batch_filepath = Path.joinpath(output_folder, 'shape_align_batches.pth')
torch.save((training_batches, validation_batches), batch_filepath)

In [None]:
len(training_batches[0][0])

In [None]:
td = models.DataLoader(training_batches)
vd = models.DataLoader(validation_batches)

trainer = Trainer(accelerator='gpu', max_epochs=50)

## Train model

In [None]:
model = PCRSingleMasked(3, coarse_attention_dim=16, coarse_nheads=8, validation_data=validation_batches)
print("Average RANSAC distance:", model.validation_ransac_distance) # shows RANSAC alignment scores for validation
trainer.fit(model, td, vd) # "improvement over ransac" for validation should be above 1 as an indicator that it's performing well

In [None]:
model_filepath = Path.joinpath(output_folder, 'protacdb_extlinker_model_align.pth')
torch.save(model, model_filepath)

## Evaluate model

In [None]:
model_filepath = Path.joinpath(output_folder, 'protacdb_extlinker_model_align.pth')
model = torch.load(model_filepath)
model.to("cuda")
model.eval()

In [None]:
mols = [Chem.MolFromSmiles(smi) for smi in query_smiles]
Chem.Draw.MolsToGridImage(mols, molsPerRow=5, subImgSize=(300,200))

In [None]:
# get PDB ID per query_smiles from df_pdb
pdb_folderpaths = Path.joinpath(data_folder, 'protac_dataset', 'dataset')

query_pdb_ids = {}
for query_smile in query_smiles:
    PDB_id = df_pdb[df_pdb['linker_ext_smiles'] == query_smile]['PDB'].values[0]
    sdf_filepath = pdb_folderpaths / PDB_id / f'{PDB_id}_fragments' / f'{PDB_id}_linker_extended.sdf'
    query_pdb_ids[query_smile] = (PDB_id, sdf_filepath)

In [None]:
def get_coords(mol, atom_idxs):
    '''
    Returns the coordinates of the atom indices.
    '''
    coords = []
    for idx in atom_idxs:
        coords.append(mol.GetConformer().GetAtomPosition(idx))
    return coords

def calc_rmsd(coords1, coords2):
    '''
    Calculates the RMSD between two sets of coordinates. 
    Need to be in matching order.
    '''
    rmsd = 0
    for i in range(len(coords1)):
        rmsd += (coords1[i].x - coords2[i].x)**2 + (coords1[i].y - coords2[i].y)**2 + (coords1[i].z - coords2[i].z)**2
    rmsd = np.sqrt(rmsd/len(coords1))
    
    return rmsd

def get_rmsd(pose1_path, pose2_path):
    pose1 = Chem.MolFromMolFile(pose1_path)
    pose2 = Chem.MolFromMolFile(pose2_path)
    pose1_indices = pose1.GetSubstructMatch(pose2)
    pose2_indices = [a.GetIdx() for a in pose2.GetAtoms()]
    pose1_coords = get_coords(pose1, pose1_indices)
    pose2_coords = get_coords(pose2, pose2_indices)
    rmsd = calc_rmsd(pose1_coords, pose2_coords)
    return rmsd

def align_and_save(indices, scores, rmsds, pose_folder, query_pose, model, PDB_id, sdf_filepath):
    for i in tqdm(indices, desc='Self align repeats'):
        alignment = query_pose.align_to_multiconformer_smiles_fast2(query_smile, model, number_of_conformers=16, es_weight=0)
        scores[i] = alignment.chamfer_distance
        pose = alignment.molecule_2
        pose_path = Path.joinpath(pose_folder, f'{PDB_id}_{i}_pose.mol')
        pose.write_to_file(pose_path.as_posix())
        rmsds[i] = get_rmsd(sdf_filepath.as_posix(), pose_path.as_posix())
    return scores, rmsds

In [None]:
# align query to themselves n times each
n = 32
pose_folder = Path.joinpath(output_folder, 'poses_model_validation')
pose_folder.mkdir(parents=False, exist_ok=True)
query_self_alignments = {}
for query_smile in tqdm(query_smiles, total=len(query_smiles), desc='Self align queries'):
    PDB_id, sdf_filepath = query_pdb_ids[query_smile]
    query_pose = MoleculeInfo.from_sdf(sdf_filepath.as_posix())
    rpt_indices = [i for i in range(n)]
    scores = [np.nan]*n
    rmsds = [np.nan]*n
    scores, rmsds = align_and_save(rpt_indices, scores, rmsds, pose_folder, query_pose, model, PDB_id, sdf_filepath)
    rmsds_sorted = np.msort(rmsds)
    intervals = UniDip(rmsds_sorted).run()
    try:
        split_point = (rmsds_sorted[intervals[0][1]] + rmsds_sorted[intervals[1][0]]) / 2
    except:
        split_point = (rmsds_sorted[intervals[0][0]] + rmsds_sorted[intervals[0][1]]) / 2
    indices_fail = [i for i in range(n) if rmsds[i] > split_point]
    while indices_fail:
        scores, rmsds = align_and_save(indices_fail, scores, rmsds, pose_folder, query_pose, model, PDB_id, sdf_filepath)
        indices_fail = [i for i in range(n) if rmsds[i] > split_point]
    query_self_alignments[PDB_id] = (scores, rmsds)

In [None]:
query_self_pose_align = {}
for query_smile in tqdm(query_smiles, total=len(query_smiles), desc='Self align query poses'):
    PDB_id, sdf_filepath = query_pdb_ids[query_smile]
    query_pose = MoleculeInfo.from_sdf(sdf_filepath.as_posix())
    own_dist = query_pose.get_chamfer_distance(query_pose)
    query_self_pose_align[PDB_id] = own_dist

In [None]:
# iterate through dict
df_scored = pd.DataFrame(columns=['PDB', 'pose_id', 'chamfer_distance'])
for query_id, results in query_self_alignments.items():
    scores, rmsds = results
    query_id_repeat = [query_id] * len(scores)
    sub_ids = [i for i in range(len(scores))]
    # add to df
    df_sub = pd.DataFrame({'PDB': query_id_repeat, 'pose_id': sub_ids, 'chamfer_distance': scores, 'RMSD': rmsds})
    df_scored = pd.concat([df_scored, df_sub], ignore_index=True)

In [None]:
# add own distance to df
own_distances = []
for query_id in df_scored['PDB']:
    own_distances.append(query_self_pose_align[query_id])
df_scored['own_distance'] = own_distances

In [None]:
df_scored.head()

In [None]:
df_scored.to_csv(Path.joinpath(output_folder, 'align_model_val_self_alignments.csv'), index=False)

In [None]:
# get stats for all systems: mean, min, max
df_scored.groupby('PDB').chamfer_distance.agg(['mean', 'min', 'max'])

In [None]:
all_scores = list(query_self_alignments.values())
all_chamfer_distances = [scores for scores, rmsds in all_scores]
all_rmsds = [rmsds for scores, rmsds in all_scores]
all_PDB_ids = list(query_self_alignments.keys())

In [None]:
# make a color dictionary with PDB id as key and color as value
color_dict = {}
for i, PDB_id in enumerate(all_PDB_ids):
    # extract colors from the color map
    color = cm.jet(i/len(all_PDB_ids))
    color_dict[PDB_id] = color

In [None]:
fig, ax = plt.subplots(figsize=(12, 6));
for i, PDB_id in enumerate(all_PDB_ids):
    chamfer_sub = all_chamfer_distances[i]
    violin_parts = ax.violinplot(chamfer_sub, showmeans=True, showmedians=True, 
        widths=0.7, positions=[i], showextrema=False);
    violin_parts['cmedians'].set_color('black');
    violin_parts['cmedians'].set_linewidth(2);
    violin_parts['cmedians'].set_linestyle((0, (1,1)));
    violin_parts['cmeans'].set_color('black');
    violin_parts['cmedians'].set_linewidth(2)
    for pc in violin_parts['bodies']:
        pc.set_facecolor(color_dict[PDB_id])
        pc.set_edgecolor('black')
        pc.set_alpha(0.8)
violin_parts['cmedians'].set_label('median of random conformers');
violin_parts['cmeans'].set_label('mean of random conformers');
ax.set_xticks(range(10));
ax.set_xticklabels(all_PDB_ids, rotation=45);
ax.set_xlabel('corresponding PDB', fontsize=14);
ax.set_ylabel('chamfer distance', fontsize=14);
# add own distance
own_distances = [query_self_pose_align[PDB_id] for PDB_id in all_PDB_ids]
ax.scatter(range(len(all_chamfer_distances)), own_distances, color='darkred', label='conformer equal to query', s=5);
ax.legend();
plt.savefig(Path.joinpath(output_folder, 'shape_ailgn_val_self_align_violin.pdf'), bbox_inches='tight');

In [None]:
# show correlation between chamfer distance and RMSD
fig, ax = plt.subplots(figsize=(10, 6));
for PDB_id in all_PDB_ids:
    df_scored_sub = df_scored[df_scored['PDB'] == PDB_id]
    ax.scatter(df_scored_sub['chamfer_distance'], df_scored_sub['RMSD'], label=PDB_id, color=color_dict[PDB_id], s=20);
    ax.set_xlabel('chamfer distance');
    ax.set_ylabel('RMSD');
    ax.legend();
ax.set_xlabel('chamfer distance', fontsize=14);
ax.set_ylabel('RMSD', fontsize=14);
plt.legend(loc='lower right');
plt.savefig(Path.joinpath(output_folder, 'shape_ailgn_val_self_align2rmsd_scatter.pdf'), bbox_inches='tight');

In [None]:
# get number of rotational bonds per PDB
def get_nROT(df, PDB):
    smiles = df[df['PDB'] == PDB]['linker_ext_smiles'].values[0]
    mol = Chem.MolFromSmiles(smiles)
    nROT = Chem.rdMolDescriptors.CalcNumRotatableBonds(mol)
    return nROT
df_scored['nROT'] = df_scored['PDB'].apply(lambda x: get_nROT(df_pdb, x))

In [None]:
# plot nROT vs. mean chamfer distance
fig, ax = plt.subplots(figsize=(6, 5));
for PDB_id in all_PDB_ids:
    df_scored_sub = df_scored[df_scored['PDB'] == PDB_id]
    mean_chamfer = df_scored_sub['chamfer_distance'].mean()
    nROT = df_scored_sub['nROT'].values[0]
    ax.scatter(nROT, mean_chamfer, label=PDB_id, color=color_dict[PDB_id], s=30);
    ax.set_xlabel('nROT', fontsize=14);
    ax.set_ylabel('average chamfer distance', fontsize=14);
    ax.legend();
plt.savefig(Path.joinpath(output_folder, 'shape_ailgn_val_self_align_nROT_vs_mean_chamfer.pdf'), bbox_inches='tight');