# Compute success rates for a custom test set
This notebook is used for computing success rates for a custom test set, where the predictions are performed with the predict.py script.

In [None]:
from biotite import structure as struc
from biotite.structure.io import pdb
from tqdm import tqdm
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import pickle
from AF2Dock.utils.dockq import compute_metrics

In [None]:
# input csv containing id for each target
input_csv = ""
predict_targets = pd.read_csv(input_csv)

expt_name = ""

num_targets = 1
num_samples = 40
num_samples_per_run = 40

# ground truth structure directory
gt_pdb_dir = ""
# input structure directory. input strucutres need to have the same chain names as the ground truth structures
input_pdb_dir = ""
# results directory of AF2Dock predictions
results_dir = ""

# bootstrap parameters
num_bootstrap = 10000
confidence_metrics_to_rank = ['ipTM_af']
confidence_metrics_assending = {'ipTM_af':False}

In [None]:
metrics_chain = []
metrics_global = []
for idx in tqdm(range(num_targets)):
    data_id = predict_targets.iloc[idx]["id"]
    gt_pdb_path = f'{gt_pdb_dir}/{data_id}.pdb'
    input_rec_pdb_path = f'{input_pdb_dir}/{data_id}_r_holo.pdb'
    for sample_idx in range(num_samples):
        # find best chain mapping
        info = compute_metrics(model = f'{results_dir}/{idx}_{data_id}/{data_id}_s{sample_idx}_ori_chain.pdb',
                                native = gt_pdb_path,
                                n_cpu=8)

        # merge receptor and ligand chains based on best mapping
        best_chain_map = info['best_mapping']
        input_rec_ori = pdb.PDBFile.read(input_rec_pdb_path)
        input_rec_ori_atomarray = pdb.get_structure(input_rec_ori, model=1)
        input_rec_ori_chains = list(struc.get_chains(input_rec_ori_atomarray))
        chains_to_merge = {'native': [input_rec_ori_chains, [chain for chain in best_chain_map.keys() if chain not in input_rec_ori_chains]]}
        chains_to_merge['model'] = [[best_chain_map[chain] for chain in ori_chains] for ori_chains in chains_to_merge['native']]
        
        # compute metrics after merging
        info = compute_metrics(model = f'{results_dir}/{idx}_{data_id}/{data_id}_s{sample_idx}_ori_chain.pdb',
                                native = gt_pdb_path,
                                n_cpu=8, chains_to_merge=chains_to_merge)
        
        for key in info['best_result'].keys():
            metrics_i = info['best_result'][key]
            metrics_i['Model'] = f'{data_id}_s{sample_idx}'
            metrics_i['chain_pair'] = key
            metrics_chain.append(metrics_i)
        metrics_global_i = {'GlobalDockQ': info['GlobalDockQ']}
        metrics_global_i['best_dockq'] = info['best_dockq']
        metrics_global_i['best_mapping'] = info['best_mapping_str']
        metrics_global_i['Model'] = f'{data_id}_s{sample_idx}'
        metrics_global.append(metrics_global_i)

metrics_chain = pd.DataFrame(metrics_chain)
metrics_global = pd.DataFrame(metrics_global)


