## 1. save results

In [1]:
import sys; sys.path.append("/gaozhangyang/experiments/DiffSDS")
from utils.data_tools import ReadPDB
import os
import numpy as np
from tqdm import tqdm
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_pdb(fname):
    ref_index, angles, coords, seqs, attn_mask, b_factor = ReadPDB.read_pdb(fname)
    mask = (b_factor==100)
    mask_idx = mask.nonzero()[0]
    left = mask_idx[0]
    right = mask_idx[-1]
    max_idx = (b_factor>0).nonzero()[0][-1]
    
    return coords, left, right, max_idx, b_factor

In [3]:
def error(pred_coords, true_coords, left, right):
    L_left = np.linalg.norm(pred_coords[left]-true_coords[left],axis=1)[1]
    L_right = np.linalg.norm(pred_coords[right]-true_coords[right],axis=1)[1]
    error = L_left + L_right
    return error

def read_CA_coords(smcdiff_fname):
    from biotite.structure.io.pdb import PDBFile
    import biotite.structure as struc
    source = PDBFile.read(str(smcdiff_fname))
    source_struct = source.get_structure(extra_fields=["b_factor"])[0]
    backbone_atoms = source_struct[struc.filter_backbone(source_struct)]
    ca = [c for c in backbone_atoms if c.atom_name in ["CA"]]
    coord_arrays = np.vstack([c.coord for c in ca])
    return coord_arrays

def get_connectiveness_error(pdb_name):
    diffsds_fname = f"/gaozhangyang/experiments/DiffSDS/results/DiffSDS_sampling/pred_{pdb_name}.pdb"
    foldingdiff_fname = f"/gaozhangyang/experiments/DiffSDS/results/Ccfoldingdiff_sampling/pred_{pdb_name}.pdb"
    # RFDesign_fname = f"/gaozhangyang/experiments/RFDesign/inpainting/tests/out/test_{pdb_name}_0.pdb"
    # smcdiff_fname = f"/gaozhangyang/experiments/ProreinBinder/results/inpaint_SMCDiff/pred_{pdb_name}.pdb"
    
    true_frame = f"/gaozhangyang/experiments/DiffSDS/results/DiffSDS_sampling/raw_{pdb_name}.pdb"
        
    diffsds_coords, left, right, max_idx, b_factor = load_pdb(diffsds_fname)
    foldingdiff_coords, _, _, _, _ = load_pdb(foldingdiff_fname)
    true_coords, _, _, _, _ = load_pdb(true_frame)
    # _, _, RFDesign_coords, _, _, _ = ReadPDB.read_pdb(RFDesign_fname)
    # SMC_coords = read_CA_coords(smcdiff_fname)

    diffsds_error = error(diffsds_coords, true_coords, left, right)
    foldingdiff_error = error(foldingdiff_coords, true_coords, left, right)
    # RFDesign_error = error(RFDesign_coords, true_coords, left, right)
    # SMC_error = error(SMC_coords, true_coords, left, right)
    
    return (right-left)+1, max_idx+1, diffsds_error, foldingdiff_error

In [7]:
pdb_name_list = os.listdir("/gaozhangyang/experiments/DiffSDS/results/DiffSDS_sampling/")
pdb_name_list = [one for one in pdb_name_list if "pred_" in one]
pdb_name_list = [one.split("_")[-1][:-4] for one in pdb_name_list]

error_table = pd.DataFrame(columns=["mask_len", "len", "diffsds_error", "foldingdiff_error", "SMC_error", "RFDesign_error"])
for i, pdb_name in enumerate(tqdm(pdb_name_list)):
    mask_len, all_len, diffsds_error, foldingdiff_error = get_connectiveness_error(pdb_name)
    error_table.loc[i, "mask_len"] = mask_len
    error_table.loc[i, "len"] = all_len
    error_table.loc[i, "diffsds_error"] = diffsds_error
    error_table.loc[i, "foldingdiff_error"] = foldingdiff_error
    # error_table.loc[i, "RFDesign_error"] = RFDesign_error
    # error_table.loc[i, "SMC_error"] = SMC_error
error_table.to_csv("./results/connectiveness.csv")

100%|██████████| 378/378 [00:13<00:00, 27.94it/s]


## 2. 统计

In [8]:
error_table.loc[:,"one"] = 1
diffsds_error_per_length = error_table.groupby('mask_len')["diffsds_error"].sum()/error_table.groupby('mask_len')["one"].sum()
foldingdiff_error_per_length = error_table.groupby('mask_len')["foldingdiff_error"].sum()/error_table.groupby('mask_len')["one"].sum()

In [10]:
print(foldingdiff_error_per_length[:11].mean(), foldingdiff_error_per_length[11:26].mean(), foldingdiff_error_per_length[26:].mean())

10.847035554070306 18.936291400102164 27.97886862073626


In [11]:
print(diffsds_error_per_length[:11].mean(), diffsds_error_per_length[11:26].mean(), diffsds_error_per_length[26:].mean())

6.770061509793061 9.276604485444057 7.731335335686093


In [11]:
statistics = pd.DataFrame(columns=["mask_len", "RFDesign_error", "foldingdiff_error", "diffsds_error", "smc_error"])
statistics['mask_len'] = foldingdiff_error_per_length.index
statistics['foldingdiff_error'] = foldingdiff_error_per_length.values
statistics['diffsds_error'] = diffsds_error_per_length.values
statistics.to_csv("/gaozhangyang/experiments/ProreinBinder/evaluate/results/statistics_connectiveness.csv")