In [1]:
import json
import pandas as pd
import faiss
import numpy as np
from tqdm import tqdm

import sys
sys.path.append("/homefs/home/fanga5/cdrcluster")
from evals.metrics import dihedral_distance
from evals.align_loops import kabsch_numpy

EMBEDDING_DIR = "/data2/fanga5/benchmarking_data/paratope_binning/"
LOOP_TYPE = "H3"

### Save Loop Type Sequences

In [2]:
# select_loops = []
# with open(f"/data/fanga5/data/test_loop_len_all_seed_42.jsonl", "r") as f:
#     for line in f:
#         item = json.loads(line)
#         if item['loop_id'].endswith(LOOP_TYPE):
#             select_loops.append(item)
# select_loops_ids = [int(item['loop_id'].split("_")[0]) for item in select_loops]
# print(f"Number of test {LOOP_TYPE} loops: {len(select_loops_ids)}")

In [3]:
# # for the files with resolution 3.5 suffix, the sabdab_id is wrong since it was numberd after filtering out for resolution
# # fix it by using ab_fname for id mapping
# raw_df = pd.read_parquet("/data/fanga5/preprocessed_data/sabdab_2025-05-06-paired.parquet")
# raw_df['sabdab_id'] = range(len(raw_df))
# ab_fname_to_id = {fname: sabdab_id for fname, sabdab_id in zip(raw_df['ab_fname'], raw_df['sabdab_id'])}

# data_df = pd.read_parquet("/data/fanga5/preprocessed_data/sabdab_2025-05-06-paired_chains_resolution_3.5.parquet")
# data_df['sabdab_id'] = data_df['ab_fname'].map(ab_fname_to_id)
# data_df = data_df[data_df['chain_id'] == LOOP_TYPE[0]]
# data_df = data_df[['sabdab_id', f'CDR{LOOP_TYPE[1]}_start', f'CDR{LOOP_TYPE[1]}_end', 'sequence']]
# data_df = data_df.rename(columns={f'CDR{LOOP_TYPE[1]}_start': 'start', f'CDR{LOOP_TYPE[1]}_end': 'end', 'sabdab_id': 'loop_id'})
# data_df['loop_sequence'] = data_df.apply(lambda row: row['sequence'][row['start']:row['end']], axis=1)
# data_df['test'] = data_df['loop_id'].isin(select_loops_ids)

# data_df.to_csv(f"/data/fanga5/data/sabdab_{LOOP_TYPE}_loops.csv", index=False)
# print(f"Saved {LOOP_TYPE} loops ({len(data_df)}) to /data/fanga5/data/sabdab_{LOOP_TYPE}_loops.csv")

### Add 3di information

In [4]:
# fasta_3di = "/data2/fanga5/sabdab/sabdab_db_ss.fasta"
# seqeunces_3di = {}
# with open(fasta_3di, "r") as f:
#     for line in f:
#         if line.startswith(">"):
#             fname = line.strip().replace(">", "")
#             seqeunces_3di[fname] = ""
#         else:
#             seqeunces_3di[fname] += line.strip()

# raw_df = pd.read_parquet("/data/fanga5/preprocessed_data/sabdab_2025-05-06-paired.parquet")
# raw_df['sabdab_id'] = range(len(raw_df))
# ab_fname_to_id = {fname: sabdab_id for fname, sabdab_id in zip(raw_df['ab_fname'], raw_df['sabdab_id'])}
# id_to_ab_fname = {sabdab_id: fname for fname, sabdab_id in ab_fname_to_id.items()}

