In [5]:
import numpy as np
import os
from tqdm.auto import tqdm

import torch

from data import get_data
from torch_geometric.loader import DataLoader
from utils.metrics import *
from utils.utils import *
from datetime import datetime
from utils.logging_utils import Logger
import sys
import argparse
from torch.utils.data import RandomSampler
import random
from torch_scatter import scatter_mean
from utils.metrics_to_tsb import metrics_runtime_no_prefix
from torch.utils.tensorboard import SummaryWriter
# from torch.nn.utils import clip_grad_norm_
from accelerate import Accelerator
from accelerate import DistributedDataParallelKwargs
from accelerate.utils import set_seed
import plotly.graph_objects as go


In [6]:
import argparse

def parse_arguments(args_list):
    parser = argparse.ArgumentParser(description='FABind model training.')
    parser.add_argument("-m", "--mode", type=int, default=0,
                    help="mode specify the model to use.")
    parser.add_argument("-d", "--data", type=str, default="0",
                        help="data specify the data to use. \
                        0 for re-docking, 1 for self-docking.")
    parser.add_argument('--seed', type=int, default=42,
                        help="seed to use.")
    parser.add_argument("--gs-tau", type=float, default=1,
                        help="Tau for the temperature-based softmax.")
    parser.add_argument("--gs-hard", action='store_true', default=False,
                        help="Hard mode for gumbel softmax.")
    parser.add_argument("--batch_size", type=int, default=8,
                        help="batch size.")

    parser.add_argument("--restart", type=str, default=None,
                        help="continue the training from the model we saved from scratch.")
    parser.add_argument("--reload", type=str, default=None,
                        help="continue the training from the model we saved.")
    parser.add_argument("--addNoise", type=str, default=None,
                        help="shift the location of the pocket center in each training sample \
                        such that the protein pocket encloses a slightly different space.")

    pair_interaction_mask = parser.add_mutually_exclusive_group()
    # use_equivalent_native_y_mask is probably a better choice.
    pair_interaction_mask.add_argument("--use_y_mask", action='store_true', default=False,
                        help="mask the pair interaction during pair interaction loss evaluation based on data.real_y_mask. \
                        real_y_mask=True if it's the native pocket that ligand binds to.")
    pair_interaction_mask.add_argument("--use_equivalent_native_y_mask", action='store_true', default=False,
                        help="mask the pair interaction during pair interaction loss evaluation based on data.equivalent_native_y_mask. \
                        real_y_mask=True if most of the native interaction between ligand and protein happen inside this pocket.")

    parser.add_argument("--use_affinity_mask", type=int, default=0,
                        help="mask affinity in loss evaluation based on data.real_affinity_mask")
    parser.add_argument("--affinity_loss_mode", type=int, default=1,
                        help="define which affinity loss function to use.")

    parser.add_argument("--pred_dis", type=int, default=1,
                        help="pred distance map or predict contact map.")
    parser.add_argument("--posweight", type=int, default=8,
                        help="pos weight in pair contact loss, not useful if args.pred_dis=1")

    parser.add_argument("--relative_k", type=float, default=0.01,
                        help="adjust the strength of the affinity loss head relative to the pair interaction loss.")
    parser.add_argument("-r", "--relative_k_mode", type=int, default=0,
                        help="define how the relative_k changes over epochs")

    parser.add_argument("--resultFolder", type=str, default="./result",
                        help="information you want to keep a record.")
    parser.add_argument("--label", type=str, default="",
                        help="information you want to keep a record.")

    parser.add_argument("--use-whole-protein", action='store_true', default=False,
                        help="currently not used.")

    parser.add_argument("--data-path", type=str, default="/PDBbind_data/pdbbind2020",
                        help="Data path.")
                        
    parser.add_argument("--exp-name", type=str, default="",
                        help="data path.")

    parser.add_argument("--tqdm-interval", type=float, default=0.1,
                        help="tqdm bar update interval")

    parser.add_argument("--lr", type=float, default=0.0001)

    parser.add_argument("--pocket-coord-huber-delta", type=float, default=3.0)

    parser.add_argument("--coord-loss-function", type=str, default='SmoothL1', choices=['MSE', 'SmoothL1'])

    parser.add_argument("--coord-loss-weight", type=float, default=1.0)
    parser.add_argument("--pair-distance-loss-weight", type=float, default=1.0)
    parser.add_argument("--pair-distance-distill-loss-weight", type=float, default=1.0)
    parser.add_argument("--pocket-cls-loss-weight", type=float, default=1.0)
    parser.add_argument("--pocket-distance-loss-weight", type=float, default=0.05)
    parser.add_argument("--pocket-cls-loss-func", type=str, default='bce')

    # parser.add_argument("--warm-mae-thr", type=float, default=5.0)

    parser.add_argument("--use-compound-com-cls", action='store_true', default=False,
                        help="only use real pocket to run pocket classification task")

    parser.add_argument("--compound-coords-init-mode", type=str, default="pocket_center_rdkit",
                        choices=['pocket_center_rdkit', 'pocket_center', 'compound_center', 'perturb_3A', 'perturb_4A', 'perturb_5A', 'random'])

    parser.add_argument('--trig-layers', type=int, default=1)

    parser.add_argument('--distmap-pred', type=str, default='mlp',
                        choices=['mlp', 'trig'])
    parser.add_argument('--mean-layers', type=int, default=3)
    parser.add_argument('--n-iter', type=int, default=5)
    parser.add_argument('--inter-cutoff', type=float, default=10.0)
    parser.add_argument('--intra-cutoff', type=float, default=8.0)
    parser.add_argument('--refine', type=str, default='refine_coord',
                        choices=['stack', 'refine_coord'])

    parser.add_argument('--coordinate-scale', type=float, default=5.0)
    parser.add_argument('--geometry-reg-step-size', type=float, default=0.001)
    parser.add_argument('--lr-scheduler', type=str, default="constant", choices=['constant', 'poly_decay', 'cosine_decay', 'cosine_decay_restart', 'exp_decay'])

    parser.add_argument('--add-attn-pair-bias', action='store_true', default=False)
    parser.add_argument('--explicit-pair-embed', action='store_true', default=False)
    parser.add_argument('--opm', action='store_true', default=False)

    parser.add_argument('--add-cross-attn-layer', action='store_true', default=False)
    parser.add_argument('--rm-layernorm', action='store_true', default=False)
    parser.add_argument('--keep-trig-attn', action='store_true', default=False)

    parser.add_argument('--pocket-radius', type=float, default=20.0)

    parser.add_argument('--rm-LAS-constrained-optim', action='store_true', default=False)
    parser.add_argument('--rm-F-norm', action='store_true', default=False)
    parser.add_argument('--norm-type', type=str, default="all_sample", choices=['per_sample', '4_sample', 'all_sample'])

    # parser.add_argument("--only-predicted-pocket-mae-thr", type=float, default=3.0)
    parser.add_argument('--noise-for-predicted-pocket', type=float, default=5.0)
    parser.add_argument('--test-random-rotation', action='store_true', default=False)

    parser.add_argument('--random-n-iter', action='store_true', default=False)
    parser.add_argument('--clip-grad', action='store_true', default=False)

    # one batch actually contains 20000 samples, not the size of training set
    parser.add_argument("--sample-n", type=int, default=0, help="number of samples in one epoch.")

    parser.add_argument('--fix-pocket', action='store_true', default=False)
    parser.add_argument('--pocket-idx-no-noise', action='store_true', default=False)
    parser.add_argument('--ablation-no-attention', action='store_true', default=False)
    parser.add_argument('--ablation-no-attention-with-cross-attn', action='store_true', default=False)

    parser.add_argument('--redocking', action='store_true', default=False)
    parser.add_argument('--redocking-no-rotate', action='store_true', default=False)

    parser.add_argument("--pocket-pred-layers", type=int, default=1, help="number of layers for pocket pred model.")
    parser.add_argument('--pocket-pred-n-iter', type=int, default=1, help="number of iterations for pocket pred model.")

    parser.add_argument('--use-esm2-feat', action='store_true', default=False)
    parser.add_argument("--center-dist-threshold", type=float, default=8.0)

    parser.add_argument("--mixed-precision", type=str, default='no', choices=['no', 'fp16'])
    parser.add_argument('--disable-tqdm', action='store_true', default=False)
    parser.add_argument('--log-interval', type=int, default=100)
    parser.add_argument('--optim', type=str, default='adam', choices=['adam', 'adamw'])
    parser.add_argument("--warmup-epochs", type=int, default=15,
                        help="used in combination with relative_k_mode.")
    parser.add_argument("--total-epochs", type=int, default=400,
                        help="option to switch training data after certain epochs.")
    parser.add_argument('--disable-validate', action='store_true', default=False)
    parser.add_argument('--disable-tensorboard', action='store_true', default=False)
    parser.add_argument("--hidden-size", type=int, default=256)
    parser.add_argument("--weight-decay", type=float, default=0.0)
    parser.add_argument("--stage-prob", type=float, default=0.5)
    parser.add_argument("--pocket-pred-hidden-size", type=int, default=128)

    parser.add_argument("--local-eval", action='store_true', default=False)
    parser.add_argument("--train-ligand-torsion-noise", action='store_true', default=False)
    parser.add_argument("--train-pred-pocket-noise", type=float, default=0.0)
    parser.add_argument('--esm2-concat-raw', action='store_true', default=False)

    parser.add_argument("--dis-map-thres", type=float, default=10.0)
    parser.add_argument("--onlydocking-from-scratch", action='store_true', default=False)

    
    return parser.parse_args(args_list)


