# Load model

In [1]:
# model path
import yaml
from abflow.utils.training import setup_model

config_path = '/spinning1/sharedby/hz362/AbFlow/config/sabdab.yaml'
checkpoint_path = '/scratch/hz362/datavol/model/seq_bb_epoch=199.ckpt'
device = "cuda:2"  # or "cpu"

config = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader)
model, datamodule = setup_model(config, checkpoint_path, load_optimizer=False)
model.to(device)
model.eval()
if model.training:
    print("The model is in training mode.")
else:
    print("The model is in evaluation mode.")

Number of pdbs in the full dataset: 10933
Number of clusters in the full dataset: 3145
Number of RAbD id: 60
Number of clusters and samples in training: 3069, 5490
Number of clusters and samples in validation: 10, 12
Number of clusters and samples in test: 60, 230
Number of structures in the train split: 5490
Number of pdbs in the full dataset: 10933
Number of clusters in the full dataset: 3145
Number of RAbD id: 60
Number of clusters and samples in training: 3069, 5490
Number of clusters and samples in validation: 10, 12
Number of clusters and samples in test: 60, 230
Number of structures in the val split: 12
Number of pdbs in the full dataset: 10933
Number of clusters in the full dataset: 3145
Number of RAbD id: 60
Number of clusters and samples in training: 3069, 5490
Number of clusters and samples in validation: 10, 12
Number of clusters and samples in test: 60, 230
Number of structures in the test split: 230


  checkpoint = torch.load(checkpoint_path)


The model is in evaluation mode.


In [2]:
import os
import torch
import pandas as pd
import numpy as np
import json
from collections import defaultdict
from glob import glob
from datetime import datetime
from abflow.constants import initialize_constants
from abflow.data.process_pdb import process_pdb_to_lmdb, process_lmdb_chain, add_features, fill_missing_atoms, output_to_pdb
from abflow.model.metrics import AbFlowMetrics
from abflow.model.utils import concat_dicts
from abflow.constants import chain_id_to_index, aa1_name_to_index


def process_pdb_to_data_dict(pdb_file, heavy_chain_id, light_chain_id, antigen_chain_ids, scheme):
    """Process PDB file into input data dictionary."""
    fixed_pdb_file = pdb_file.replace(".pdb", "_fixed.pdb")
    fill_missing_atoms(pdb_file, fixed_pdb_file)

    data = process_pdb_to_lmdb(
        fixed_pdb_file, model_id=0,
        heavy_chain_id=heavy_chain_id, light_chain_id=light_chain_id,
        antigen_chain_ids=antigen_chain_ids, scheme=scheme
    )
    data_dict = process_lmdb_chain(data)
    data_dict.update(add_features(data_dict))
    
    return data_dict, fixed_pdb_file

def generate_complexes(data_dict, num_designs, batch_size, seed):
    """Generate complexes from input data dictionary."""
    pred_data_dicts = []
    for _ in range(0, num_designs, batch_size):
        current_batch_size = min(batch_size, num_designs - len(pred_data_dicts))
        true_data_dict = datamodule.collate([data_dict.copy()] * current_batch_size)
        for key in true_data_dict:
            true_data_dict[key] = true_data_dict[key].to(device)

        pred_data_dict = model._generate_complexes(true_data_dict, seed=seed)
        pred_data_dicts.append(pred_data_dict)

    # Combine all predictions into one dictionary
    pred_data_dict = concat_dicts(pred_data_dicts)

    return pred_data_dict

def squeeze_data_dict(data_dict):
    """Squeeze the data dictionary to remove unnecessary dimensions."""
    for key in data_dict:
        if isinstance(data_dict[key], torch.Tensor):
            data_dict[key] = data_dict[key].squeeze(0)
        elif isinstance(data_dict[key], list):
            data_dict[key] = [item.squeeze(0) for item in data_dict[key]]
    return data_dict

