# Train a shape alignment model
This is an example workflow to train a shape alignment model.
Use the `shape_align` environment.

In [None]:
import os
from pathlib import Path

import torch
from pytorch3d.loss import chamfer_distance
from tqdm.notebook import tqdm
import numpy as np
from pytorch_lightning import Trainer
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import RDLogger

In [None]:
old_cwd = Path.cwd()
os.chdir(Path.cwd().parent)
from shape_alignment import models, molecule
from shape_alignment.models import PCRSingleMasked, PCRSepFeat
from shape_alignment.molecule import Molecules, MoleculeInfo
from shape_alignment.loss import chamfer_distance as cmf

In [None]:
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
os.chdir(old_cwd)
data_folder = Path.cwd().parent / 'data'
# output folder results
output_folder = Path.cwd().parent / 'results'
output_folder = Path.joinpath(output_folder, 'toy_shape_align')
# create output folder if it does not exist
if not output_folder.exists():
    output_folder.mkdir(parents=True)

## Process data

In [None]:
df_protacdb = pd.read_csv(Path.joinpath(data_folder, 'protacdb_extended_linkers.csv'))
df_protacdb.head()

In [None]:
smiles = df_protacdb['linker_ext_smiles'].values
query_id = 0

### Create training data

In [None]:
# This cell is example to sample and batch pairs of molecules for training. 
# Here I sampled a single query vs others - the model outcome will be good at aligning the query molecule only. More molecules can be included to train a more general model

training_batches = []

for _ in tqdm(range(10)): # make data to learn self alignment
    rest = [0]*5
    try:
        training_batches += MoleculeInfo.from_smiles(smiles[query_id]).get_training_batches([smiles[i] for i in rest], batch_num=2, batch_size=16)
    except ValueError:
        continue

for _ in tqdm(range(10)): # make data for query vs others alignments
    rest = np.random.choice(range(len(smiles)), 5)
    try:
        training_batches += MoleculeInfo.from_smiles(smiles[query_id]).get_training_batches([smiles[i] for i in rest], batch_num=2, batch_size=16)
    except:
        continue

validation_batches = []

for _ in tqdm(range(10)): # make some validation batches (self vs others)
    rest = np.random.choice(range(len(smiles)), 1)
    try:
        validation_batches += MoleculeInfo.from_smiles(smiles[query_id]).get_training_batches([smiles[i] for i in rest], batch_num=1, batch_size=16)
    except:
        continue

In [None]:
batch_filepath = Path.joinpath(output_folder, 'shape_align_batches.pth')
torch.save((training_batches, validation_batches), batch_filepath)

In [None]:
td = models.DataLoader(training_batches)
vd = models.DataLoader(validation_batches)

trainer = Trainer(accelerator='gpu', max_epochs=50)

## Train model

In [None]:
model = PCRSingleMasked(3, coarse_attention_dim=16, coarse_nheads=8, validation_data=validation_batches)
print("Average RANSAC distance:", model.validation_ransac_distance) # shows RANSAC alignment scores for validation
trainer.fit(model, td, vd) # "improvement over ransac" for validation should be above 1 as an indicator that it's performing well

In [None]:
model_filepath = Path.joinpath(output_folder, 'model_align_toy.pth')
torch.save(model, model_filepath)

## Use model to align molecules

In [None]:
model_filepath = Path.joinpath(output_folder, 'model_align_toy.pth')
model = torch.load(model_filepath)
model.to("cuda")
model.eval()

In [None]:
# get pose of the query molecule
query_smile = smiles[query_id]
query_pose = MoleculeInfo.from_smiles(query_smile)

In [None]:
# align a SMILES string to the query molecule and save the pose
random_int = np.random.randint(0, len(smiles))
alignment = query_pose.align_to_multiconformer_smiles_fast2(smiles[random_int], model, number_of_conformers=50)
cmf_dist = alignment.chamfer_distance
pose = alignment.molecule_2
pose_path = Path.joinpath(output_folder, f'toy_pose.mol')
pose.write_to_file(pose_path.as_posix())

In [None]:
# one can also use a known conformer as query
sdf_filepath ='path/to/query.sdf'
query_pose = MoleculeInfo.from_sdf(sdf_filepath)
alignment = query_pose.align_to_multiconformer_smiles_fast2(smiles[random_int], model, number_of_conformers=50)
cmf_dist = alignment.chamfer_distance
pose = alignment.molecule_2
pose_path = Path.joinpath(output_folder, f'toy_pose.mol')
pose.write_to_file(pose_path.as_posix())