args = parse_arguments(['--batch_size', '5', '-d', '0', '-m', '5', '--data-path', '~/workspace/data/fabind', '--label', 'baseline', '--addNoise', '5', '--resultFolder', './results', '--use-compound-com-cls', '--total-epochs', '500', '--exp-name', 'fabind-onlydocking-from-scratch-dismap15', '--coord-loss-weight', '1.0', '--pair-distance-loss-weight', '1.0', '--pair-distance-distill-loss-weight', '1.0', '--pocket-cls-loss-weight', '1.0', '--pocket-distance-loss-weight', '0.05', '--lr', '5e-05', '--lr-scheduler', 'poly_decay', '--distmap-pred', 'mlp', '--n-iter', '8', '--mean-layers', '4', '--refine', 'refine_coord', '--coordinate-scale', '5', '--hidden-size', '512', '--geometry-reg-step-size', '0.001', '--rm-layernorm', '--add-attn-pair-bias', '--explicit-pair-embed', '--add-cross-attn-layer', '--noise-for-predicted-pocket', '0', '--clip-grad', '--random-n-iter', '--pocket-idx-no-noise', '--pocket-cls-loss-func', 'bce', '--use-esm2-feat', '--disable-validate', '--dis-map-thres', '15', '--onlydocking-from-scratch'])