def unpad_data(data: dict) -> dict:
    """
    Reverse of pad_data — restores original unpadded shapes using 'valid_mask'.

    Handles:
    - 1D or 2D per-residue features: (N,)
    - 2D or 3D pairwise features: (N, N)
    - Non-tensor values are left unchanged.
    - 'valid_mask' itself is removed from the returned dict.

    :param data: Dictionary with padded tensors and 'valid_mask'
    :return: Dictionary with valid-only entries
    """
    valid_mask = data.get("valid_mask", None)
    if valid_mask is None:
        raise ValueError("'valid_mask' is required for unpadding.")

    valid_length = valid_mask.sum().item()
    unpadded_data = {}

    for key, value in data.items():
        if key == "valid_mask":
            continue

        if isinstance(value, torch.Tensor):
            if value.ndim >= 2 and value.shape[0] == value.shape[1]:
                unpadded_data[key] = value[:valid_length, :valid_length]
            else:
                unpadded_data[key] = value[:valid_length]
        else:
            unpadded_data[key] = value

    return unpadded_data

def copy_data_dict(data_dict, num_designs):

    true_data_dict = datamodule.collate([data_dict.copy()] * num_designs)
    for key in true_data_dict:
        true_data_dict[key] = true_data_dict[key].to(device)

    return true_data_dict

def compute_metrics(true_data_dict, pred_data_dict, model_pred = True):
    """Compute metrics for the generated complexes."""

    if model_pred:
        metrics = AbFlowMetrics()
    else:
        metrics = AbFlowMetrics(model_pred=False)
    metrics_dict = metrics(pred_data_dict, true_data_dict)

    aggregated_metrics = {k: v.mean().item() for k, v in metrics_dict.items()}
    return aggregated_metrics

def cleanup_fixed_file(fixed_pdb_file):
    """Remove the fixed PDB file to keep the directory clean."""
    if os.path.exists(fixed_pdb_file):
        os.remove(fixed_pdb_file)
        print(f"Temporary file removed: {fixed_pdb_file}")

def evaluate_single_pdb(pdb_file, heavy_chain_id, light_chain_id, antigen_chain_ids, scheme, num_designs, batch_size, seed):
    """Full pipeline to process PDB, generate complexes, compute metrics, and clean up."""

    try:
        data_dict, fixed_pdb_file = process_pdb_to_data_dict(pdb_file, heavy_chain_id, light_chain_id, antigen_chain_ids, scheme)
        pred_data_dict = generate_complexes(data_dict, num_designs, batch_size, seed)
        true_data_dict = copy_data_dict(data_dict, num_designs)
        metrics = compute_metrics(true_data_dict, pred_data_dict)
        return metrics
    finally:
        cleanup_fixed_file(fixed_pdb_file)

def design_single_pdb(pdb_file, heavy_chain_id, light_chain_id, antigen_chain_ids, scheme, batch_size, seed, output_dir):
    try:
        pdb_filename = os.path.basename(pdb_file)
        output_path = os.path.join(output_dir, pdb_filename)
        os.makedirs(output_dir, exist_ok=True)
        data_dict, fixed_pdb_file = process_pdb_to_data_dict(
            pdb_file, heavy_chain_id, light_chain_id, antigen_chain_ids, scheme
        )
        pred_data_dict = generate_complexes(data_dict, 1, batch_size, seed)
        pred_data_dict = squeeze_data_dict(pred_data_dict)
        pred_data_dict = unpad_data(pred_data_dict)

        output_to_pdb(pred_data_dict, path=output_path)

        print(f"Design saved to: {output_path}")
    finally:
        cleanup_fixed_file(fixed_pdb_file)


def evaluate_two_pdbs(pred_pdb_file, true_pdb_file, heavy_chain_id, light_chain_id, antigen_chain_ids, scheme):
    """
    Pipeline to evaluate two pdb files
    """
    try:
        data_dict_1, fixed_pdb_file_1 = process_pdb_to_data_dict(pred_pdb_file, heavy_chain_id, light_chain_id, antigen_chain_ids, scheme)
        data_dict_2, fixed_pdb_file_2 = process_pdb_to_data_dict(true_pdb_file, heavy_chain_id, light_chain_id, antigen_chain_ids, scheme)
        pred_data_dict = datamodule.collate([data_dict_1.copy()])
        true_data_dict = datamodule.collate([data_dict_2.copy()])
        metrics = compute_metrics(true_data_dict, pred_data_dict, model_pred=False)
        return metrics
    finally:
        cleanup_fixed_file(fixed_pdb_file_1)
        cleanup_fixed_file(fixed_pdb_file_2)