In [None]:
confidence_metrics = []
for idx in tqdm(range(num_targets)):
    data_id = predict_targets.iloc[idx]["PDB"].split('.')[0]
    # load ipTM scores from csv files
    iptm_all = []
    for run_idx in range(num_samples // num_samples_per_run):
        starting_idx = run_idx * num_samples_per_run
        ending_idx = (run_idx + 1) * num_samples_per_run -1
        iptm_run = pd.read_csv(f'{results_dir}/{idx}_{data_id}/{data_id}_s{starting_idx}_{ending_idx}_iptm.csv')
        iptm_all.append(iptm_run)
    confidence_metrics_i = pd.concat(iptm_all, ignore_index=True)
    confidence_metrics_i.rename(columns={'iptm': 'ipTM_af'}, inplace=True)
    confidence_metrics_i['Model'] = confidence_metrics_i['sample_idx'].apply(lambda x: f'{data_id}_s{x}')
    confidence_metrics.append(confidence_metrics_i)

confidence_metrics = pd.concat(confidence_metrics)

In [None]:
combined_metrics = pd.merge(metrics_chain, confidence_metrics, left_on='Model', right_on='Model', how='left').reset_index(drop=True)

In [None]:
combined_metrics.to_csv(f"combined_metrics_{expt_name}.csv", index=False)

In [None]:
combined_metrics = pd.read_csv(f"combined_metrics_{expt_name}.csv")

In [None]:
success_count_indi_all = {}
for confidence_metric_name in confidence_metrics_to_rank:
    success_count_indi_all[f'top1_{confidence_metric_name}'] = {'high': [], 'medium': [], 'acceptable': []}
    success_count_indi_all[f'top5_{confidence_metric_name}'] = {'high': [], 'medium': [], 'acceptable': []}
success_count_indi_all['oracle'] = {'high': [], 'medium': [], 'acceptable': []}
np.random.seed(42)

for bs_idx in tqdm(range(num_bootstrap)):
    per_sample_rand_idx = np.random.randint(num_samples, size=num_samples * num_targets)
    sample_base_idx = np.repeat(np.arange(num_targets) * num_samples, num_samples)
    randlist = pd.DataFrame(index=per_sample_rand_idx + sample_base_idx)
    combined_metrics_bs_i = combined_metrics.merge(randlist, left_index=True, right_index=True, how='right').reset_index(drop=True)
    
    oracle_top1 = []
    for idx in range(num_targets):
        sample_i = combined_metrics_bs_i[idx * num_samples: (idx + 1) * num_samples]
        oracle_top1_i  = sample_i.sort_values('DockQ', ascending=False).head(1)
        oracle_top1.append(oracle_top1_i)
    oracle_top1 = pd.concat(oracle_top1)
    
    confidence_metric_top1 = {confidence_metric_name: [] for confidence_metric_name in confidence_metrics_to_rank}
    confidence_metric_top5 = {confidence_metric_name: [] for confidence_metric_name in confidence_metrics_to_rank}
    for idx in range(num_targets):
        sample_i = combined_metrics_bs_i.iloc[idx * num_samples: (idx + 1) * num_samples]
        for confidence_metric_name in confidence_metrics_to_rank:
            confidence_metric_top1_i = sample_i.sort_values(confidence_metric_name, ascending=confidence_metrics_assending[confidence_metric_name]).head(1)
            confidence_metric_top1[confidence_metric_name].append(confidence_metric_top1_i)
            confidence_metric_top5_i = sample_i.sort_values(confidence_metric_name, ascending=confidence_metrics_assending[confidence_metric_name]).head(5)
            confidence_metric_top5[confidence_metric_name].append(confidence_metric_top5_i)
    confidence_metric_top1 = {confidence_metric_name: pd.concat(confidence_metric_top1[confidence_metric_name]).reset_index(drop=True) for confidence_metric_name in confidence_metrics_to_rank}
    confidence_metric_top5 = {confidence_metric_name: pd.concat(confidence_metric_top5[confidence_metric_name]).reset_index(drop=True) for confidence_metric_name in confidence_metrics_to_rank}
    
    success_count_indi = {}
    for confidence_metric_name in confidence_metrics_to_rank:
        top1_high = confidence_metric_top1[confidence_metric_name]['DockQ'] >= 0.8
        top1_medium = confidence_metric_top1[confidence_metric_name]['DockQ'] >= 0.49
        top1_acceptable = confidence_metric_top1[confidence_metric_name]['DockQ'] >= 0.23
        success_count_indi[f'top1_{confidence_metric_name}'] = {'high': top1_high, 'medium': top1_medium, 'acceptable':top1_acceptable}
        success_count_indi[f'top5_{confidence_metric_name}'] = {'high': 0, 'medium': 0, 'acceptable':0}
        top5_high = []
        top5_medium = []
        top5_acceptable = []
        for idx in range(num_targets):
            sample_i = confidence_metric_top5[confidence_metric_name].iloc[idx * 5: (idx + 1) * 5]
            top5_high.append((sample_i['DockQ'].max() >= 0.8).any())
            top5_medium.append((sample_i['DockQ'].max() >= 0.49).any())
            top5_acceptable.append((sample_i['DockQ'].max() >= 0.23).any())
        success_count_indi[f'top5_{confidence_metric_name}']['high'] = np.array(top5_high)
        success_count_indi[f'top5_{confidence_metric_name}']['medium'] = np.array(top5_medium)
        success_count_indi[f'top5_{confidence_metric_name}']['acceptable'] = np.array(top5_acceptable)
    oracle_high = oracle_top1['DockQ'] >= 0.8
    oracle_medium = oracle_top1['DockQ'] >= 0.49
    oracle_acceptable = oracle_top1['DockQ'] >= 0.23
    success_count_indi['oracle'] = {'high': oracle_high, 'medium': oracle_medium, 'acceptable': oracle_acceptable}

    for key in success_count_indi.keys():
        for subkey in success_count_indi[key].keys():
            success_count_indi_all[key][subkey].append(success_count_indi[key][subkey])
    
for key in success_count_indi_all.keys():
    for subkey in success_count_indi_all[key].keys():
        success_count_indi_all[key][subkey] = np.array(success_count_indi_all[key][subkey])

In [None]:
success_rate_CI95_mean = {}
success_rate_CI95_err = {}
for key in success_count_indi_all.keys():
    success_rate_CI95_mean[key] = {}
    success_rate_CI95_err[key] = {}
    for subkey in success_count_indi_all[key].keys():
        success_count = success_count_indi_all[key][subkey].sum(axis=1)
        CI95 = [np.percentile(success_count, 2.5),np.percentile(success_count, 97.5)]
        CI95_mean = np.mean(success_count)
        CI95_err = np.array([CI95[0], CI95[1]])
        success_rate_CI95_mean[key][subkey] = CI95_mean / num_targets
        success_rate_CI95_err[key][subkey] = CI95_err / num_targets

success_rate_CI95_mean = pd.DataFrame(success_rate_CI95_mean)
success_rate_CI95_err = pd.DataFrame(success_rate_CI95_err)
success_rate_CI95 = {'mean': success_rate_CI95_mean, 'err': success_rate_CI95_err}

with open(f'success_rate_CI95_10000_{expt_name}.pkl', 'wb') as f:
    pickle.dump(success_rate_CI95, f)

In [None]:
with open(f'success_rate_CI95_10000_{expt_name}.pkl', 'rb') as f:
    success_rate_CI95 = pickle.load(f)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(3.0, 3.5))