# data_df = pd.read_csv(f"/data/fanga5/data/sabdab_{LOOP_TYPE}_loops.csv")
# data_df['ab_fname'] = data_df['loop_id'].map(id_to_ab_fname)
# data_df['ab_fname_chain'] = data_df['ab_fname'].str.replace(".pdb", f"_{LOOP_TYPE[0]}")
# data_df.loc[~data_df['ab_fname_chain'].isin(seqeunces_3di), 'ab_fname_chain'] = data_df['ab_fname'].str.replace(".pdb", "")
# assert data_df['ab_fname_chain'].isin(seqeunces_3di).all(), "Some ab_fname_chain are not in the 3di sequences"
# data_df['3di_sequence'] = data_df['ab_fname_chain'].map(seqeunces_3di)
# data_df.to_csv(f"/data/fanga5/data/sabdab_{LOOP_TYPE}_loops_with_3di.csv", index=False)
# print(f"Saved {LOOP_TYPE} loops with 3di sequences ({len(data_df)}) to /data/fanga5/data/sabdab_{LOOP_TYPE}_loops_with_3di.csv")

### Add angle information

In [5]:
# loop_df = pd.read_parquet("/data/fanga5/preprocessed_data/sabdab_2025-05-06-paired_loops.parquet")
# loop_df = loop_df[loop_df['loop_type'] == LOOP_TYPE]

# data_df = pd.read_csv(f"/data/fanga5/data/sabdab_{LOOP_TYPE}_loops.csv")
# data_df_with_angles = data_df.merge(loop_df[['sabdab_id', 'phi', 'psi', 'omega', 'c_alpha_atoms', 'stem_c_alpha_atoms']], left_on='loop_id', right_on='sabdab_id', how='inner')
# data_df_with_angles.rename(columns={'c_alpha_atoms': 'loop_c_alpha_atoms'}, inplace=True)
# data_df_with_angles[['loop_id', 'loop_sequence', 'phi', 'psi', 'omega', 'loop_c_alpha_atoms', 'stem_c_alpha_atoms']].to_parquet(f"/data/fanga5/data/sabdab_{LOOP_TYPE}_loops_with_angles.parquet", index=False)
# print(f"Saved {LOOP_TYPE} loops with angles ({len(data_df_with_angles)}) to /data/fanga5/data/sabdab_{LOOP_TYPE}_loops_with_angles.parquet")

# Retrieval set up

In [6]:
loop_df = pd.read_parquet("/data/fanga5/preprocessed_data/sabdab_2025-05-06-paired_loops.parquet")
loop_df = loop_df[loop_df['loop_type'] == LOOP_TYPE]

loop_df['angles'] = loop_df.apply(lambda row: np.stack([row['phi'], row['psi'], row['omega']], axis=1), axis=1)
loop_id_to_angles = {row['sabdab_id']: row['angles'] for _, row in loop_df.iterrows()}
loop_id_to_calpha = {row['sabdab_id']: np.stack(row['c_alpha_atoms'].tolist() + row['c_atoms'].tolist() + row['n_atoms'].tolist()) for _, row in loop_df.iterrows()}

In [7]:
data_df = pd.read_csv(f"/data/fanga5/data/sabdab_{LOOP_TYPE}_loops.csv")
data_df['loop_len'] = data_df['end'] - data_df['start']
data_df

Unnamed: 0,loop_id,start,end,sequence,loop_sequence,test,loop_len
0,0,96,105,EVQLQQPGPELVKPGASVKVSCKASGYSFTDHNMYWVKQSHGKSLE...,YIGSFYFVY,False,9
1,1,96,105,EVQLQQPGPELVKPGASVKVSCKASGYSFTDHNMYWVKQSHGKSLE...,YIGSFYFVY,False,9
2,2,96,107,QVQLVQSGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLE...,ARDSGSGRFDP,False,11
3,3,96,107,QVQLVQSGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLE...,ARDSGSGRFDP,False,11
4,4,96,107,QVQLVQSGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLE...,ARDSGSGRFDP,False,11
...,...,...,...,...,...,...,...
15282,18823,95,98,QVQLRESGPSLVKPSQTLSLTCTASGLSLSDKAVGWVRRAPTKALE...,ATV,False,3
15283,18827,94,106,VQLVESGGGLVQPGGSLRLSCAASEFIVSANYMSWVRQAPGKGLEW...,ARFLPTYDYFDY,False,12
15284,18829,96,107,QVQFQQSGAELVKPGASVKLSCKASGYTFTSYLMHWIKQRPGRGLE...,ARYAYCRPMDY,False,11
15285,18830,96,107,QVQFQQSGAELVKPGASVKLSCKASGYTFTSYLMHWIKQRPGRGLE...,ARYAYCRPMDY,False,11


