In [1]:
from glob import glob
import os

import numpy as np
import pandas as pd

from sklearn.linear_model import RidgeClassifier

In [2]:
peptides = ["AVFDRKSDAK", "ELAGIGILTV", "FLCMKALLL", "GILGFVFTL", "LLWNGPMAV"]
peptide = peptides[0]

if not os.path.exists(f"{peptide}.csv"):

    peptide_dir = os.path.join("../../", peptide)

    data = dict()
    for case_dir in glob(os.path.join(peptide_dir, "*")):
        case = os.path.basename(case_dir)
        energy_csv_fp = os.path.join(case_dir, "energy_terms.csv")
        if not os.path.exists(energy_csv_fp):
            continue
        energy_df = pd.read_csv(energy_csv_fp, index_col=0)
        if len(energy_df) == 0:
            continue
        energy_term_cols = energy_df.columns
        term_matrix = energy_df[energy_df.columns].values
        avg_term_vector = term_matrix.mean(axis=0)
        data[case] = {term: avg_term_vector[i] for i, term in enumerate(energy_term_cols)}

    energy_df = pd.DataFrame.from_dict(data, orient="index")
    energy_df

    ba_csv_fp = os.path.join("use_case_demo/peptide_datasets/", f"{peptide}.csv")
    ba_df = pd.read_csv(ba_csv_fp, index_col=0)
    ba_df.set_index(ba_df.CDR3a + "_" + ba_df.CDR3b + "_" + ba_df.peptide, inplace=True)
    ba_df

    if "split" not in ba_df.columns:
        ba_df = ba_df.sample(frac=1, random_state=1047)
        rand = np.random.rand(len(ba_df))
        ba_df["split"] = np.where(rand < 0.75, "train", np.where(rand < 0.875, "val", "test"))

    df = energy_df.merge(ba_df[["label", "split"]], how="inner", left_index=True, right_index=True)
    df.to_csv(f"{peptide}.csv")

else:

    df = pd.read_csv(f"{peptide}.csv", index_col=0)
    energy_term_cols = df.drop(["label", "split"], axis=1).columns

df

Unnamed: 0,fa_atr,fa_rep,fa_sol,fa_intra_rep,fa_intra_sol_xover4,lk_ball_wtd,fa_elec,pro_close,hbond_sr_bb,hbond_lr_bb,...,hbond_sc,dslf_fa13,omega,fa_dun,p_aa_pp,yhh_planarity,ref,rama_prepro,label,split
CVVNFTGNQFYF_CASSLNPSFNYGYTF_AVFDRKSDAK,-2318.996980,1443.078973,1615.014392,1014.847193,87.820212,-2.756752,-789.009204,8.323094,-72.972140,-140.380526,...,-73.363003,-2.479235,-0.338421,2158.593385,-61.771531,10.604072,65.46191,195.770030,0,train
CAVDLNARLMF_CASSAGAEQFF_AVFDRKSDAK,-2318.291697,1440.537080,1566.092463,934.109860,82.658718,-0.037132,-770.650670,6.745605,-69.532897,-144.194578,...,-66.094223,-2.698921,-0.518694,1944.153885,-51.930994,8.922323,87.91081,210.138849,1,train
CLVGDILYDKIIF_CASSQLHLAGGLDEQFF_AVFDRKSDAK,-2385.013076,1327.395998,1592.917437,968.542450,83.096640,0.639738,-770.307651,4.886481,-72.551520,-135.860620,...,-73.830384,-1.931548,-4.232436,2048.064366,-61.119395,10.520965,87.54151,220.943156,1,test
CAVRRERRDDKIIF_CAIGPGGSGTQYF_AVFDRKSDAK,-2326.640199,1264.141148,1634.635779,954.958547,80.632232,3.952717,-786.500004,4.369369,-68.538171,-145.157756,...,-74.258974,-1.109777,-5.832560,2185.219715,-64.038928,9.266431,92.26966,229.312503,0,test
CAVRDTDSNYQLIW_CASASSITGYGYTF_AVFDRKSDAK,-2350.343093,1414.693169,1619.673798,969.620127,81.186406,-0.697552,-797.754293,4.192594,-72.123265,-144.826876,...,-72.225482,-2.039139,-4.012635,2110.941144,-63.798877,9.719029,95.62630,237.128761,1,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CVVNGYNFNKFYF_CASSGTGVYEQYF_AVFDRKSDAK,-2310.226487,1816.403082,1588.874306,966.324877,83.239305,-1.815110,-770.749394,7.581206,-74.534849,-142.450021,...,-75.236437,-2.306217,1.831565,1950.995681,-69.498922,12.004045,69.14792,185.733975,1,val
CAFMPYNDMRF_CASRRTGKGFF_AVFDRKSDAK,-2420.590629,1534.108454,1650.421880,1034.814477,86.408835,3.970022,-808.968088,4.209854,-71.089683,-134.413282,...,-73.465464,-1.670051,-2.299971,2170.302874,-35.827254,10.746161,68.31409,261.814632,0,train
CGTEEQETSGSRLTF_CASSQWGTGQPQHF_AVFDRKSDAK,-2366.704667,1456.673683,1693.639228,942.427380,84.775310,2.204854,-821.330562,3.985421,-72.977713,-145.600271,...,-73.317142,-1.908780,-5.905412,2192.987863,-58.503603,9.228405,53.42957,229.665097,1,val
CAYRMRF_CASSPGVGLSYEQYF_AVFDRKSDAK,-2331.754744,1722.156773,1632.399286,1099.702922,85.732594,1.004800,-786.419719,8.368239,-68.597461,-131.424942,...,-71.111377,-1.503927,-0.291503,2016.534650,-30.748662,11.875999,66.69269,217.871679,1,test


In [3]:
X = df[energy_term_cols].values
y = df.label.values

model = RidgeClassifier()
model.fit(X, y)

acc = model.score(X, y)
print(acc)

0.581081081081081


In [4]:
print(df.split.value_counts())

X_train = df[df.split == "train"][energy_term_cols].values
y_train = df[df.split == "train"].label.values

model = RidgeClassifier()
model.fit(X_train, y_train)

X_test = df[(df.split == "test") | (df.split == "val")][energy_term_cols].values
y_test = df[(df.split == "test") | (df.split == "val")].label.values

acc = model.score(X_test, y_test)
print(acc)

split
train    142
test      52
val       28
Name: count, dtype: int64
0.4125