In [7]:
args.stage_prob = 0.0
args.center_dist_threshold = 1000.0


ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs], mixed_precision=args.mixed_precision)
set_seed(args.seed)
# Seed_everything(seed=args.seed)
pre = f"{args.resultFolder}/{args.exp_name}"

if accelerator.is_main_process:
    os.system(f"mkdir -p {pre}/models")

    if not args.disable_tensorboard:
        tsb_runtime_dir = f"{pre}/tsb_runtime"
        os.system(f"mkdir -p {tsb_runtime_dir}")
        train_writer = SummaryWriter(log_dir=f'{tsb_runtime_dir}/train')
        valid_writer = SummaryWriter(log_dir=f'{tsb_runtime_dir}/valid')
        test_writer = SummaryWriter(log_dir=f'{tsb_runtime_dir}/test')
        test_writer_use_predicted_pocket = SummaryWriter(log_dir=f'{tsb_runtime_dir}/test_use_predicted_pocket')

accelerator.wait_for_everyone()

timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M")
logger = Logger(accelerator=accelerator, log_path=f'{pre}/{timestamp}.log')

logger.log_message(f"{' '.join(sys.argv)}")

# torch.set_num_threads(1)
# # ----------without this, I could get 'RuntimeError: received 0 items of ancdata'-----------
torch.multiprocessing.set_sharing_strategy('file_system')