def evaluate_mutated_pdb(pdb_file, parent_info, mutated_info, heavy_chain_id, light_chain_id, antigen_chain_ids, scheme, results_dir, num_designs, batch_size, seed):
    """
    Evaluates a PDB file with mutated sequences provided in mutated_info and saves metrics to a CSV file,
    ensuring only mutations with the same length as the parent sequence are processed.

    :param pdb_file: Path to the PDB file.
    :param parent_info: Path to the CSV file containing parent sequence information.
    :param mutated_info: Path to the CSV file containing mutated sequences and metadata.
    :param heavy_chain_id: Chain ID for the heavy chain.
    :param light_chain_id: Chain ID for the light chain.
    :param antigen_chain_ids: List of chain IDs for antigens.
    :param scheme: Antibody numbering scheme.
    :param results_dir: Directory to save the results CSV file.

    :return: None
    """

    parent_df = pd.read_csv(parent_info)
    parent_heavy_sequence = parent_df["Heavy"].iloc[0]
    parent_light_sequence = parent_df["Light"].iloc[0]
    parent_hcdr3_length = len(parent_df["HCDR3"].iloc[0])

    mutated_df = pd.read_csv(mutated_info)

    mutated_df = mutated_df[mutated_df["HCDR3"].str.len() == parent_hcdr3_length]

    data_dict, fixed_pdb_file = process_pdb_to_data_dict(
        pdb_file, heavy_chain_id, light_chain_id, antigen_chain_ids, scheme
    )

    pred_data_dict = generate_complexes(data_dict, num_designs, batch_size, seed)

    mutated_df["likelihood"] = None

    try:
        for idx, row in mutated_df.iterrows():
            mutated_data_dict = data_dict.copy()

            mutated_heavy_sequence = parent_heavy_sequence.replace(parent_df["HCDR3"].iloc[0], row["HCDR3"])
            mutated_light_sequence = parent_light_sequence

            heavy_indices = (data_dict["chain_type"] == chain_id_to_index["heavy"])
            light_indices = (data_dict["chain_type"] == chain_id_to_index["light_lambda"]) | (data_dict["chain_type"] == chain_id_to_index["light_kappa"])

            mutated_data_dict["res_type"][heavy_indices] = torch.tensor(
                [aa1_name_to_index[aa] for aa in mutated_heavy_sequence],
                dtype=torch.long,
                device=data_dict["res_type"].device,
            )
            mutated_data_dict["res_type"][light_indices] = torch.tensor(
                [aa1_name_to_index[aa] for aa in mutated_light_sequence],
                dtype=torch.long,
                device=data_dict["res_type"].device,
            )

            mutated_data_dict = copy_data_dict(mutated_data_dict, num_designs)

            metrics = compute_metrics(mutated_data_dict, pred_data_dict)

            likelihood = metrics.get("likelihood/redesign", float("nan"))

            mutated_df.at[idx, "likelihood"] = likelihood

        os.makedirs(results_dir, exist_ok=True)
        output_csv_path = os.path.join(results_dir, "absci_her2_zs_likelihood.csv")
        mutated_df.to_csv(output_csv_path, index=False)
        print(f"Filtered results saved to: {output_csv_path}")

    finally:
        cleanup_fixed_file(fixed_pdb_file)

def get_avg_metrics(metrics_list):
    avg_dict = defaultdict(float)
    for m in metrics_list:
        for k, v in m.items():
            avg_dict[k] += v
    for k in avg_dict:
        avg_dict[k] /= len(metrics_list)
    return dict(avg_dict)

def evaluate_designs(design_dir, true_pdb_dir, scheme = "chothia"):
    pdb_files = sorted(glob(os.path.join(design_dir, "*.pdb")))
    all_results = {}
    all_metrics_list = []

    print(f"🧬 Found {len(pdb_files)} design PDBs.")

    for pdb_file in pdb_files:
        name = os.path.splitext(os.path.basename(pdb_file))[0]
        true_pdb_file = os.path.join(true_pdb_dir, f"{name}.pdb")

        if not os.path.exists(true_pdb_file):
            print(f"❌ Missing true PDB for {name}, skipping...")
            continue

        parts = name.split("_")
        heavy_chain_id = parts[1]
        light_chain_id = parts[2]
        antigen_chain_ids = parts[3:]

        try:
            metrics = evaluate_two_pdbs(
                pred_pdb_file=pdb_file,
                true_pdb_file=true_pdb_file,
                heavy_chain_id=heavy_chain_id,
                light_chain_id=light_chain_id,
                antigen_chain_ids=antigen_chain_ids,
                scheme=scheme
            )
            all_results[name] = metrics
            all_metrics_list.append(metrics)
            print(f"✅ Evaluated {name}")
        except Exception as e:
            print(f"⚠️ Failed evaluation for {name}: {e}")

    print("\n📊 Averaging metrics across all designs...")
    avg_metrics = get_avg_metrics(all_metrics_list)
    print(json.dumps(avg_metrics, indent=2))

    return all_results

