In [1]:
from glob import glob
import os

import numpy as np
import pandas as pd

from sklearn.linear_model import RidgeClassifier

In [None]:
peptides = ["AVFDRKSDAK", "ELAGIGILTV", "FLCMKALLL", "GILGFVFTL", "LLWNGPMAV"]
peptide = peptides[4]
peptide_dir = os.path.join("../../", peptide)

data = dict()
for case_dir in glob(os.path.join(peptide_dir, "*")):
    case = os.path.basename(case_dir)
    energy_csv_fp = os.path.join(case_dir, "energy_terms.csv")
    if not os.path.exists(energy_csv_fp):
        continue
    energy_df = pd.read_csv(energy_csv_fp, index_col=0)
    if len(energy_df) == 0:
        continue
    energy_term_cols = energy_df.columns
    term_matrix = energy_df[energy_df.columns].values
    avg_term_vector = term_matrix.mean(axis=0)
    data[case] = {term: avg_term_vector[i] for i, term in enumerate(energy_term_cols)}

energy_df = pd.DataFrame.from_dict(data, orient="index")
energy_df

Unnamed: 0,fa_atr,fa_rep,fa_sol,fa_intra_rep,fa_intra_sol_xover4,lk_ball_wtd,fa_elec,pro_close,hbond_sr_bb,hbond_lr_bb,hbond_bb_sc,hbond_sc,dslf_fa13,omega,fa_dun,p_aa_pp,yhh_planarity,ref,rama_prepro
CAVRDLMNRDDKIIF_CATKGGYSEAFF_GILGFVFTL,-2372.914676,1376.610697,1567.238082,1046.560724,77.644723,-2.819039,-771.849065,5.831626,-69.868141,-134.429073,-58.979061,-60.487377,-1.646729,-1.862811,2127.949596,-67.925014,12.212958,113.58507,202.834617
CAASGGGSQGNLIF_CASSKRSQEPQHF_GILGFVFTL,-2300.733201,2408.170557,1548.887418,994.521660,82.030327,-11.117040,-719.634295,14.348304,-70.354745,-138.320799,-59.563006,-52.487481,-1.508297,28.529571,1928.359982,-70.094984,10.201959,81.14559,153.975795
CAMSNNAGNMLTF_CSAKDWEAAYNEQFF_GILGFVFTL,-2312.189185,1824.029301,1564.503304,1048.519059,77.266052,-6.595848,-741.893746,8.029456,-68.677269,-138.352770,-57.992176,-54.529524,-1.081530,3.740411,2128.639549,-53.131986,10.284778,119.75880,221.457292
CAYRSAREDKIIF_CALGGWTGFMTEAFF_GILGFVFTL,-2402.477069,1864.454638,1599.006770,1009.754856,79.975103,-3.391743,-759.182671,5.625172,-73.680137,-136.499999,-57.320754,-70.084612,-1.943141,1.429029,2078.813128,-52.158709,11.330357,104.52053,216.190408
CAGAGSQGNLIF_CAGRGTDDYGYTF_GILGFVFTL,-2270.995804,1294.905575,1520.641919,980.450685,78.907773,-7.524495,-747.447615,5.472280,-69.863234,-132.424504,-64.083556,-60.163752,-1.602054,-3.969958,2088.844411,-62.280057,10.479709,112.01556,213.853207
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CIELRARF_CASSIRSSYEQYF_GILGFVFTL,-2272.943396,1289.177784,1522.274811,998.698864,80.064751,-6.871608,-733.896186,6.029318,-68.773123,-129.004880,-62.158388,-64.530794,-1.496429,-3.976305,2216.471730,-48.198531,12.751513,100.99526,211.201932
CVVTETSYDKVIF_CASSIRSSYEQYF_GILGFVFTL,-2310.221950,1307.099793,1548.995543,1040.274470,81.335123,-4.873212,-777.484306,5.480721,-69.717060,-137.174442,-67.566949,-65.220146,-2.844226,-4.010339,2169.306763,-62.320970,12.746094,95.86163,158.159561
CAVLYGGSQGNLIF_CASSPSPRLASPLHF_GILGFVFTL,-2342.252268,1578.315368,1523.402688,1054.767106,74.165376,-2.853814,-724.638600,6.574438,-72.458810,-147.194374,-47.942787,-59.782110,-1.723518,-1.323276,2034.896309,-67.937542,8.656505,115.63612,243.961768
CAEDNNARLMF_CSARTGTGNSYTF_GILGFVFTL,-2277.352689,1313.299948,1549.573612,1003.066363,74.005992,-5.667996,-752.291861,3.730063,-72.148501,-139.278188,-58.535843,-57.674013,-1.200774,-3.611903,2132.782873,-64.157930,10.183666,134.01750,192.172881


