# GAABind Tutorial
This notebook is an example of how to use GAABind to predict the ligand binding conformation and binding affinity. 

## 1. Prepare the input dataset
First, you need to prepare the input dataset using the following format, :
```bash
1e66/
├── ligand.mol2
├── ligand.sdf
├── pocket.txt
└── receptor.pdb
```
The ligand file can be in .sdf, .mol2, or .mol format, or you can provide the ligand's SMILES representation in a .txt file. The pocket.txt contains the residues of binding pockets, each residue is named by the chain, residue number, and three-letter abbreviation as follows:
```bash
A_81_TRP
A_77_GLY
A_439_TYR
A_331_TYR
A_120_GLY
...
```

In [41]:
# define the input dataset path
input_path = './example_data/1e66'   #replace the path by your own dataset path

In [42]:
# find the coresponding path for the input ligand, protein and pocket information.
from glob import glob
import os

name = os.path.basename(input_path)
mol_file = glob(f'{input_path}/ligand.*')[0]
pro_file = f'{input_path}/receptor.pdb'
pocket_file = f'{input_path}/pocket.txt'
poc_res = open(pocket_file).read().splitlines()
print('using dataset path: ', mol_file, pro_file, pocket_file)
print('using the following reisudes as target pocket: ', ','.join(poc_res))

output_dir = './example_output' #replace the save path by your own


using dataset path:  ./example_data/1e66/ligand.sdf ./example_data/1e66/receptor.pdb ./example_data/1e66/pocket.txt
using the following reisudes as target pocket:  A_81_TRP,A_77_GLY,A_439_TYR,A_331_TYR,A_120_GLY,A_437_HIS,A_330_LEU,A_429_TRP,A_198_ALA,A_75_PHE,A_287_PHE,A_115_GLY,A_438_GLY,A_197_SER,A_118_TYR,A_328_PHE,A_72_PHE,A_285_PHE,A_441_ILE,A_124_LEU,A_114_GLY,A_117_PHE,A_327_PHE,A_69_ASP,A_82_ASN,A_119_SER,A_433_MET,A_436_ILE,A_78_SER,A_196_GLU,A_116_GLY,A_127_TYR


## 2. Preprocess the input dataset

In [43]:
# Load required libraries for data preprocess
import os
from tqdm import tqdm
import pickle
import argparse
from glob import glob
from multiprocessing import Pool

from data.feature_utils import get_ligand_info, get_protein_info, get_chem_feats, read_mol, get_coords

import warnings
warnings.filterwarnings('ignore')

In [44]:
## Generating the dataset features for model input

if not os.path.exists(output_dir):
    os.makedirs(output_dir)
output_path = os.path.join(output_dir, f'{name}.pkl')
new_data = {}
input_mol = read_mol(mol_file)
try:
    mol, smiles, coordinate_list = get_coords(input_mol)
except:
    print(f'generate input ligand coords failed for {name}')

lig_atoms, lig_atom_feats, lig_edges, lig_bonds = get_ligand_info(mol)
poc_pos, poc_atoms, poc_atom_feats, poc_edges, poc_bonds = get_protein_info(pro_file, poc_res)

new_data.update({'atoms': lig_atoms, 'coordinates': coordinate_list, 'pocket_atoms': poc_atoms,
                'pocket_coordinates': poc_pos, 'smi': smiles, 'pocket': name,'lig_feats': lig_atom_feats,
                'lig_bonds': lig_edges, 'lig_bonds_feats': lig_bonds, 'poc_feats': poc_atom_feats, 
                'poc_bonds': poc_edges, 'poc_bonds_feats': poc_bonds, 'mol': mol})

new_data = get_chem_feats(new_data)

f_out = open(output_path, 'wb')
pickle.dump(new_data, f_out)
f_out.close()

## 3. Binding conformation prediction using GAABind