# train, valid, test: only native pocket. train_after_warm_up, all_pocket_test include all other pockets(protein center and P2rank result)
if args.redocking:
    args.compound_coords_init_mode = "redocking"
elif args.redocking_no_rotate:
    args.redocking = True
    args.compound_coords_init_mode = "redocking_no_rotate"

train, valid, test= get_data(args, logger, addNoise=args.addNoise, use_whole_protein=args.use_whole_protein, compound_coords_init_mode=args.compound_coords_init_mode, pre=args.data_path)


11/02/2023 07:46:45 - INFO - Main - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: no

11/02/2023 07:46:45 - INFO - Main - Working directory is /home/t-kaiyuangao/workspace/fabind/fabind
11/02/2023 07:46:45 - INFO - Main - /opt/conda/envs/py38/lib/python3.8/site-packages/ipykernel_launcher.py --ip=127.0.0.1 --stdin=9003 --control=9001 --hb=9000 --Session.signature_scheme="hmac-sha256" --Session.key=b"0d84db60-ad4a-45c5-9da7-2dc09eade036" --shell=9002 --transport="tcp" --iopub=9004 --f=/home/t-kaiyuangao/.local/share/jupyter/runtime/kernel-v2-19550wtuonF6nms4h.json
11/02/2023 07:46:45 - INFO - Main - Loading dataset
11/02/2023 07:46:45 - INFO - Main - compound feature based on torchdrug
11/02/2023 07:46:45 - INFO - Main - protein feature based on esm2


['/home/t-kaiyuangao/workspace/data/fabind/dataset/processed/data.pt', '/home/t-kaiyuangao/workspace/data/fabind/dataset/processed/protein_1d_3d.lmdb', '/home/t-kaiyuangao/workspace/data/fabind/dataset/processed/compound_LAS_edge_index.lmdb', '/home/t-kaiyuangao/workspace/data/fabind/dataset/processed/compound_rdkit_coords.pt', '/home/t-kaiyuangao/workspace/data/fabind/dataset/processed/esm2_t33_650M_UR50D.lmdb']


In [9]:
num_pockets = []
num_proteins = []
num_ligands = []
for data in tqdm(train):
    num_pockets.append(len(data.node_xyz))
    num_proteins.append(len(data.node_xyz_whole))
    num_ligands.append(len(data.coords))

num_pockets = np.array(num_pockets)
num_proteins = np.array(num_proteins)
num_ligands = np.array(num_ligands)

fig = go.Figure()
fig.add_trace(go.Histogram(x=num_pockets, name="num_pockets", opacity=0.6))
fig.add_trace(go.Histogram(x=num_proteins, name="num_proteins", opacity=0.6))
fig.add_trace(go.Histogram(x=num_ligands, name="num_ligands", opacity=0.6))

# Update layout for better visualization
fig.update_layout(title="Element Frequency",
                  xaxis_title="Element", 
                  yaxis_title="Frequency",
                  barmode='overlay')

fig.show()

  0%|          | 0/17299 [00:00<?, ?it/s]

In [8]:

num_pockets = []
num_proteins = []
num_ligands = []
for data in tqdm(test):
    num_pockets.append(len(data.node_xyz))
    num_proteins.append(len(data.node_xyz_whole))
    num_ligands.append(len(data.coords))

