In [1]:
import os 
import pickle
from tqdm.auto import tqdm
import math
import random
import numpy as np
import torch
import matplotlib.pyplot as plt

from utils_eval import dict_load, dict_save



In [12]:
def ms_stat(v_dict, threshold = 2.0, idx = -2):
    tms_list = []
    rmsd_list = []

    stat_dict = {}
    
    for name in v_dict:
        target = '_'.join(name.split('_')[:idx])
        att = name.split('_')[-1]
        
        if target not in stat_dict:
            stat_dict[target] = [0, 0]
        stat_dict[target][0] += 1

        rmsd_min = float('inf')
        for pred in v_dict[name]:
            tms, rmsd = v_dict[name][pred]
            tms_list.append(tms)
            rmsd_list.append(rmsd)
            rmsd_min = min(rmsd_min, rmsd)

        if rmsd_min <= threshold:
            stat_dict[target][1] += 1

    ### statistics
    print('tmscore = %f; rmsd = %f' % (np.mean(tms_list), np.mean(rmsd_list)))
    suc_num = 0
    suc_ratio = []
    for target in stat_dict:
        sr = stat_dict[target][1] / stat_dict[target][0]
        suc_ratio.append(sr)
        if stat_dict[target][1] > 0:
            suc_num += 1

    print('%d successful cases out of %d. suc_ratio=%f' % (suc_num, len(stat_dict), np.mean(suc_ratio)))

## Baselines

In [9]:
rfdiff_dict = dict_load('../../Results/Baseline_motif/RFdiffusion/designability_esmfold_TMscore_dict.pkl')

In [13]:
ms_stat(rfdiff_dict, threshold = 2.0, idx = -1)

tmscore = 0.792488; rmsd = 1.496438
20 successful cases out of 20. suc_ratio=0.895500


## Distilled Model

In [2]:
PATH = '../../Results/jointDiff-distill_motif/'

In [4]:
val_dict = {}

for model in os.listdir(PATH):
    dict_path = os.path.join(PATH, model, 'designability_esmfold_TMscore_dict.pkl')

    if not os.path.exists(dict_path):
        continue

    dset = model.split('distill-')[1].split('_')[0]
    kind = model.split('_')[-1]
    name = '%s_%s' % (dset, kind)
    if 'rm' in model.split('_')[-2]:
        name = '%s_%s' % (name, model.split('_')[-2])
    if '_sc_' in model:
        name = '%s_%s' % ('sc', name)

    print(name)
    val_dict[name] = dict_load(dict_path)

sc_merge_cold_rm-conse
merge_cold_rm-conse
sc_merge_cold_rm
nature_warm
gen_warm
merge_cold
merge_warm
merge_cold_rm
sc_merge_warm_rm
merge_warm_rm-conse
gen_cold
merge_warm_rm
sc_merge_warm_rm-conse
nature_cold


In [8]:
for name in val_dict:
    print(name)
    ms_stat(val_dict[name], threshold = 2.0)
    print()

sc_merge_cold_rm-conse
tmscore = 0.407787; rmsd = 3.133257
23 successful cases out of 24. suc_ratio=0.230417

merge_cold_rm-conse
tmscore = 0.417274; rmsd = 3.040197
24 successful cases out of 24. suc_ratio=0.277917

sc_merge_cold_rm
tmscore = 0.459720; rmsd = 2.804802
24 successful cases out of 24. suc_ratio=0.396250

nature_warm
tmscore = 0.384668; rmsd = 2.987067
24 successful cases out of 24. suc_ratio=0.334167

gen_warm
tmscore = 0.397189; rmsd = 2.953354
24 successful cases out of 24. suc_ratio=0.353333

merge_cold
tmscore = 0.371469; rmsd = 3.123261
24 successful cases out of 24. suc_ratio=0.290000

merge_warm
tmscore = 0.364606; rmsd = 3.083555
24 successful cases out of 24. suc_ratio=0.301250

merge_cold_rm
tmscore = 0.451006; rmsd = 2.927745
23 successful cases out of 24. suc_ratio=0.326250

sc_merge_warm_rm
tmscore = 0.479490; rmsd = 2.788219
24 successful cases out of 24. suc_ratio=0.392500

merge_warm_rm-conse
tmscore = 0.445896; rmsd = 2.972333
23 successful cases out of 