In [None]:
from glob import glob
import os

import pandas as pd
import plotly.express as px
from scipy.stats import spearmanr, pearsonr
from sklearn.cross_decomposition import PLSRegression

In [2]:
ensemble_dirs = glob("../../stcrdab_STEGG_complex_DB_filtered/*")
pdb_ids = [x.split("/")[-2 if x.endswith("/") else -1] for x in ensemble_dirs]

energy_terms = list()
geo_terms = list()
feature_dfs = list()
for (ensemble_dir, pdb_id) in zip(ensemble_dirs, pdb_ids):
    energy_term_df = pd.read_csv(os.path.join(ensemble_dir, "energy_terms.csv"))
    energy_term_df.set_index(pdb_id + energy_term_df["Unnamed: 0"].str[:-4], inplace=True)
    energy_term_df.drop("Unnamed: 0", axis=1, inplace=True)
    energy_terms = energy_term_df.columns
    geo_df = pd.read_csv(os.path.join(ensemble_dir, "geo.csv"))
    geo_df.set_index(pdb_id + geo_df.pdb_name, inplace=True)
    geo_df.drop("pdb_name", axis=1, inplace=True)
    geo_terms = geo_df.columns
    feature_df = pd.concat([energy_term_df, geo_df], axis=1)
    feature_dfs.append(feature_df)

feature_df = pd.concat(feature_dfs)

In [3]:
score_df = pd.read_csv("dockq_scores_filtered.csv")
score_df.set_index(score_df.pdb_id.str.upper() + score_df.model.str[:-4], inplace=True)
score_df.drop("Unnamed: 0", axis=1, inplace=True)
score_df.drop_duplicates(inplace=True)

In [4]:
feature_df["dockq"] = score_df.DockQ.loc[score_df.index.intersection(feature_df.index)]
feature_df.dropna(subset="dockq", inplace=True)
feature_df

Unnamed: 0,fa_atr,fa_rep,fa_sol,fa_intra_rep,fa_intra_sol_xover4,lk_ball_wtd,fa_elec,pro_close,hbond_sr_bb,hbond_lr_bb,...,angle_113,angle_114,angle_115,angle_116,angle_117,angle_118,angle_119,angle_120,angle_121,dockq
5WKH_complex_109_roi,-2297.893161,1888.265239,1564.991798,934.664599,76.357929,-10.735848,-743.896985,3.949098,-69.663255,-139.424447,...,1.800239,0.787730,1.262214,2.486522,1.014418,2.353863,1.570796,1.570796,0.0,0.198170
5WKH_complex_18_roi,-2249.853687,1288.896912,1535.221834,915.861171,73.556175,-8.018389,-705.689029,3.236660,-68.002538,-129.096977,...,1.330873,0.817635,1.079131,2.244975,0.898941,2.323958,1.570796,1.570796,0.0,0.164260
5WKH_complex_137_roi,-2272.115849,1128.960982,1545.025827,908.137889,74.064624,-6.478713,-742.474657,4.329981,-69.114188,-139.093090,...,1.596152,0.744642,1.317535,2.431369,0.924873,2.396951,1.570796,1.570796,0.0,0.216224
5WKH_complex_156_roi,-2266.466280,1071.642157,1516.420796,941.020237,75.144636,-7.486938,-744.134891,3.326101,-73.447718,-145.419772,...,1.648431,0.781156,1.375354,2.238761,0.706427,2.360437,1.570796,1.570796,0.0,0.118275
5WKH_complex_186_roi,-2305.834670,1314.103322,1558.734334,933.966606,75.041948,-7.876167,-750.452864,3.600953,-70.272424,-144.769397,...,1.347618,0.775101,1.204063,2.570438,1.154281,2.366491,1.570796,1.570796,0.0,0.098293
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1MWA_complex_7_roi,-2256.430291,1090.135440,1514.755540,953.670599,75.700341,-9.453312,-728.630794,4.416992,-63.801069,-134.179684,...,1.398376,0.973132,1.040204,2.065293,0.766709,2.168461,1.570796,1.570796,0.0,0.264965
1MWA_complex_146_roi,-2189.924460,1050.876805,1465.518119,943.205039,73.534845,-10.541635,-708.175185,3.370048,-69.944914,-128.958010,...,1.735394,0.830880,1.142535,2.286647,0.889287,2.310712,1.570796,1.570796,0.0,0.146098
1MWA_complex_145_roi,-2277.869391,1162.773184,1495.454361,1868.064674,72.367358,-5.118086,-737.509342,3.791187,-72.225749,-144.215567,...,1.915077,1.053854,1.008842,2.028257,0.764385,2.087738,1.570796,1.570796,0.0,0.166095
1MWA_complex_79_roi,-2212.487057,1227.993697,1494.406611,943.814377,72.253879,-13.749705,-704.243935,3.365065,-60.942304,-128.709922,...,2.109436,1.128855,0.942475,1.863448,0.713877,2.012738,1.570796,1.570796,0.0,0.135017