num_pockets = np.array(num_pockets)
num_proteins = np.array(num_proteins)
num_ligands = np.array(num_ligands)

# Create histogram using plotly
# fig = go.Figure(data=[go.Histogram(x=num_proteins, nbinsx=len(np.unique(num_proteins)))])
# fig = go.Figure(data=[go.Histogram(x=num_pockets, nbinsx=len(np.unique(num_pockets)))])
# fig.update_layout(title="Element Frequency", xaxis_title="Element", yaxis_title="Frequency")
# Add histograms for each array
fig = go.Figure()
fig.add_trace(go.Histogram(x=num_pockets, name="num_pockets", opacity=0.6))
fig.add_trace(go.Histogram(x=num_proteins, name="num_proteins", opacity=0.6))
fig.add_trace(go.Histogram(x=num_ligands, name="num_ligands", opacity=0.6))

# Update layout for better visualization
fig.update_layout(title="Element Frequency",
                  xaxis_title="Element", 
                  yaxis_title="Frequency",
                  barmode='overlay')

fig.show()

  0%|          | 0/363 [00:00<?, ?it/s]

In [14]:
import plotly.graph_objects as go
import numpy as np

pocket_mae = np.load('../pocket_mae.npy')
rmsd = np.load('../rmsd.npy')

cnt = 0
for mae, sd in zip(pocket_mae, rmsd):
    if mae < 5:
        cnt += 1
cnt / len(rmsd)

0.8292011019283747

In [11]:
import plotly.graph_objects as go
import numpy as np

pocket_mae = np.load('../pocket_mae.npy')
rmsd = np.load('../rmsd.npy')

fig = go.Figure(data=go.Scatter(x=pocket_mae, y=rmsd, mode='markers'))
fig.update_layout(shapes=[dict(type='line', y0=2, y1=2, x0=min(pocket_mae), x1=max(pocket_mae), line=dict(color='Red'))])
fig.update_layout(title='Relationship between Accuracy on pocket mae vs rmsd',
                  xaxis_title='pocket mae',
                  yaxis_title='rmsd',
                  xaxis=dict(range=[0, 40]),
                  yaxis=dict(range=[0, 80]))
fig.show()

In [3]:
import numpy as np
import pandas as pd
import plotly.express as px

pocket_mae = np.load('../pocket_mae.npy')
rmsd = np.load('../rmsd.npy')
ligand_num = np.load('../ligand_num.npy')
protein_num = np.load('../protein_num.npy')

# Combine the data into a dataframe for easier plotting
df = pd.DataFrame({'pocket_mae': pocket_mae, 'rmsd': rmsd, 'ligand_num': ligand_num, 'protein_num': protein_num})

# Create scatter matrix using plotly express
fig = px.scatter_matrix(df)

# Update layout for better visualization
fig.update_layout(title="Relationships between metrics pocket_mae, rmsd, ligand_num, and protein_num")

fig.show()

In [13]:
import plotly.graph_objects as go
import numpy as np

ratios = []
rmsds = []
with open('../split_pdb_id/test_index', 'r') as f:
    pdbs = f.readlines()
    for pdb in pdbs:
        pdb = pdb.strip()
        gt_keepnode = np.load(f'../keepnodes/{pdb}_gt.npy').squeeze()
        pp_keepnode = np.load(f'../keepnodes/{pdb}_pp.npy').squeeze()
        rmsd = np.load(f'../keepnodes/{pdb}_rmsd.npy')
        set1 = set(gt_keepnode)
        set2 = set(pp_keepnode)
        common_elem_num = len(set1.intersection(set2))
        ratio = common_elem_num / len(gt_keepnode)
        ratios.append(ratio)
        rmsds.append(rmsd)
        