def average_all_structures(all_results):
    combined_metrics = defaultdict(list)
    for metrics in all_results.values():
        for key, value in metrics.items():
            combined_metrics[key].append(value)

    final_avg = {key: float(np.mean(values)) for key, values in combined_metrics.items()}
    return final_avg



# RABD Benchmark

#### Evaluate RABD benchmark using AbFlow

In [None]:
batch_size = 1
num_designs = 1
pdb_dir = "/spinning1/sharedby/hz362/AbFlow/data/rabd/pdb"
results_dir = "/spinning1/sharedby/hz362/AbFlow/results"
scheme = "chothia"
seed = 2025

initialize_constants(device=device)

all_metrics = []
for pdb_file in os.listdir(pdb_dir):
    if pdb_file.endswith(".pdb"):
        pdb_path = os.path.join(pdb_dir, pdb_file)
        base_name = os.path.basename(pdb_file)
        parts = base_name.split(".")[0].split("_")
        heavy_chain_id, light_chain_id, *antigen_chain_ids = parts[1:]
        metrics = evaluate_single_pdb(pdb_path, heavy_chain_id, light_chain_id, antigen_chain_ids, scheme, num_designs, batch_size, seed)
        all_metrics.append(metrics)
        print(f"PDB file processed: {pdb_file}")
        break


aggregated_metrics = {}
for key in all_metrics[0]:
    valid_values = [d[key] for d in all_metrics]
    aggregated_metrics[key] = sum(valid_values) / len(valid_values)

metrics_df = pd.DataFrame([aggregated_metrics], index=[datetime.now().strftime("%Y%m%d_%H%M%S")])

os.makedirs(results_dir, exist_ok=True)
output_path = os.path.join(results_dir, f"rabd_metrics.csv")
metrics_df.to_csv(output_path)
print(f"\nResults saved to: {output_path}")

#### Design single PDB using AbFlow

In [4]:
batch_size = 1
pdb_dir = "/spinning1/sharedby/hz362/AbFlow/data/rabd/pdb"
results_dir = "/spinning1/sharedby/hz362/AbFlow/data/baseline/abflow"
scheme = "chothia"
seed = 2025

initialize_constants(device=device)

os.makedirs(results_dir, exist_ok=True)

for pdb_file in os.listdir(pdb_dir):
    if pdb_file.endswith(".pdb"):
        pdb_path = os.path.join(pdb_dir, pdb_file)

        base_name = os.path.basename(pdb_file)
        parts = base_name.split(".")[0].split("_")
        heavy_chain_id, light_chain_id, *antigen_chain_ids = parts[1:]

        design_single_pdb(
            pdb_file=pdb_path,
            heavy_chain_id=heavy_chain_id,
            light_chain_id=light_chain_id,
            antigen_chain_ids=antigen_chain_ids,
            scheme=scheme,
            batch_size=batch_size,
            seed=seed,
            output_dir=results_dir,
        )

        print(f"Designed complex saved for: {pdb_file}")
    break


Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/4dtg_H_L_K_fixed.pdb
Design saved to: /spinning1/sharedby/hz362/AbFlow/data/baseline/abflow/4dtg_H_L_K.pdb
Temporary file removed: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/4dtg_H_L_K_fixed.pdb
Designed complex saved for: 4dtg_H_L_K.pdb


#### Absci affinity benchmark

In [None]:
# Evaluate Absci HER2 ZS using AbFlow
pdb_file = "/spinning1/sharedby/hz362/AbFlow/data/absci_her2_zs/absci_her2_zs.pdb"
parent_info = "/spinning1/sharedby/hz362/AbFlow/data/absci_her2_zs/absci_her2_zs_parent.csv"
mutated_info = "/spinning1/sharedby/hz362/AbFlow/data/absci_her2_zs/absci_her2_zs.csv"
results_dir = "/spinning1/sharedby/hz362/AbFlow/results"
heavy_chain_id = "B"
light_chain_id = "A"
antigen_chain_ids = ["C"]
scheme = "chothia"
num_designs = 1
batch_size = 1
seed = 2025