In [8]:
valid_loop_ids = []
for _, row in data_df.iterrows():
    if not row['loop_id'] in loop_id_to_angles:
        continue
    if row['loop_len'] == loop_id_to_angles[row['loop_id']].shape[0]:
        valid_loop_ids.append(row['loop_id'])
data_df['valid_loop'] = data_df['loop_id'].isin(valid_loop_ids)
len(valid_loop_ids)

15281

In [9]:
import h5py

prostt5_embeddings = []
with h5py.File(f"{EMBEDDING_DIR}/prostt5_embeddings_{LOOP_TYPE}.h5", "r") as f:
    for loop_id in data_df['loop_id'].values:
        if str(loop_id) not in f.keys():
            print(f"Loop ID {loop_id} not found in embeddings file.")
            prostt5_embeddings.append(np.zeros((1024,)))
        else:
            prostt5_embeddings.append(f[str(loop_id)][:])
prostt5_embeddings = np.array(prostt5_embeddings)

prostt5_3di_embeddings = []
with h5py.File(f"{EMBEDDING_DIR}/prostt5_3di_embeddings_{LOOP_TYPE}.h5", "r") as f:
    for loop_id in data_df['loop_id'].values:
        if str(loop_id) not in f.keys():
            print(f"Loop ID {loop_id} not found in embeddings file.")
            prostt5_3di_embeddings.append(np.zeros((1024,)))
        else:
            prostt5_3di_embeddings.append(f[str(loop_id)][:])
prostt5_3di_embeddings = np.array(prostt5_3di_embeddings)

In [10]:
igloo_embeddings = np.load(f"{EMBEDDING_DIR}/IgLoo_sabdab_{LOOP_TYPE}.npy")
ablang2_embeddings = np.load(f"{EMBEDDING_DIR}/ablang2_embeddings_{LOOP_TYPE}.npy")
esm2_embeddings = np.load(f"{EMBEDDING_DIR}/esm2_embeddings_{LOOP_TYPE}.npy")
emc_embeddings = np.load(f"{EMBEDDING_DIR}/esmc_embeddings_{LOOP_TYPE}.npy")
igbert_embeddings = np.load(f"{EMBEDDING_DIR}/igbert_embeddings_{LOOP_TYPE}.npy")
saprot_embeddings = np.load(f"{EMBEDDING_DIR}/saprot_3di_embeddings_{LOOP_TYPE}.npy")
foldseek3di_embeddings = np.load(f"{EMBEDDING_DIR}/foldseek3di_embeddings_{LOOP_TYPE}.npy")

In [11]:
# load igloo with angle embeddings, some of the angles are missing so we have to remap them

def get_igloo_angle_embeddings(fname):
    if fname.endswith(".jsonl"):
        igloo_angle_embeddings_raw = {}
        with open(fname, "r") as f:
            for line in f:
                item = json.loads(line)
                igloo_angle_embeddings_raw[item['id']] = item['encoded']
    else:
        igloo_embeddings_df = pd.read_parquet(fname)
        igloo_angle_embeddings_raw = {row['loop_id']: row['encoded'] for _, row in igloo_embeddings_df.iterrows()}

    igloo_angle_embeddings = []
    for loop_id in data_df['loop_id']:
        if loop_id in igloo_angle_embeddings_raw:
            igloo_angle_embeddings.append(np.array(igloo_angle_embeddings_raw[loop_id]))
        else:
            igloo_angle_embeddings.append(np.zeros(128))
            print(f"Missing igloo angle embedding for loop_id: {loop_id}")
    igloo_angle_embeddings = np.stack(igloo_angle_embeddings)
    return igloo_angle_embeddings