combined_results = {stat:{
    "AF2Dock":{
        'top1' : success_rate_CI95[stat]['top1_ipTM_af'],
        'top5' : success_rate_CI95[stat]['top5_ipTM_af'],
        'oracle' : success_rate_CI95[stat]['oracle']
    }} for stat in ['mean', 'err']
}

cate_larger = ['AF2Dock']
metrics = ['top1', 'top5', 'oracle']
n_groups = len(cate_larger)
n_metrics = len(metrics)

# Set up positions for grouped bars
bar_width = 0.2  # Width of individual bars
bar_spacing = 0.05  # Space between bars within a group
group_spacing = 0.2  # Space between groups
group_width = n_metrics * bar_width + (n_metrics - 1) * bar_spacing
group_positions = np.arange(n_groups) * (group_width + group_spacing)

colors = {
    'acceptable': '#F3C6AC',
    'medium': '#E8674C',
    'high': '#761314'
}

# Create grouped bar data
for i, metric in enumerate(metrics):
    metric_positions = group_positions + i * (bar_width + bar_spacing)
    metric_positions_error_bar = np.array([pos + i * (bar_width + bar_spacing) for j, pos in enumerate(group_positions)])

    
    accu_counts = {
        'high': [combined_results['mean'][cat][metric]['high'] * 100 for cat in cate_larger],
        'medium': [combined_results['mean'][cat][metric]['medium'] * 100 for cat in cate_larger],
        'acceptable': [combined_results['mean'][cat][metric]['acceptable'] * 100 for cat in cate_larger]
    }
    accu_err_center = {
        'high': [np.mean(combined_results['err'][cat][metric]['high']) * 100 for cat in cate_larger],
        'medium': [np.mean(combined_results['err'][cat][metric]['medium']) * 100 for cat in cate_larger],
        'acceptable': [np.mean(combined_results['err'][cat][metric]['acceptable']) * 100 for cat in cate_larger]
    }
    accu_err_err = {
        'high': [(combined_results['err'][cat][metric]['high'][1] - combined_results['err'][cat][metric]['high'][0]) / 2 * 100 for cat in cate_larger],
        'medium': [(combined_results['err'][cat][metric]['medium'][1] - combined_results['err'][cat][metric]['medium'][0]) / 2 * 100 for cat in cate_larger],
        'acceptable': [(combined_results['err'][cat][metric]['acceptable'][1] - combined_results['err'][cat][metric]['acceptable'][0]) / 2 * 100 for cat in cate_larger]
    }
    
    previous_heights = [None] * len(cate_larger)  # Initialize previous values for each metric
    z_orders = {'high': 3, 'medium': 2, 'acceptable': 1}
    ebar_pos_adjust = {'high': bar_width / 2, 'medium': 0, 'acceptable': - bar_width / 2}
    for accu_cate in ['high', 'medium', 'acceptable']:
        if i == 0:
            p = ax.bar(metric_positions, accu_counts[accu_cate], bar_width, capsize=2.0, 
                    label=accu_cate, color=colors[accu_cate], zorder=z_orders[accu_cate])
        else:
            p = ax.bar(metric_positions, accu_counts[accu_cate], bar_width, capsize=2.0,
                    color=colors[accu_cate], zorder=z_orders[accu_cate])
        ebar = ax.errorbar(metric_positions_error_bar + ebar_pos_adjust[accu_cate], accu_err_center[accu_cate], yerr=accu_err_err[accu_cate],
                        fmt="none", color="gray", capsize=2.0)
        [bar.set_zorder(5) for bar in ebar.lines[1]]
        [cap.set_zorder(5) for cap in ebar.lines[2]]

        # Add percentage labels on top of bars
        for j, (pos, val, previous_height) in enumerate(zip(metric_positions, accu_counts[accu_cate], previous_heights)):
            if previous_height is not None and val < previous_height + 2:
                height = previous_height + 2
            else:
                height = val
            previous_heights[j] = height  # Update previous value for next iteration
            ax.text(pos, height, f'{val:.1f}%', 
                   ha='center', va='bottom', fontsize=7, rotation=0)

ax.legend()
ax.set_ylim([0, 100])
ax.set_ylabel('Success Rate (%)')

# Add reference lines
# Calculate the actual span of all bars
first_bar_start = group_positions[0] - bar_width/2
last_bar_end = group_positions[-1] + (n_metrics - 1) * (bar_width + bar_spacing) + bar_width/2

ax.set_xlim([first_bar_start - 0.2, last_bar_end + 0.2])

ax.set_xticklabels([])
ax.set_xticks([])

minor_ticks = []
minor_labels = []
for i, group_pos in enumerate(group_positions):
    for j, metric in enumerate(metrics):
        minor_ticks.append(group_pos + j * (bar_width + bar_spacing))
        minor_labels.append(metric)

ax2 = ax.twiny()
ax2.set_xlim(ax.get_xlim())
ax2.set_xticks(minor_ticks)
ax2.set_xticklabels(minor_labels, rotation=0, ha='center', fontsize=8)
ax2.tick_params(axis='x', which='major', top=False, bottom=True, labeltop=False, labelbottom=True, pad=5)

fig.tight_layout()
plt.show()