rmsd = np.array(rmsds)
ratio = np.array(ratios)
fig = go.Figure(data=go.Scatter(x=ratio, y=rmsd, mode='markers'))
fig.update_layout(shapes=[dict(type='line', y0=2, y1=2, x0=min(ratio), x1=max(ratio), line=dict(color='Red'))])
fig.update_layout(title='Relationship between Accuracy on pocket mae vs rmsd',
                  xaxis_title='ratio',
                  yaxis_title='rmsd',
                  xaxis=dict(range=[0, 1]),
                  yaxis=dict(range=[0, max(rmsds)]))
fig.show()

In [11]:
import plotly.graph_objects as go
import numpy as np

ratios = []
rmsds = []
with open('../split_pdb_id/test_index', 'r') as f:
    pdbs = f.readlines()
    for pdb in pdbs:
        pdb = pdb.strip()
        ratio = np.load(f'../keepnodes/{pdb}_cover_ratio.npy')
        rmsd = np.load(f'../keepnodes/{pdb}_rmsd.npy')
        ratios.append(ratio)
        rmsds.append(rmsd)
        
rmsd = np.array(rmsds)
ratio = np.array(ratios)
fig = go.Figure(data=go.Scatter(x=ratio, y=rmsd, mode='markers'))
fig.update_layout(shapes=[dict(type='line', y0=2, y1=2, x0=min(ratio), x1=max(ratio), line=dict(color='Red'))])
fig.update_layout(title='Relationship between Accuracy on pocket mae vs rmsd',
                  xaxis_title='ratio',
                  yaxis_title='rmsd',
                  xaxis=dict(range=[0, 1]),
                  yaxis=dict(range=[0, max(rmsds)]))
fig.show()

In [16]:
import plotly.graph_objects as go
import numpy as np

ratios = []
rmsds = []
with open('../split_pdb_id/test_index', 'r') as f:
    pdbs = f.readlines()
    for pdb in pdbs:
        pdb = pdb.strip()
        ratio = np.load(f'../keepnodes/{pdb}_max_ligand_radius.npy')
        rmsd = np.load(f'../keepnodes/{pdb}_rmsd.npy')
        ratios.append(ratio)
        rmsds.append(rmsd)
        
rmsd = np.array(rmsds)
ratio = np.array(ratios)
fig = go.Figure(data=go.Scatter(x=ratio, y=rmsd, mode='markers'))
fig.update_layout(shapes=[dict(type='line', y0=2, y1=2, x0=min(ratio), x1=max(ratio), line=dict(color='Red'))])
fig.update_layout(title='Relationship between Accuracy on pocket mae vs rmsd',
                  xaxis_title='max ligand radius',
                  yaxis_title='rmsd',
                  xaxis=dict(range=[0, max(ratio)]),
                  yaxis=dict(range=[0, max(rmsds)]))
fig.show()

In [30]:
import plotly.graph_objects as go
import numpy as np

ratios = []
rmsds = []
with open('../split_pdb_id/test_index', 'r') as f:
    pdbs = f.readlines()
    for pdb in pdbs:
        pdb = pdb.strip()
        ratio = np.load(f'../keepnodes/{pdb}_rdkit_max_ligand_radius.npy')
        rmsd = np.load(f'../keepnodes/{pdb}_max_ligand_radius.npy')
        ratios.append(ratio)
        rmsds.append(rmsd)
        
rmsd = np.array(rmsds)
ratio = np.array(ratios)
fig = go.Figure(data=go.Scatter(x=ratio, y=rmsd, mode='markers'))
fig.update_layout(shapes=[dict(type='line', y0=2, y1=2, x0=min(ratio), x1=max(ratio), line=dict(color='Red'))], width=800, height=800,)
fig.update_layout(title='Relationship between Accuracy on pocket mae vs rmsd',
                  xaxis_title='rdkit_max_ligand_radius',
                  yaxis_title='max_ligand_radius',
                  xaxis=dict(range=[0, 40]),
                  yaxis=dict(range=[0, 40]))
fig.show()