In [21]:
ba_csv_fp = os.path.join("use_case_demo/peptide_datasets/", f"{peptide}.csv")
ba_df = pd.read_csv(ba_csv_fp, index_col=0)
ba_df.set_index(ba_df.CDR3a + "_" + ba_df.CDR3b + "_" + ba_df.peptide, inplace=True)
ba_df

Unnamed: 0,CDR3a,CDR3b,MHC Sequence,peptide,TCR_A_sequence,TCR_B_sequence,label,split,cdr_full
AASFIIQGAQKLV_ASSLLGGWSEAF_GILGFVFTL,AASFIIQGAQKLV,ASSLLGGWSEAF,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,GENVEQHPSTLSVQEGDSAVIKCTYSDSASNYFPWYKQELGKRPQL...,EAQVTQNPRYLITVTGKKLTVTCSQNMNHEYMSWYRQDPGLGLRQI...,1,train,AASFIIQGAQKLVASSLLGGWSEAF
AGAGSQGNLI_ASSSRSSYEQY_GILGFVFTL,AGAGSQGNLI,ASSSRSSYEQY,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,LLEQSPQFLSIQEGENLTVYCNSSSVFSSLQWYRQEPGEGPVLLVT...,GGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLIY...,1,train,AGAGSQGNLIASSSRSSYEQY
AFDTNAGKST_ASSIFGQREQY_GILGFVFTL,AFDTNAGKST,ASSIFGQREQY,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,LNVEQSPQSLHVQEGDSTNFTCSFPSSNFYALHWYRWETAKSPEAL...,IGGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLI...,1,train,AFDTNAGKSTASSIFGQREQY
CVVRDNNDMRF_CATSDFEVAGSSYNEQFF_GILGFVFTL,CVVRDNNDMRF,CATSDFEVAGSSYNEQFF,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,RKEVEQDPGPFNVPEGATVAFNCTYSNSASQSFFWYRQDCRKEPKL...,DADVTQTPRNRITKTGKRIMLECSQTKGHDRMYWYRQDPGLGLRLI...,1,train,CVVRDNNDMRFCATSDFEVAGSSYNEQFF
CIAHGGGGADGLTF_CATSDGLAGGWANVLTF_GILGFVFTL,CIAHGGGGADGLTF,CATSDGLAGGWANVLTF,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,DAKTTQPNSMESNEEEPVHLPCNHSTISGTDYIHWYRQLPSQGPEY...,DADVTQTPRNRITKTGKRIMLECSQTKGHDRMYWYRQDPGLGLRLI...,1,train,CIAHGGGGADGLTFCATSDGLAGGWANVLTF
...,...,...,...,...,...,...,...,...,...
CILRETSYDKVIF_CASSYHSNQPQHF_GILGFVFTL,CILRETSYDKVIF,CASSYHSNQPQHF,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,DAKTTQPNSMESNEEEPVHLPCNHSTISGTDYIHWYRQLPSQGPEY...,DGGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLI...,1,train,CILRETSYDKVIFCASSYHSNQPQHF
CGTGVVSGGGADGLTF_CASTVFQGGGQPQHF_GILGFVFTL,CGTGVVSGGGADGLTF,CASTVFQGGGQPQHF,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,SQQPVQSPQAVILREGEDAVINCSSSKALYSVHWYRQKHGEAPVFL...,DGGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLI...,1,test,CGTGVVSGGGADGLTFCASTVFQGGGQPQHF
CATARSSMDSNYQLIW_CATAPHSGNQPQHF_GILGFVFTL,CATARSSMDSNYQLIW,CATAPHSGNQPQHF,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,SQQGEEDPQALSIQEGENATMNCSYKTSINNLQWYRQNSGRGLVHL...,DGGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLI...,1,val,CATARSSMDSNYQLIWCATAPHSGNQPQHF
CAFLNAGKSTF_CASSDAGGGMAEAFF_GILGFVFTL,CAFLNAGKSTF,CASSDAGGGMAEAFF,GSHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRME...,GILGFVFTL,ILNVEQSPQSLHVQEGDSTNFTCSFPSSNFYALHWYRWETAKSPEA...,EADIYQTPRYLVIGTGKKITLECSQTMGHDKMYWYQQDPGMELHLI...,1,test,CAFLNAGKSTFCASSDAGGGMAEAFF


In [22]:
if "split" not in ba_df.columns:
    ba_df = ba_df.sample(frac=1, random_state=1047)
    rand = np.random.rand(len(ba_df))
    ba_df["split"] = np.where(rand < 0.75, "train", np.where(rand < 0.875, "val", "test"))

