# Sample inference pipeline (& runtime evaluation)

In [2]:
# Import required dependencies
import os
import torch
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from argparse import ArgumentParser
from pytorch_lightning import Trainer
import pytorch_lightning.callbacks as plc
from pytorch_lightning.loggers import TensorBoardLogger

## Graph generation dependencies
from create_graph.Ligand_graph import construct_ligand_graph, convert_to_pyg
from create_graph.Protein_graph import parallel_process_proteins

## Model dependecies
from model import MInterface
from data import DInterface
from utils import load_model_path_by_args, plot_rmsd_metrics

from data.dataset import prepare_data_binary, prepare_data_point, prepare_data_pose, prepare_data_rmsd
from sklearn.model_selection import train_test_split, KFold
from scipy import stats

from utils import ndcg_score


Using device: cuda


## Prepare the graph data

In [None]:
path = "dataset/raw_inference"
save_dir_protein = "dataset/protein_g_inference"
save_dir_ligand = "dataset/ligand_g_inference"

# Generate the index for protein/ligand pairs
index = os.listdir(path)
print(len(index)) # Datapoint count

30


In [None]:
# Protein graph
## Configuration for protein graph construction
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot
from graphein.protein.edges.distance import (
    add_peptide_bonds,
    add_hydrogen_bond_interactions,
    add_disulfide_interactions,
    add_ionic_interactions,
    add_aromatic_interactions,
    add_aromatic_sulphur_interactions,
    add_cation_pi_interactions
)

config = ProteinGraphConfig(
    granularity="centroids",
    node_metadata_functions=[amino_acid_one_hot],
    edge_construction_functions=[
        add_peptide_bonds,
        add_aromatic_interactions,
        add_hydrogen_bond_interactions,
        add_disulfide_interactions,
        add_ionic_interactions,
        add_aromatic_sulphur_interactions,
        add_cation_pi_interactions,
    ]
)

## Construct the protein graph
parallel_process_proteins(index, config, save_dir_protein, num_workers=1, chunk_size=100, path=path)

Output()

Output()

Failed to process 1b0p_TPP_B_1236: CUDA out of memory. Tried to allocate 3.74 GiB. GPU 0 has a total capacity of 7.62 GiB of which 1.15 GiB is free. Including non-PyTorch memory, this process has 6.45 GiB memory in use. Of the allocated memory 6.26 GiB is allocated by PyTorch, and 78.25 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)


Output()

Successfully processed 1amr_PMP_A_413


Output()

Successfully processed 1awb_IPD_A_281


Successfully processed 1b3d_S27_B_401


Output()

Successfully processed 1ajs_PLA_A_415


Output()

In [None]:
# Ligand graph generation
## Configuration for ligand graph construction
import graphein.molecule as gm
from functools import partial

config = gm.MoleculeGraphConfig(
    node_metadata_functions=[
        gm.atom_type_one_hot,
        gm.atomic_mass,
        gm.degree,
        gm.total_degree,
        gm.total_valence,
        gm.explicit_valence,
        gm.implicit_valence,
        gm.num_explicit_h,
        gm.num_implicit_h,
        gm.total_num_h,
        gm.num_radical_electrons,
        gm.formal_charge,
        gm.is_aromatic,
        gm.is_isotope,
        gm.is_ring,
        partial(gm.is_ring_size, ring_size=5),
        partial(gm.is_ring_size, ring_size=7)
    ]
)

proteins = pd.Series(index)
graphs = proteins.apply(lambda p: construct_ligand_graph(p, path))
# Convert all graphs
pyg_graphs = [convert_to_pyg(g) for g in graphs]
# Assuming pyg_graphs is a list containing your graph data objects
proteins = proteins.to_list()
os.makedirs(save_dir_ligand, exist_ok=True)
for idx, pyg_graph in enumerate(pyg_graphs):
    protein_name = proteins[idx]
    file_name = f"{save_dir_ligand}/pyg_graph_{protein_name}.pt"
    torch.save(pyg_graph, file_name)
    print(f'{protein_name} saved')


100
1b0p_TPP_B_1236 saved
1amr_PMP_A_413 saved
1awb_IPD_A_281 saved
1b3d_S27_B_401 saved
1ajs_PLA_A_415 saved
1a49_ATP_C_1735 saved
1ami_MIC_A_755 saved
1aex_THP_A_151 saved
1aqx_GTD_B_2201 saved
1a5w_Y3_A_1 saved
1aqx_GTD_C_2301 saved
1aog_MAE_A_500 saved
1b0p_TPP_A_1236 saved
1amq_PMP_A_413 saved
1aj2_2PH_A_283 saved
1a49_ATP_D_2335 saved
1a3u_THP_A_151 saved
1a80_NDP_A_300 saved
1aog_FAD_A_492 saved
1aer_TIA_B_700 saved
1a49_OXL_F_4133 saved
1afq_0FG_B_304 saved
1aej_NVI_A_296 saved
1akd_CAM_A_420 saved
1axd_GGL_C_1 saved
1aiq_CB3_B_267 saved
1ac4_TMT_A_296 saved
1a0g_PMP_A_285 saved
1aiq_UMP_B_266 saved
1a96_PCP_B_301 saved
1afe_ASP_I_55 saved
1a49_OXL_H_5333 saved
1aog_FAD_B_492 saved
1a59_COA_A_380 saved
1anc_BEN_A_290 saved
1abn_NDP_A_351 saved
1ah4_NAP_A_318 saved
1a59_CIT_A_379 saved
1b7y_FYA_A_1002 saved
1axg_NAD_C_403 saved
1af7_SAH_A_287 saved
1a96_PCP_C_302 saved
1a0g_PMP_B_285 saved
1aer_TAD_A_700 saved
1aia_PMP_A_411 saved
1a49_OXL_E_3533 saved
1aj8_CIT_A_1000 saved
1a4i