In [45]:
# load required libraries for prediction
import torch
import pandas as pd
from utils import set_global_seed
from data.graph_dataset import DockingTestDataset
from data.collator import collator_test_3d
from option import set_args
from models.DockingPoseModel import DockingPoseModel
from docking.docking_utils import (
    docking_data_pre,
    ensemble_iterations,
)
from torch.utils.data import DataLoader

### load model and model weights

In [46]:
# load model
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
parser = set_args()
args = parser.parse_args(args=[])
set_global_seed(args.seed)
ckpt_path = './saved_model/best_epoch.pt'
state_dict = torch.load(ckpt_path, map_location='cpu')
new_state_dict = dict()
for key in state_dict.keys():
    layer_name = key[7:]
    new_state_dict[layer_name] = state_dict[key]

model = DockingPoseModel(args).to(device)
model.load_state_dict(new_state_dict)

<All keys matched successfully>

### load dataset and run inference

In [47]:
inference_save_path = os.path.join(output_dir, 'example_inference.pkl')   #define the save path of inference
test_dataset = DockingTestDataset(output_dir, args.conf_size)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False, collate_fn=collator_test_3d)

outputs = []
with torch.no_grad():
    model.eval()
    for batch in tqdm(test_dataloader):
        for dicts in batch[:2]:
            for key in dicts.keys():
                dicts[key] = dicts[key].to(device)

        with torch.cuda.amp.autocast():
            pred = model(batch)

        mol_token_atoms = batch[0]['x'][:,:,0]
        poc_token_atoms = batch[1]['x'][:,:,0]
        poc_coords = batch[1]['pos']

        logging_output = {}

        logging_output["smi_name"] = batch[2]['smi_list']
        logging_output["pocket_name"] = batch[2]['pocket_list']
        logging_output['mol'] = batch[2]['mol']
        logging_output["cross_distance_predict"] = pred[0].data.detach().cpu().permute(0, 2, 1)
        logging_output["holo_distance_predict"] = pred[1].data.detach().cpu()
        logging_output["atoms"] = mol_token_atoms.data.detach().cpu()
        logging_output["pocket_atoms"] = poc_token_atoms.data.detach().cpu()
        logging_output["holo_center_coordinates"] = batch[2]['holo_center_list']
        logging_output["pocket_coordinates"] = poc_coords.data.detach().cpu()
        logging_output['pred_affinity'] = pred[-1].data.detach().cpu()
        outputs.append(logging_output)

    pickle.dump(outputs, open(inference_save_path, "wb"))

100%|██████████| 10/10 [00:01<00:00,  7.31it/s]


### Get ligand binding pose by using the inference result

In [48]:
mol_list, smi_list, pocket_list, pocket_coords_list, distance_predict_list, holo_distance_predict_list,\
        holo_center_coords_list, pred_affi_list = docking_data_pre(inference_save_path)
iterations = ensemble_iterations(mol_list, smi_list, pocket_list, pocket_coords_list, distance_predict_list,\
                                     holo_distance_predict_list, holo_center_coords_list, pred_affi_list)

cache_dir = os.path.join(output_dir, "cache")
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, f'{name}.pkl')

pd.to_pickle(next(iterations), cache_file)

output_ligand_path = os.path.join(output_dir, name)
cmd = "python docking/coordinate_model.py --input {}  --output-path {}".format(cache_file, output_ligand_path)
os.system(cmd)
print(f'Prediction fininshed !!!!! You can find the result in the {output_ligand_path} directory')

Prediction fininshed !!!!! You can find the result in the ./example_output/1e66 directory


## 4. Visualize the prediction result

In [52]:
import nglview 
predicted_file = os.path.join(output_ligand_path, 'ligand.sdf')
view = nglview.show_file(nglview.FileStructure(pro_file), default=False)
view.add_representation('cartoon', selection='protein', color='white')

pred_lig = view.add_component(nglview.FileStructure(predicted_file), default=False)
pred_lig.add_ball_and_stick(color='red')

native = view.add_component(nglview.FileStructure(mol_file), default=False)
native.add_ball_and_stick(color='yellow', selection='not hydrogen')

view

NGLWidget()