evaluate_mutated_pdb(
    pdb_file, parent_info, mutated_info, heavy_chain_id, light_chain_id, antigen_chain_ids, scheme, results_dir, num_designs, batch_size, seed
)

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import kendalltau, spearmanr

def plot_correlation_scatter(
    log_likelihood: np.ndarray, log_kd: np.ndarray, save_path: str = None
):
    """
    Create a scatter plot with a density contour and display Kendall and Spearman correlations.

    :param log_likelihood: Array of log-likelihood values.
    :param log_kd: Array of -log(KD) values.
    :param save_path: Optional path to save the plot. If None, the plot will be displayed.
    """

    plt.figure(figsize=(7, 5))
    plt.scatter(log_likelihood, log_kd, color="red", s=12)
    sns.kdeplot(x=log_likelihood, y=log_kd, fill=True, color="red", alpha=0.3, levels=5)

    kendall_tau, _ = kendalltau(log_likelihood, log_kd)
    spearman_rho, _ = spearmanr(log_likelihood, log_kd)

    plt.text(
        0.05,
        0.95,
        f"Kendall τ: {kendall_tau:.2f}\nSpearman ρ: {spearman_rho:.2f}",
        transform=plt.gca().transAxes,
        fontsize=14,
        verticalalignment="top",
        bbox=dict(boxstyle="round", alpha=0.1),
    )
    plt.xlabel("Log-likelihood", fontsize=16)
    plt.ylabel("-log(KD)", fontsize=16)
    plt.tick_params(axis="x", labelsize=14)
    plt.tick_params(axis="y", labelsize=14)
    plt.grid(True)

    if save_path:
        plt.savefig(save_path, bbox_inches="tight")
    else:
        plt.show()

    plt.close()
    
# Load the saved data
data_path = "/spinning1/sharedby/hz362/AbFlow/results/absci_her2_zs_likelihood.csv"
data = pd.read_csv(data_path)

# Extract log-likelihood and -log(KD) as numpy arrays
log_likelihood = data["likelihood"].to_numpy()
log_kd = data["-log(KD (M))"].to_numpy()

# Plot the correlation
plot_correlation_scatter(log_likelihood, log_kd)

# Evaluation pipeline

#### test for two pdb compare

In [None]:
import os
pred_pdb_file = "/spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/1a14_H_L_N.pdb"
true_pdb_file = "/spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/1a14_H_L_N.pdb"
base_name = os.path.basename(true_pdb_file)
parts = base_name.split(".")[0].split("_")
heavy_chain_id = parts[1]
light_chain_id = parts[2]
antigen_chain_ids = parts[3:]
scheme = "chothia"

evaluate_two_pdbs(
	pred_pdb_file=pred_pdb_file,
	true_pdb_file=true_pdb_file,
	heavy_chain_id=heavy_chain_id,
	light_chain_id=light_chain_id,
	antigen_chain_ids=antigen_chain_ids,
	scheme=scheme
)

#### Baseline models

In [None]:
# load the datamodule only

from abflow.model.datamodule import AntibodyAntigenDataModule
import yaml

config_path = "/spinning1/sharedby/hz362/AbFlow/config/test.yaml"
with open(config_path, "r") as f:
	config = yaml.load(f, Loader=yaml.FullLoader)

datamodule = AntibodyAntigenDataModule(config["datamodule"])

In [9]:
print("\n========== diffab DN mode ==========")
design_pdb_dir = "/spinning1/sharedby/hz362/AbFlow/data/baseline/diffab/DN"
true_pdb_dir = "/spinning1/sharedby/hz362/AbFlow/data/rabd/pdb"

results = evaluate_designs(design_pdb_dir, true_pdb_dir, scheme="chothia")

final_avg = average_all_structures(results)
for k, v in sorted(final_avg.items()):
    print(f"{k}: {v:.4f}")


