# PromptBindInference Use Case Notebook

## 1. Setup Environment

In [1]:
import numpy as np
import os
import torch
import sys
import argparse
import random
from datetime import datetime

from safetensors.torch import load_model
import yaml

In [2]:
import sys
promptbind_src = "./promptbind"
sys.path.insert(0, promptbind_src)

from promptbind.data.data import get_data
from promptbind.utils.metrics import *
from promptbind.utils.utils import *
from promptbind.utils.logging_utils import Logger
from promptbind.models.model import *

  from .autonotebook import tqdm as notebook_tqdm


## 2. Initialize Configuration

In [3]:
data_path = "data/pdbbind2020"
prompt_nf = 8
result_folder = "./results"
exp_name = f"test_prompt_{prompt_nf}"

In [4]:
config_path = 'options/test_args.yml'
with open(config_path, 'r') as f:
    args_dict = yaml.safe_load(f)
combined_args_dict = {**args_dict['config'], **args_dict['args']}

combined_args_dict['data_path'] = data_path
combined_args_dict['pocket_prompt_nf'] = prompt_nf
combined_args_dict['complex_prompt_nf'] = prompt_nf
combined_args_dict['resultFolder'] = result_folder
combined_args_dict['exp_name'] = f"test_prompt_{prompt_nf}"
combined_args_dict['ckpt'] = f"pretrained/prompt_{prompt_nf}/best/model.safetensors"

# Convert to argparse Namespace
args = argparse.Namespace(**combined_args_dict)

# Set random seed
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)

In [5]:
# import os
# import logging

# class Logger:
#     def __init__(self, log_path):
#         # Create a custom logger
#         self.logger = logging.getLogger('MainLogger')
#         self.logger.setLevel(logging.INFO)

#         # Prevent adding multiple handlers if Logger is initialized multiple times
#         if not self.logger.handlers:
#             # Create handlers
#             c_handler = logging.StreamHandler()
#             f_handler = logging.FileHandler(log_path)

#             # Create formatters and add them to the handlers
#             c_format = logging.Formatter(
#                 "%(asctime)s - %(levelname)s - %(message)s",
#                 datefmt="%m/%d/%Y %H:%M:%S",
#             )
#             f_format = logging.Formatter('%(message)s')

#             c_handler.setFormatter(c_format)
#             f_handler.setFormatter(f_format)

#             # Add handlers to the logger
#             self.logger.addHandler(c_handler)
#             self.logger.addHandler(f_handler)

#         # Log the working directory
#         self.logger.info(f'Working directory is {os.getcwd()}')

#     def log_stats(self, stats, epoch, args, prefix=''):
#         msg_start = f'[{prefix}] Epoch {epoch} out of {args.total_epochs} | '
#         dict_msg = ' | '.join([f'{k.capitalize()} --> {v:.5f}' for k, v in stats.items()]) + ' | '

#         msg = msg_start + dict_msg

#         self.log_message(msg)

#     def log_message(self, msg):
#         self.logger.info(msg)

In [5]:
# Prepare logging directory
pre = f"{args.resultFolder}/{args.exp_name}"
os.makedirs(pre, exist_ok=True)

# Initialize the logger without accelerator
logger = Logger(accelerator=None, log_path=f'{pre}/test.log')

# Log the command used to run the script
logger.log_message(f"{' '.join(sys.argv)}")

10/25/2024 16:34:49 - INFO - MainLogger - Working directory is /home/kevinb/protein/PromptBind-draft
10/25/2024 16:34:49 - INFO - MainLogger - /home/kevinb/miniconda3/envs/ppi-toolkit/lib/python3.10/site-packages/ipykernel_launcher.py --f=/home/kevinb/.local/share/jupyter/runtime/kernel-v3aef9cf15854ed1e8c07ab778e006194ae0ad6aa9.json


## 3. Initialize and Run Inference

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

model = get_model(args, logger, device)
load_model(model, args.ckpt)
model.to(device)

import torch.nn as nn

if args.pred_dis:
    criterion = nn.MSELoss()
    pred_dis = True
else:
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(args.posweight).to(device))

if args.coord_loss_function == 'MSE':
    com_coord_criterion = nn.MSELoss()
elif args.coord_loss_function == 'SmoothL1':
    com_coord_criterion = nn.SmoothL1Loss()

if args.pocket_cls_loss_func == 'bce':
    pocket_cls_criterion = nn.BCEWithLogitsLoss(reduction='mean')

pocket_coord_criterion = nn.HuberLoss(delta=args.pocket_coord_huber_delta)

10/25/2024 16:34:51 - INFO - MainLogger - PromptBind


Using device: cuda