igloo_angle_embeddings = get_igloo_angle_embeddings(f"/data2/fanga5/benchmarking_data/paratope_binning/IgLoo_sabdab_{LOOP_TYPE}_with_angles.jsonl")
igloo_no_dihedral_loss_embeddings = get_igloo_angle_embeddings(f"/data2/fanga5/benchmarking_data/paratope_binning/Igloo_ablation_no_dihedral_loss_sabdab_{LOOP_TYPE}.parquet")
igloo_no_sequence_embeddings = get_igloo_angle_embeddings(f"/data2/fanga5/benchmarking_data/paratope_binning/Igloo_ablation_no_sequence_sabdab_{LOOP_TYPE}.parquet")
igloo_no_dihedrals_embeddings = get_igloo_angle_embeddings(f"/data2/fanga5/benchmarking_data/paratope_binning/Igloo_ablation_no_dihedrals_sabdab_{LOOP_TYPE}.parquet")
igloo_tol1_embeddings = get_igloo_angle_embeddings(f"/data2/fanga5/benchmarking_data/paratope_binning/Igloo_ablation_tol1_v222_epoch30_sabdab_{LOOP_TYPE}.parquet")
igloo_no_dihedral_threshold_embeddings = get_igloo_angle_embeddings(f"/data2/fanga5/benchmarking_data/paratope_binning/Igloo_ablation_no_dihedral_threshold_sabdab_{LOOP_TYPE}.parquet")

Missing igloo angle embedding for loop_id: 15905
Missing igloo angle embedding for loop_id: 17275
Missing igloo angle embedding for loop_id: 15905
Missing igloo angle embedding for loop_id: 17275
Missing igloo angle embedding for loop_id: 15905
Missing igloo angle embedding for loop_id: 17275
Missing igloo angle embedding for loop_id: 15905
Missing igloo angle embedding for loop_id: 17275
Missing igloo angle embedding for loop_id: 15905
Missing igloo angle embedding for loop_id: 17275
Missing igloo angle embedding for loop_id: 15905
Missing igloo angle embedding for loop_id: 17275


In [12]:
embeddings = {
    'igloo_no_dihedral_loss_embeddings': igloo_no_dihedral_loss_embeddings,
    'igloo_no_sequence_embeddings': igloo_no_sequence_embeddings,
    'igloo_no_dihedrals_embeddings': igloo_no_dihedrals_embeddings,
    'igloo_tol1_embeddings': igloo_tol1_embeddings,
    'igloo_no_dihedral_threshold_embeddings': igloo_no_dihedral_threshold_embeddings,
    'ablang2': ablang2_embeddings,
    'esm2': esm2_embeddings,
    'emc': emc_embeddings,
    'igbert': igbert_embeddings,
    'prostt5': prostt5_embeddings,
    'igloo_angle': igloo_angle_embeddings,
    'prostt5_3di': prostt5_3di_embeddings,
    'saprot_3di': saprot_embeddings,
    'foldseek': foldseek3di_embeddings,
}

# Run retrieval

In [13]:
results = []