🧬 Found 55 design PDBs.
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/baseline/diffab/DN/1a14_H_L_N_fixed.pdb
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/1a14_H_L_N_fixed.pdb
Temporary file removed: /spinning1/sharedby/hz362/AbFlow/data/baseline/diffab/DN/1a14_H_L_N_fixed.pdb
Temporary file removed: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/1a14_H_L_N_fixed.pdb
✅ Evaluated 1a14_H_L_N
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/baseline/diffab/DN/1a2y_B_A_C_fixed.pdb
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/1a2y_B_A_C_fixed.pdb
Temporary file removed: /spinning1/sharedby/hz362/AbFlow/data/baseline/diffab/DN/1a2y_B_A_C_fixed.pdb
Temporary file removed: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/1a2y_B_A_C_fixed.pdb
✅ Evaluated 1a2y_B_A_C
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/baseline/diffab/DN/1fe8_H_L_A_fixed.pdb
Fixed PDB file saved to: /spinning1/sharedby/h

In [10]:
print("\n========== MEAN DN mode ==========")
design_pdb_dir = "/spinning1/sharedby/hz362/AbFlow/data/baseline/mean"
true_pdb_dir = "/spinning1/sharedby/hz362/AbFlow/data/rabd/pdb"

results = evaluate_designs(design_pdb_dir, true_pdb_dir, scheme="chothia")

final_avg = average_all_structures(results)
for k, v in sorted(final_avg.items()):
    print(f"{k}: {v:.4f}")


🧬 Found 55 design PDBs.
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/baseline/mean/1a14_H_L_N_fixed.pdb
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/1a14_H_L_N_fixed.pdb
Temporary file removed: /spinning1/sharedby/hz362/AbFlow/data/baseline/mean/1a14_H_L_N_fixed.pdb
Temporary file removed: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/1a14_H_L_N_fixed.pdb
✅ Evaluated 1a14_H_L_N
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/baseline/mean/1a2y_B_A_C_fixed.pdb
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/1a2y_B_A_C_fixed.pdb
Temporary file removed: /spinning1/sharedby/hz362/AbFlow/data/baseline/mean/1a2y_B_A_C_fixed.pdb
Temporary file removed: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/1a2y_B_A_C_fixed.pdb
✅ Evaluated 1a2y_B_A_C
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/baseline/mean/1fe8_H_L_A_fixed.pdb
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb

In [11]:
print("\n========== ABX DN mode ==========")
design_pdb_dir = "/spinning1/sharedby/hz362/AbFlow/data/baseline/abx"
true_pdb_dir = "/spinning1/sharedby/hz362/AbFlow/data/rabd/pdb"

results = evaluate_designs(design_pdb_dir, true_pdb_dir, scheme="chothia")

final_avg = average_all_structures(results)
for k, v in sorted(final_avg.items()):
    print(f"{k}: {v:.4f}")


🧬 Found 55 design PDBs.
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/baseline/abx/1a14_H_L_N_fixed.pdb
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/1a14_H_L_N_fixed.pdb
Temporary file removed: /spinning1/sharedby/hz362/AbFlow/data/baseline/abx/1a14_H_L_N_fixed.pdb
Temporary file removed: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/1a14_H_L_N_fixed.pdb
✅ Evaluated 1a14_H_L_N
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/baseline/abx/1a2y_B_A_C_fixed.pdb
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/1a2y_B_A_C_fixed.pdb
Temporary file removed: /spinning1/sharedby/hz362/AbFlow/data/baseline/abx/1a2y_B_A_C_fixed.pdb
Temporary file removed: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/1a2y_B_A_C_fixed.pdb
✅ Evaluated 1a2y_B_A_C
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/baseline/abx/1fe8_H_L_A_fixed.pdb
Fixed PDB file saved to: /spinning1/sharedby/hz362/AbFlow/data/rabd/pdb/1fe8

In [None]:
print("\n========== DYMEAN DN mode ==========")
design_pdb_dir = "/spinning1/sharedby/hz362/AbFlow/data/baseline/dymean"
true_pdb_dir = "/spinning1/sharedby/hz362/AbFlow/data/rabd/pdb"

results = evaluate_designs(design_pdb_dir, true_pdb_dir, scheme="chothia")

final_avg = average_all_structures(results)
for k, v in sorted(final_avg.items()):
    print(f"{k}: {v:.4f}")

In [None]:
print("\n========== rabd reference comparison ==========")
design_pdb_dir = "/spinning1/sharedby/hz362/AbFlow/data/rabd/pdb"
true_pdb_dir = "/spinning1/sharedby/hz362/AbFlow/data/rabd/pdb"

