In [1]:
#sfisch: Version 0.4.0 is needed to use with our model
from autogluon.tabular import TabularPredictor
#sfisch6: we pull the test feature matrix from huggingface using the datasets module
from datasets import load_dataset
#sfisch6: we use this module to pull our trained model from huggingface
from huggingface_hub import snapshot_download

In [2]:
#sfisch6: pull test featmat
dataset = load_dataset('sfisch/hu.MAP3.0')
df_test = dataset["test"].to_pandas()
df_test.head()

Unnamed: 0,acc1,acc2,Ce_1111_poisson,Ce_1111_wcc,Ce_1111_apex,Ce_1111_pq_euc,Ce_6mg_1203_poisson,Ce_6mg_1203_wcc,Ce_6mg_1203_apex,Ce_6mg_1203_pq_euc,...,pair_count_bp3_293T_Z4,neg_ln_pval_bp3_HCT116_Z2,pair_count_bp3_HCT116_Z2,neg_ln_pval_bp3_HCT116_Z4,pair_count_bp3_HCT116_Z4,neg_ln_pval_bp3_293T_HCT116_Z2,pair_count_bp3_293T_HCT116_Z2,neg_ln_pval_bp3_293T_HCT116_Z4,pair_count_bp3_293T_HCT116_Z4,Label
0,Q9UPY3,Q14004,,,,,,,,,...,1.0,4.179461,1.0,4.23792,1.0,4.515594,3.0,7.212276,3.0,-1
1,Q9UPY3,P82664,,,,,,,,,...,7.0,26.672521,9.0,19.010974,6.0,38.192802,26.0,24.502485,12.0,-1
2,Q9UPY3,Q8NCN5,,,,,,,,,...,1.0,4.631487,3.0,1.81195,1.0,27.559512,22.0,3.943765,3.0,-1
3,O75486,Q9UPY3,,,,,,,,,...,,,,,,,,,,-1
4,Q9UPY3,Q9Y6K9,,,,,,,,,...,,,,,,,,,,-1


In [3]:
# load in model - here we pull our trained model from HuggingFace
model_dir = snapshot_download(repo_id="sfisch/hu.MAP3.0_AutoGluon")
predictor = TabularPredictor.load(f"{model_dir}/huMAP3_20230503_complexportal_subset10kNEG_notScaled_accuracy")

Fetching 29 files:   0%|          | 0/29 [00:00<?, ?it/s]

In [4]:
# predict based on features - remove protein IDs
test_trim = df_test.drop(["acc1","acc2"], axis=1)
# make sure label matches training label
test_trim.rename(columns={'Label': predictor.label}, inplace=True)

In [5]:
# remove test label to make predictions
y_test = test_trim[predictor.label]  # values to predict
test_nolab = test_trim.drop(columns=[predictor.label])  # delete label column 
test_nolab.head()

Unnamed: 0,Ce_1111_poisson,Ce_1111_wcc,Ce_1111_apex,Ce_1111_pq_euc,Ce_6mg_1203_poisson,Ce_6mg_1203_wcc,Ce_6mg_1203_apex,Ce_6mg_1203_pq_euc,Ce_BNF_wan_60_1209_poisson,Ce_BNF_wan_60_1209_wcc,...,neg_ln_pval_bp3_293T_Z4,pair_count_bp3_293T_Z4,neg_ln_pval_bp3_HCT116_Z2,pair_count_bp3_HCT116_Z2,neg_ln_pval_bp3_HCT116_Z4,pair_count_bp3_HCT116_Z4,neg_ln_pval_bp3_293T_HCT116_Z2,pair_count_bp3_293T_HCT116_Z2,neg_ln_pval_bp3_293T_HCT116_Z4,pair_count_bp3_293T_HCT116_Z4
0,,,,,,,,,,,...,2.004226,1.0,4.179461,1.0,4.23792,1.0,4.515594,3.0,7.212276,3.0
1,,,,,,,,,,,...,12.886566,7.0,26.672521,9.0,19.010974,6.0,38.192802,26.0,24.502485,12.0
2,,,,,,,,,,,...,1.705391,1.0,4.631487,3.0,1.81195,1.0,27.559512,22.0,3.943765,3.0
3,,,,,,,,,,,...,,,,,,,,,,
4,,,,,,,,,,,...,,,,,,,,,,


In [6]:
# generate predictions
y_pred = predictor.predict(test_nolab)
print("Predictions:  \n", y_pred)

# evaluate predictions against true labels
perf = predictor.evaluate_predictions(y_true=y_test, y_pred=y_pred, auxiliary_metrics=True)

Predictions:  
 0        -1
1        -1
2        -1
3        -1
4        -1
         ..
248164   -1
248165   -1
248166   -1
248167   -1
248168   -1
Name: IntAct_train, Length: 248169, dtype: int64


In [7]:
# generate predicted probabilities
pred_test_probs = predictor.predict_proba(test_nolab)

In [8]:
# grab probability that interaction is true
pred_test_probs.sort_values(1)

Unnamed: 0,-1,1
30149,0.999924,0.000076
63020,0.999921,0.000079
17315,0.999920,0.000080
73109,0.999919,0.000081
46101,0.999916,0.000084
...,...,...
166178,0.000026,0.999974
160200,0.000026,0.999974
4184,0.000026,0.999974
189057,0.000025,0.999975


In [9]:
# annotate probabilities in original df
df_test['pred_prob'] = pred_test_probs[1]

In [10]:
df_test.sort_values("pred_prob",ascending=False)[['acc1','acc2','Label','pred_prob']].head(70)

Unnamed: 0,acc1,acc2,Label,pred_prob
161518,Q9UJX5,Q9UJX4,1,0.999975
189057,Q13868,Q5RKV6,1,0.999975
4184,Q13868,Q9NQT4,1,0.999974
160200,Q9H1A4,Q9UJX4,1,0.999974
166178,Q9H1A4,Q13042,1,0.999974
...,...,...,...,...
31227,Q9NZN8,Q9UKZ1,1,0.999926
24594,Q969G3,Q8TAQ2,1,0.999924
149179,Q9H410,Q9HBM1,1,0.999924
42005,Q9UP83,Q8WTW3,1,0.999924


In [11]:
# save predicted pairwise probability for downstream analysis (e.g., Precision-Recall analysis 
# w/ protein_complex_maps/evaluation/plots/prcurve.py)

df_test.sort_values("pred_prob", ascending=False)[['acc1','acc2','pred_prob']].to_csv(
    "humap3_test_20230503.pairsWprob", index=False, header=False, sep='\t')