# Calculate scTM and RMSD between gt and predicted

In [14]:
from tqdm import tqdm
import os
import csv
from Bio import BiopythonDeprecationWarning
from tmtools import tm_align
from analysis import metrics
from data import utils as du
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=BiopythonDeprecationWarning)

from tmtools.io import get_structure, get_residue_data
from Bio.PDB.Polypeptide import three_to_one
import numpy as np

In [15]:
def extract_coords_and_seq(pdb_path):
    s = get_structure(pdb_path)

    chain = next(s.get_chains())
    coords = []
    seq = []

    for residue in chain:
        try:
            aa_code = three_to_one(residue.resname)

            if "CA" in residue.child_dict:
                coords.append(residue.child_dict["CA"].coord)
                seq.append(aa_code)
        except KeyError:
            if residue.resname == "MSE":
                coords.append(residue.child_dict["CA"].coord)
                seq.append("M")

    coords = np.array(coords)
    seq = "".join(seq)
    
    return coords, seq

In [16]:
def calculate_tm_alignment(pdb_path_1, pdb_path_2):
    coords_A, seq_A = extract_coords_and_seq(pdb_path_1)
    coords_B, seq_B = extract_coords_and_seq(pdb_path_2)

    res = tm_align(coords_A, coords_B, seq_A, seq_B)

    return res.tm_norm_chain2

In [24]:

def calculate_protein_metrics(ref_folder, esmf_folder, csv_path, csv_path_error):
    ref_filenames = os.listdir(ref_folder)
    esmf_filenames = os.listdir(esmf_folder)
    
    results = []
    error_file = []

    for ref_filename in tqdm(ref_filenames):
        ref_pdb_path = os.path.join(ref_folder, ref_filename)
        pdb_name = ref_filename.split(".")[0]
        
        for esmf_filename in esmf_filenames:
            esmf_pdb_path = os.path.join(esmf_folder, esmf_filename)
            
            if esmf_filename.startswith(pdb_name):
                try:
                    scTM_value = calculate_tm_alignment(ref_pdb_path, esmf_pdb_path)
                    coords_A, seq_A = extract_coords_and_seq(ref_pdb_path)
                    coords_B, seq_B = extract_coords_and_seq(esmf_pdb_path)
                    aligned_pos_1 = du.rigid_transform_3D(coords_A, coords_B)[0]
                    rmsd_value = np.mean(np.linalg.norm(aligned_pos_1 - coords_B, axis=-1))
                    
                    results.append((ref_filename, esmf_filename, scTM_value, rmsd_value))
                except Exception as e:
                    error_file.append((ref_filename, esmf_filename, e))
                    continue
    
    with open(csv_path, "w", newline="") as csv_file:
        csv_writer = csv.writer(csv_file)
        csv_writer.writerow(["Reference PDB", "ESMF PDB", "scTM Value", "RMSD Value"])
        csv_writer.writerows(results)
        
    with open(csv_path_error, "w", newline="") as csv_file:
        csv_writer = csv.writer(csv_file)
        csv_writer.writerow(["Reference PDB", "ESMF PDB", "error"])
        csv_writer.writerows(error_file)
    print("All scTM and RMSD calculations completed and saved to CSV.")

# CATH4.2 Pifold

In [25]:
esmf_folder = "/home/zhengsun/code/protein/ProteinInvBench/results/esm_fold_pdb/cath42pifold"
ref_folder = "/home/zhengsun/code/protein/ProteinInvBench/results/reference_pdb"
csv_path = "/home/zhengsun/code/protein/ProteinInvBench/results/scTM_RMSD/cath42pifold_2.csv"
csv_path_error = "/home/zhengsun/code/protein/ProteinInvBench/results/scTM_RMSD/cath42pifold_error.csv"

calculate_protein_metrics(ref_folder, esmf_folder, csv_path, csv_path_error)

100%|██████████| 2647/2647 [05:51<00:00,  7.52it/s]

All scTM and RMSD calculations completed and saved to CSV.





# CATH4.2 ProteinMPNN

In [26]:
esmf_folder = "/home/zhengsun/code/protein/ProteinInvBench/results/esm_fold_pdb/cath42mpnn"
ref_folder = "/home/zhengsun/code/protein/ProteinInvBench/results/reference_pdb"
csv_path = "/home/zhengsun/code/protein/ProteinInvBench/results/scTM_RMSD/cath42mpnn_3.csv"
csv_path_error = "/home/zhengsun/code/protein/ProteinInvBench/results/scTM_RMSD/cath42mpnn_error.csv"

calculate_protein_metrics(ref_folder, esmf_folder, csv_path, csv_path_error)

100%|██████████| 2647/2647 [05:39<00:00,  7.80it/s]

All scTM and RMSD calculations completed and saved to CSV.