results = evaluate_designs(design_pdb_dir, true_pdb_dir, scheme="chothia")

final_avg = average_all_structures(results)
for k, v in sorted(final_avg.items()):
    print(f"{k}: {v:.4f}")

# Energy ddG

In [None]:
# dg stats

import json
import statistics

def calculate_dg_stats(json_filepath):
    ddg_values = []

    with open(json_filepath, 'r') as file:
        for line in file:
            if line.strip():  # Skip empty lines
                entry = json.loads(line)
                if 'interface_dG_separated' in entry:
                    ddg_values.append(entry['interface_dG_separated'])

    if not ddg_values:
        raise ValueError("No 'interface_dG_separated' values found in the data")

    average_ddg = sum(ddg_values) / len(ddg_values)
    median_ddg = statistics.median(ddg_values)

    print(f"\nFile: {json_filepath}")
    print(f"Number of entries: {len(ddg_values)}")
    print(f"Average ddG: {average_ddg:.4f}")
    print(f"Median ddG: {median_ddg:.4f}")

# Example usage:
calculate_dg_stats("/spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/abx/score.json")
calculate_dg_stats("/spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/diffab/score.json")
calculate_dg_stats("/spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/dymean/score.json")
calculate_dg_stats("/spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/mean/score.json")
calculate_dg_stats("/spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/rabd/score.json")



File: /spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/abx/score.json
Number of entries: 55
Average ddG: 723.9601
Median ddG: 183.8750

File: /spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/diffab/score.json
Number of entries: 55
Average ddG: 1584.7742
Median ddG: 330.3818

File: /spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/dymean/score.json
Number of entries: 55
Average ddG: 37998.7547
Median ddG: 14973.7461

File: /spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/mean/score.json
Number of entries: 55
Average ddG: 442.3900
Median ddG: 8.8841

File: /spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/rabd/score.json
Number of entries: 55
Average ddG: -31.2850
Median ddG: -29.8292


In [5]:
import json
import statistics

def load_ddg_dict(json_filepath):
    ddg_dict = {}
    with open(json_filepath, 'r') as file:
        for line in file:
            if line.strip():
                entry = json.loads(line)
                decoy = entry.get("decoy")
                ddg = entry.get("interface_dG_separated")
                if decoy is not None and ddg is not None:
                    ddg_dict[decoy] = ddg
    return ddg_dict

def calculate_ddg_vs_reference(design_file, rabd_file):
    design_ddg = load_ddg_dict(design_file)
    rabd_ddg = load_ddg_dict(rabd_file)

    ddg_deltas = []
    for decoy in design_ddg:
        if decoy in rabd_ddg:
            delta = design_ddg[decoy] - rabd_ddg[decoy]
            ddg_deltas.append(delta)

    if not ddg_deltas:
        raise ValueError("No matching decoys between design and rabd")

    avg_ddg = sum(ddg_deltas) / len(ddg_deltas)
    median_ddg = statistics.median(ddg_deltas)

    print(f"\nFile: {design_file}")
    print(f"Matched decoys: {len(ddg_deltas)}")
    print(f"Average ddG (vs RabD): {avg_ddg:.4f}")
    print(f"Median ddG (vs RabD): {median_ddg:.4f}")

# RabD reference file path
rabd_path = "/spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/rabd/score.json"

# Designs to compare against RabD
design_paths = [
    "/spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/abx/score.json",
    "/spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/diffab/score.json",
    "/spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/dymean/score.json",
    "/spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/mean/score.json"
]

for design_file in design_paths:
    calculate_ddg_vs_reference(design_file, rabd_path)



File: /spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/abx/score.json
Matched decoys: 55
Average ddG (vs RabD): 755.2451
Median ddG (vs RabD): 228.7647

File: /spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/diffab/score.json
Matched decoys: 55
Average ddG (vs RabD): 1616.0592
Median ddG (vs RabD): 363.0035

File: /spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/dymean/score.json
Matched decoys: 55
Average ddG (vs RabD): 38030.0397
Median ddG (vs RabD): 15004.3978

File: /spinning1/sharedby/hz362/AbFlow/data/talip_interface_scores_2/mean/score.json
Matched decoys: 55
Average ddG (vs RabD): 473.6750
Median ddG (vs RabD): 40.4227