df = energy_df.merge(ba_df[["label", "split"]], how="inner", left_index=True, right_index=True)
df.label.value_counts()
df

Unnamed: 0,fa_atr,fa_rep,fa_sol,fa_intra_rep,fa_intra_sol_xover4,lk_ball_wtd,fa_elec,pro_close,hbond_sr_bb,hbond_lr_bb,...,hbond_sc,dslf_fa13,omega,fa_dun,p_aa_pp,yhh_planarity,ref,rama_prepro,label,split
CAVRDLMNRDDKIIF_CATKGGYSEAFF_GILGFVFTL,-2372.914676,1376.610697,1567.238082,1046.560724,77.644723,-2.819039,-771.849065,5.831626,-69.868141,-134.429073,...,-60.487377,-1.646729,-1.862811,2127.949596,-67.925014,12.212958,113.58507,202.834617,0,train
CAASGGGSQGNLIF_CASSKRSQEPQHF_GILGFVFTL,-2300.733201,2408.170557,1548.887418,994.521660,82.030327,-11.117040,-719.634295,14.348304,-70.354745,-138.320799,...,-52.487481,-1.508297,28.529571,1928.359982,-70.094984,10.201959,81.14559,153.975795,1,train
CAMSNNAGNMLTF_CSAKDWEAAYNEQFF_GILGFVFTL,-2312.189185,1824.029301,1564.503304,1048.519059,77.266052,-6.595848,-741.893746,8.029456,-68.677269,-138.352770,...,-54.529524,-1.081530,3.740411,2128.639549,-53.131986,10.284778,119.75880,221.457292,0,train
CAYRSAREDKIIF_CALGGWTGFMTEAFF_GILGFVFTL,-2402.477069,1864.454638,1599.006770,1009.754856,79.975103,-3.391743,-759.182671,5.625172,-73.680137,-136.499999,...,-70.084612,-1.943141,1.429029,2078.813128,-52.158709,11.330357,104.52053,216.190408,0,train
CAGAGSQGNLIF_CAGRGTDDYGYTF_GILGFVFTL,-2270.995804,1294.905575,1520.641919,980.450685,78.907773,-7.524495,-747.447615,5.472280,-69.863234,-132.424504,...,-60.163752,-1.602054,-3.969958,2088.844411,-62.280057,10.479709,112.01556,213.853207,1,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CIELRARF_CASSIRSSYEQYF_GILGFVFTL,-2272.943396,1289.177784,1522.274811,998.698864,80.064751,-6.871608,-733.896186,6.029318,-68.773123,-129.004880,...,-64.530794,-1.496429,-3.976305,2216.471730,-48.198531,12.751513,100.99526,211.201932,1,train
CVVTETSYDKVIF_CASSIRSSYEQYF_GILGFVFTL,-2310.221950,1307.099793,1548.995543,1040.274470,81.335123,-4.873212,-777.484306,5.480721,-69.717060,-137.174442,...,-65.220146,-2.844226,-4.010339,2169.306763,-62.320970,12.746094,95.86163,158.159561,1,train
CAVLYGGSQGNLIF_CASSPSPRLASPLHF_GILGFVFTL,-2342.252268,1578.315368,1523.402688,1054.767106,74.165376,-2.853814,-724.638600,6.574438,-72.458810,-147.194374,...,-59.782110,-1.723518,-1.323276,2034.896309,-67.937542,8.656505,115.63612,243.961768,0,train
CAEDNNARLMF_CSARTGTGNSYTF_GILGFVFTL,-2277.352689,1313.299948,1549.573612,1003.066363,74.005992,-5.667996,-752.291861,3.730063,-72.148501,-139.278188,...,-57.674013,-1.200774,-3.611903,2132.782873,-64.157930,10.183666,134.01750,192.172881,0,train


In [23]:
df.to_csv(f"{peptide}.csv")

In [24]:
X = df[energy_term_cols].values
y = df.label.values

model = RidgeClassifier()
model.fit(X, y)

acc = model.score(X, y)
print(acc)

0.7347826086956522


In [25]:
print(df.split.value_counts())

X_train = df[df.split == "train"][energy_term_cols].values
y_train = df[df.split == "train"].label.values

model = RidgeClassifier()
model.fit(X_train, y_train)

X_test = df[(df.split == "test") | (df.split == "val")][energy_term_cols].values
y_test = df[(df.split == "test") | (df.split == "val")].label.values

acc = model.score(X_test, y_test)
print(acc)

split
train    167
test      45
val       18
Name: count, dtype: int64
0.6825396825396826