## Run inference

### New Inference Functionality

The project now includes a comprehensive inference pipeline that allows you to run algorithm selection predictions on custom data. Here's what's available:

**Key Features:**
- **Multiple Model Support**: Works with binary, point, and rank models
- **Pre-generated Graph Support**: Uses existing protein and ligand graph files
- **Input Validation**: Validates graph file existence and format
- **Comprehensive Output**: Generates algorithm selection predictions with confidence scores
- **CSV Export**: Saves results in structured format with rankings and statistics

**Command Line Usage:**
```bash
python main.py --inference \
    --model_name [binary|point|rank] \
    --ckpt_path /path/to/checkpoint.ckpt \
    --protein_graph_dir dataset/protein_g_inference \
    --ligand_graph_dir dataset/ligand_g_inference \
    --inference_output results.csv
```

**Input Requirements:**
- Pre-generated protein graphs in PyTorch Geometric format
- Pre-generated ligand graphs in PyTorch Geometric format  
- Graph files should be named: `pyg_graph_{protein_ligand_id}.pt`
- Trained model checkpoint

**Output Format:**
- CSV file with algorithm scores, selected algorithm, confidence, and rankings
- Summary statistics including algorithm distribution and mean confidence

In [None]:
# Run inference using the command line interface
# Example: Using a rank model checkpoint
import subprocess
import os

# Set the checkpoint path - change this to your desired model
checkpoint_path = "lightning_logs/version_90/checkpoints/last.ckpt"

# Check if checkpoint exists
if os.path.exists(checkpoint_path):
    print(f"Using checkpoint: {checkpoint_path}")
    
    # Run inference using the new inference functionality
    cmd = [
        "python", "main.py",
        "--model_name", "rank",  # or "binary", "point" depending on your model
        "--max_epochs", "500",
        "--loss", "bce",
        "--devices", "0",
        "--seed", "42",
        "--no_ndcg_loss",
        "--no_logistic_loss", 
        "--inference",
        "--ckpt_path", checkpoint_path,
        "--protein_graph_dir", "dataset/protein_g_inference",
        "--ligand_graph_dir", "dataset/ligand_g_inference",
        "--inference_output", "notebook_inference_results.csv"
    ]
    
    # Execute the inference
    result = subprocess.run(cmd, capture_output=True, text=True)
    
    # Print the output
    print("STDOUT:")
    print(result.stdout)
    
    if result.stderr:
        print("STDERR:")
        print(result.stderr)
    
    print(f"Return code: {result.returncode}")
    
    # Load and display results if successful
    if result.returncode == 0 and os.path.exists("notebook_inference_results.csv"):
        import pandas as pd
        results_df = pd.read_csv("notebook_inference_results.csv")
        print("\n=== Inference Results Preview ===")
        print(results_df.head(10))
        
        print(f"\n=== Summary ===")
        print(f"Total predictions: {len(results_df)}")
        print(f"Algorithm selection distribution:")
        print(results_df['selected_algorithm'].value_counts())
        
        print(f"\nMean confidence: {results_df['confidence'].mean():.4f}")
        print(f"Results saved to: notebook_inference_results.csv")
    
else:
    print(f"Checkpoint not found: {checkpoint_path}")
    print("Please check available checkpoints:")
    # List available checkpoints
    for root, dirs, files in os.walk("lightning_logs"):
        for file in files:
            if file.endswith('.ckpt'):
                print(f"  {os.path.join(root, file)}")

Using checkpoint: lightning_logs/version_90/checkpoints/last.ckpt
STDOUT:
Loading inference data from:
  Protein graphs: dataset/protein_g_inference
  Ligand graphs: dataset/ligand_g_inference
Found 30 protein-ligand pairs for inference.
Loaded model from: lightning_logs/version_90/checkpoints/last.ckpt
Model criteria: rank
Number of classes: 8
Running inference...

Predicting: |          | 0/? [00:00<?, ?it/s]
Predicting:   0%|          | 0/4 [00:00<?, ?it/s]
Predicting DataLoader 0:   0%|          | 0/4 [00:00<?, ?it/s]
Predicting DataLoader 0:  25%|██▌       | 1/4 [00:00<00:00,  3.49it/s]
Predicting DataLoader 0:  50%|█████     | 2/4 [00:00<00:00,  6.16it/s]
Predicting DataLoader 0:  75%|███████▌  | 3/4 [00:00<00:00,  8.34it/s]
Predicting DataLoader 0: 100%|██████████| 4/4 [00:00<00:00, 10.27it/s]
Predicting DataLoader 0: 100%|██████████| 4/4 [00:00<00:00, 10.25it/s]
Generated predictions for 30 protein-ligand pairs.
Predictions array shape: (240,)
Model criteria: rank
Number of cla