for knn in [20]: # 1, 5, 10, 
    for LOOP_LEN in data_df[data_df['test']]['loop_len'].value_counts().index.tolist():
        test_mask = (data_df['loop_len'] == LOOP_LEN) & (data_df['test']) & (data_df['valid_loop'])
        dataset_mask = (data_df['loop_len'] == LOOP_LEN) & (~data_df['test']) & (data_df['valid_loop'])
        original_indices = np.where(dataset_mask)[0]
        test_indices = np.where(test_mask)[0]

        if np.sum(original_indices) <= knn:
            print(f"Not enough loops for knn={knn} and loop_len={LOOP_LEN}. Skipping...")
            continue

        test_angles = [loop_id_to_angles[data_df['loop_id'][test_indices[i]]] for i in range(len(test_indices))]
        test_angles = np.stack(test_angles, axis=0)

        dataset_angles = [loop_id_to_angles[data_df['loop_id'][original_indices[i]]] for i in range(len(original_indices))]
        dataset_angles = np.stack(dataset_angles, axis=0)

        D_all = dihedral_distance(test_angles, dataset_angles)
        D_all_bin = D_all < 0.47

        valid_values = (D_all_bin.sum(axis=1) != 0) # filter out cases where no loops with similar dihedrals are found 
        for embedding_name, embeddings_dataset in embeddings.items():
            query_embeddings = embeddings_dataset[test_mask]
            embeddings_dataset_ = embeddings_dataset[dataset_mask]
            
            # Use cosine similarity
            query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
            embeddings_dataset_ = embeddings_dataset_ / np.linalg.norm(embeddings_dataset_, axis=1, keepdims=True)
            index = faiss.IndexFlatIP(embeddings_dataset_.shape[1])  # Use inner product for cosine similarity
            index.add(embeddings_dataset_)
            _, retrieved_indices = index.search(query_embeddings, knn)

            # Use L2
            # index = faiss.IndexFlatL2(embeddings_dataset_.shape[1])
            # index.add(embeddings_dataset_)
            # _, retrieved_indices = index.search(query_embeddings, knn)

            precision = D_all_bin[np.arange(D_all_bin.shape[0])[:, None], retrieved_indices][valid_values].mean()
            recall = np.mean(D_all_bin[np.arange(D_all_bin.shape[0])[:, None], retrieved_indices].sum(axis=1)[valid_values] / D_all_bin.sum(axis=1)[valid_values])
            hits = np.mean((D_all_bin[np.arange(D_all_bin.shape[0])[:, None], retrieved_indices].sum(axis=1) > 0)[valid_values])

            # RMSD
            rmsd_values = []
            for i in range(len(test_indices)):
                test_calpha = loop_id_to_calpha[data_df['loop_id'][test_indices[i]]]
                centroid1 = np.mean(test_calpha, axis=0)
                test_calpha_centered = test_calpha - centroid1

                for j in range(len(retrieved_indices[i])):
                    retrieved_calpha = loop_id_to_calpha[data_df['loop_id'][original_indices[retrieved_indices[i][j]]]]
                    centroid2 = np.mean(retrieved_calpha, axis=0)
                    retrieved_calpha_centered = retrieved_calpha - centroid2

                    _, _, rmsd = kabsch_numpy(test_calpha_centered, retrieved_calpha_centered)
                    rmsd_values.append(rmsd)

            results.append({
                'embedding': embedding_name,
                'precision': precision,
                'recall': recall,
                'hits': hits,
                'knn': knn,
                'loop_len': LOOP_LEN,
                'rmsd': np.mean(rmsd_values),
                'rmsd_precision': np.mean(np.array(rmsd_values) < 1.0),  # consider RMSD < 2.0 as a hit
            })
results = pd.DataFrame(results)

In [14]:
results.groupby(['knn', 'embedding']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,precision,recall,hits,loop_len,rmsd,rmsd_precision
knn,embedding,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
20,ablang2,0.222329,0.22365,0.697239,15.5,2.676202,0.17329
20,emc,0.208441,0.207271,0.677094,15.5,2.650888,0.18996
20,esm2,0.236538,0.181721,0.645844,15.5,2.616186,0.206248
20,foldseek,0.36177,0.313618,0.857018,15.5,2.267806,0.2811
20,igbert,0.21586,0.218832,0.704361,15.5,2.639347,0.181773
20,igloo_angle,0.401525,0.33607,0.883748,15.5,2.450542,0.277825
20,igloo_no_dihedral_loss_embeddings,0.334637,0.208014,0.788578,15.5,2.558965,0.241513
20,igloo_no_dihedral_threshold_embeddings,0.416734,0.29402,0.856581,15.5,2.463137,0.279402
20,igloo_no_dihedrals_embeddings,0.216782,0.194406,0.665094,15.5,2.787291,0.1934
20,igloo_no_sequence_embeddings,0.356375,0.283101,0.851591,15.5,2.618416,0.245031