In [7]:
if args.redocking:
    args.compound_coords_init_mode = "redocking"
elif args.redocking_no_rotate:
    args.redocking = True
    args.compound_coords_init_mode = "redocking_no_rotate"

train, valid, test = get_data(
    args,
    logger,
    addNoise=args.addNoise,
    use_whole_protein=args.use_whole_protein,
    compound_coords_init_mode=args.compound_coords_init_mode,
    pre=args.data_path,
)
logger.log_message(
    f"Data points - Train: {len(train)}, Valid: {len(valid)}, Test: {len(test)}"
)

num_workers = 0

from torch_geometric.loader import DataLoader

test_loader = DataLoader(
    test,
    batch_size=args.batch_size,
    follow_batch=['x', 'compound_pair'],
    shuffle=False,
    pin_memory=False,
    num_workers=num_workers,
)

with open('split_pdb_id/unseen_test_index') as f:
    test_unseen_pdb_list = [line.strip() for line in f]

test_unseen_index = test.data.query(
    "(group =='test') and (pdb in @test_unseen_pdb_list)"
).index.values
test_unseen_index_for_select = np.array(
    [np.where(test._indices == i)[0][0] for i in test_unseen_index]
)
test_unseen = test.index_select(test_unseen_index_for_select)

test_unseen_loader = DataLoader(
    test_unseen,
    batch_size=args.batch_size,
    follow_batch=['x', 'compound_pair'],
    shuffle=False,
    pin_memory=False,
    num_workers=num_workers,
)

10/25/2024 16:34:54 - INFO - MainLogger - Loading dataset
10/25/2024 16:34:54 - INFO - MainLogger - compound feature based on torchdrug
10/25/2024 16:34:54 - INFO - MainLogger - protein feature based on esm2
  self.data = torch.load(self.processed_paths[0])


['/home/kevinb/nas/protein/data/pdbbind2020/dataset/processed/data.pt', '/home/kevinb/nas/protein/data/pdbbind2020/dataset/processed/protein_1d_3d.lmdb', '/home/kevinb/nas/protein/data/pdbbind2020/dataset/processed/compound_LAS_edge_index.lmdb', '/home/kevinb/nas/protein/data/pdbbind2020/dataset/processed/compound_rdkit_coords.pt', '/home/kevinb/nas/protein/data/pdbbind2020/dataset/processed/esm2_t33_650M_UR50D.lmdb']


  self.compound_rdkit_coords = torch.load(self.processed_paths[3])
10/25/2024 16:34:55 - INFO - MainLogger - Data points - Train: 17299, Valid: 968, Test: 363


In [8]:
model.eval()
logger.log_message("Begin testing")
metrics, _, _ = evaluate_mean_pocket_cls_coord_multi_task(
    accelerator=None,
    args=args,
    data_loader=test_unseen_loader,
    model=model,
    com_coord_criterion=com_coord_criterion,
    criterion=criterion,
    pocket_cls_criterion=pocket_cls_criterion,
    pocket_coord_criterion=pocket_coord_criterion,
    relative_k=args.relative_k,
    device=device,
    pred_dis=args.pred_dis,
    use_y_mask=False,
    stage=2
)

logger.log_stats(metrics, 0, args, prefix="Test_unseen")

10/25/2024 16:34:59 - INFO - MainLogger - Begin testing
100%|██████████| 36/36 [02:46<00:00,  4.63s/it]
10/25/2024 16:37:46 - INFO - MainLogger - [Test_unseen] Epoch 0 out of 400 | Samples --> 144.00000 | Skip_samples --> 1.00000 | Keepnode < 5 --> 0.00000 | Contact_loss --> 1.13186 | Contact_by_pred_loss --> 1.19398 | Com_coord_huber_loss --> 4.27379 | Rmsd --> 7.95665 | Rmsd < 2a --> 0.21528 | Rmsd < 5a --> 0.57639 | Rmsd 25% --> 2.12651 | Rmsd 50% --> 3.74071 | Rmsd 75% --> 9.29415 | Centroid_dis --> 6.21516 | Centroid_dis < 2a --> 0.54167 | Centroid_dis < 5a --> 0.74306 | Centroid_dis 25% --> 0.98511 | Centroid_dis 50% --> 1.66359 | Centroid_dis 75% --> 5.20561 | Pocket_cls_bce_loss --> 0.57726 | Pocket_coord_mse_loss --> 0.45735 | Pocket_cls_accuracy --> 0.82094 | Pocket_pearson --> 0.65176 | Pocket_rmse --> 7.00585 | Pocket_mae --> 4.20523 | Pocket_center_avg_dist --> 8.34821 | Pocket_center_dcc --> 34.72222 | 