In [None]:
X = feature_df[[col for col in feature_df if col != "dockq"]].values
y = feature_df.dockq.values

pls = PLSRegression(n_components=2)
pls.fit(X, y)

y_hat = pls.predict(X)
px.scatter(x=y, y=y_hat).show()
print(spearmanr(y, y_hat))
print(pearsonr(y, y_hat))

SignificanceResult(statistic=0.7670899001734434, pvalue=0.0)


In [6]:
X = feature_df[energy_terms].values
y = feature_df.dockq.values

pls = PLSRegression(n_components=2)
pls.fit(X, y)

y_hat = pls.predict(X)
px.scatter(x=y, y=y_hat).show()
print(spearmanr(y, y_hat))

SignificanceResult(statistic=0.24431486130596594, pvalue=0.0)


In [7]:
X = feature_df[geo_terms].values
y = feature_df.dockq.values

pls = PLSRegression(n_components=2)
pls.fit(X, y)

y_hat = pls.predict(X)
px.scatter(x=y, y=y_hat).show()
print(spearmanr(y, y_hat))

SignificanceResult(statistic=0.7511732983705518, pvalue=0.0)


In [12]:
feature_df["dockq_rank"] = feature_df.dockq.rank().astype(int)
feature_df["dockq_hat"] = y_hat
feature_df["dockq_hat_rank"] = feature_df.dockq_hat.rank().astype(int)
feature_df

Unnamed: 0,fa_atr,fa_rep,fa_sol,fa_intra_rep,fa_intra_sol_xover4,lk_ball_wtd,fa_elec,pro_close,hbond_sr_bb,hbond_lr_bb,...,angle_116,angle_117,angle_118,angle_119,angle_120,angle_121,dockq,dockq_rank,dockq_hat,dockq_hat_rank
5WKH_complex_109_roi,-2297.893161,1888.265239,1564.991798,934.664599,76.357929,-10.735848,-743.896985,3.949098,-69.663255,-139.424447,...,2.486522,1.014418,2.353863,1.570796,1.570796,0.0,0.198170,19020,0.135061,8068
5WKH_complex_18_roi,-2249.853687,1288.896912,1535.221834,915.861171,73.556175,-8.018389,-705.689029,3.236660,-68.002538,-129.096977,...,2.244975,0.898941,2.323958,1.570796,1.570796,0.0,0.164260,14928,0.275546,19134
5WKH_complex_137_roi,-2272.115849,1128.960982,1545.025827,908.137889,74.064624,-6.478713,-742.474657,4.329981,-69.114188,-139.093090,...,2.431369,0.924873,2.396951,1.570796,1.570796,0.0,0.216224,20491,0.310878,22489
5WKH_complex_156_roi,-2266.466280,1071.642157,1516.420796,941.020237,75.144636,-7.486938,-744.134891,3.326101,-73.447718,-145.419772,...,2.238761,0.706427,2.360437,1.570796,1.570796,0.0,0.118275,7239,0.161730,12487
5WKH_complex_186_roi,-2305.834670,1314.103322,1558.734334,933.966606,75.041948,-7.876167,-750.452864,3.600953,-70.272424,-144.769397,...,2.570438,1.154281,2.366491,1.570796,1.570796,0.0,0.098293,4377,0.281206,19568
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1MWA_complex_7_roi,-2256.430291,1090.135440,1514.755540,953.670599,75.700341,-9.453312,-728.630794,4.416992,-63.801069,-134.179684,...,2.065293,0.766709,2.168461,1.570796,1.570796,0.0,0.264965,23661,0.321461,23965
1MWA_complex_146_roi,-2189.924460,1050.876805,1465.518119,943.205039,73.534845,-10.541635,-708.175185,3.370048,-69.944914,-128.958010,...,2.286647,0.889287,2.310712,1.570796,1.570796,0.0,0.146098,12003,0.185001,14619
1MWA_complex_145_roi,-2277.869391,1162.773184,1495.454361,1868.064674,72.367358,-5.118086,-737.509342,3.791187,-72.225749,-144.215567,...,2.028257,0.764385,2.087738,1.570796,1.570796,0.0,0.166095,15209,0.101725,2548
1MWA_complex_79_roi,-2212.487057,1227.993697,1494.406611,943.814377,72.253879,-13.749705,-704.243935,3.365065,-60.942304,-128.709922,...,1.863448,0.713877,2.012738,1.570796,1.570796,0.0,0.135017,10103,0.139359,8867


In [13]:
px.scatter(feature_df, x="dockq_rank", y="dockq_hat_rank").show()

In [14]:
feature_df.to_csv("all_models.csv")