In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
# Import the data
ppi_clean = pd.read_csv('data/ppi_clean.csv')
ppi_clean

Unnamed: 0,domain,aa_ProtPosition,uniprot_id,hydrophobicity_scores,Rlength,normalized_length,normalized_abs_surf_acc,normalized_hydropathy_index,rel_surf_acc,prob_sheet,...,9_wm_pssm_K,9_wm_pssm_M,9_wm_pssm_F,9_wm_pssm_P,9_wm_pssm_S,9_wm_pssm_T,9_wm_pssm_W,9_wm_pssm_Y,9_wm_pssm_V,p_interface
0,0,1,A0A024RAV5,0.64,188,0.238095,0.784319,0.711111,0.803,0.003,...,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0
1,0,2,A0A024RAV5,-0.05,188,0.238095,0.359207,0.422222,0.530,0.047,...,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0
2,0,3,A0A024RAV5,-0.74,188,0.238095,0.395387,0.111111,0.464,0.043,...,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0
3,0,4,A0A024RAV5,0.26,188,0.238095,0.401655,0.355556,0.385,0.084,...,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0
4,1,5,A0A024RAV5,-1.50,188,0.238095,0.288103,0.066667,0.287,0.084,...,0.302858,0.436017,0.238759,0.103980,0.105653,0.161544,0.060391,0.140326,0.481904,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
65145,1,34,P0A427,1.38,38,0.014881,0.210794,1.000000,0.233,0.086,...,0.286923,0.374128,0.415754,0.109332,0.351726,0.330383,0.241807,0.334863,0.381730,1
65146,0,35,P0A427,-0.74,38,0.014881,0.292872,0.111111,0.343,0.086,...,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,1
65147,0,36,P0A427,0.48,38,0.014881,0.141193,0.455556,0.367,0.043,...,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,1
65148,0,37,P0A427,-0.74,38,0.014881,0.406473,0.111111,0.477,0.043,...,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,1


In [None]:
# Remove columns that shouldn't be features
def remove_non_features(df: pd.DataFrame, non_feature_cols: list[str]) -> pd.DataFrame:
    return df.drop(non_feature_cols)

In [None]:
# Split data into train/validation and test set, taking into account the protein groups
from sklearn.model_selection import GroupShuffleSplit

def split_data_by_group(X: pd.DataFrame, y: pd.Series, groups: pd.Series, **kwargs) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    gss = GroupShuffleSplit(**kwargs)

    for train_idx, test_idx in gss.split(X, y, groups):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

    assert set(X_train.uniprot_id).intersection(set(X_test.uniprot_id)) == set(), 'The same uniprot ID appears in both the train and test set!'
    return X_train, X_test, y_train, y_test

def select_window_size(df: pd.DataFrame, window_size: int) -> pd.DataFrame:
    '''Remove the window columns that are not used, i.e. not specified in window_size'''
    wm_cols = df.columns[df.columns.str.match(r"^\d+_wm")].tolist()  # Regex for [number]_wm
    pssm_cols = [col for col in df.columns if col.startswith('pssm') or col.startswith('prob')]
    extra_cols = ['rel_surf_acc', 'normalized_abs_surf_acc', 'normalized_hydropathy_index']
    all_cols = set(wm_cols + pssm_cols + extra_cols)
    if window_size != 1:
        keep_cols = set([col for col in df.columns if col.startswith(f"{window_size}_wm")])
        cols_to_remove = list(set(all_cols).difference(set(keep_cols)))

    else:
        cols_to_remove = wm_cols
    cols_before = frozenset(df.columns)
    df = df.drop(cols_to_remove, axis=1)
    cols_after = frozenset(df.columns)
    print('Intersection: ', cols_before.intersection(cols_after))

    return df

ppi_clean_9 = select_window_size(ppi_clean, 9)

X = ppi_clean_9.drop('p_interface', axis=1)
y = ppi_clean_9.p_interface
groups = ppi_clean.uniprot_id

X_tv, X_test, y_tv, y_test = split_data_by_group(X, y, groups, train_size=0.2)

Intersection:  frozenset({'9_wm_pssm_D', '9_wm_pssm_L', '9_wm_pssm_G', '9_wm_pssm_S', 'uniprot_id', '9_wm_pssm_P', '9_wm_pssm_W', '9_wm_normalized_hydropathy_index', 'p_interface', '9_wm_pssm_C', 'hydrophobicity_scores', '9_wm_pssm_Q', '9_wm_pssm_Y', 'Rlength', '9_wm_pssm_F', '9_wm_prob_helix', '9_wm_pssm_N', 'aa_ProtPosition', '9_wm_pssm_A', '9_wm_pssm_M', '9_wm_pssm_V', '9_wm_pssm_T', '9_wm_pssm_E', '9_wm_normalized_abs_surf_acc', '9_wm_prob_sheet', '9_wm_pssm_H', '9_wm_pssm_K', '9_wm_prob_coil', 'normalized_length', '9_wm_rel_surf_acc', '9_wm_pssm_R', '9_wm_pssm_I', 'domain'})


Unnamed: 0,domain,aa_ProtPosition,uniprot_id,hydrophobicity_scores,Rlength,normalized_length,normalized_abs_surf_acc,normalized_hydropathy_index,rel_surf_acc,prob_sheet,...,9_wm_pssm_L,9_wm_pssm_K,9_wm_pssm_M,9_wm_pssm_F,9_wm_pssm_P,9_wm_pssm_S,9_wm_pssm_T,9_wm_pssm_W,9_wm_pssm_Y,9_wm_pssm_V
1559,0,1,P9WFX5,0.64,370,0.508929,0.789109,0.711111,0.808,0.003,...,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111
1560,0,2,P9WFX5,0.62,370,0.508929,0.304174,0.700000,0.565,0.017,...,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111
1561,0,3,P9WFX5,1.06,370,0.508929,0.296944,0.922222,0.332,0.017,...,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111
1562,0,4,P9WFX5,-0.18,370,0.508929,0.297579,0.411111,0.520,0.016,...,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111
1563,0,5,P9WFX5,0.62,370,0.508929,0.244137,0.700000,0.454,0.016,...,0.297797,0.287011,0.310314,0.170403,0.595197,0.621541,0.618270,0.053566,0.127863,0.296047
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
64332,0,86,Q89VT6,0.12,90,0.092262,0.162932,0.322222,0.235,0.017,...,0.245301,0.536919,0.272496,0.040481,0.260189,0.577020,0.321720,0.016734,0.073474,0.234516
64333,0,87,Q89VT6,0.62,90,0.092262,0.203090,0.700000,0.378,0.019,...,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111
64334,0,88,Q89VT6,-0.05,90,0.092262,0.310999,0.422222,0.459,0.019,...,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111
64335,0,89,Q89VT6,0.62,90,0.092262,0.266193,0.700000,0.495,0.019,...,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111,0.111111


In [None]:
from sklearn.tree import DecisionTreeClassifier

X = X_tv
y = y_tv
groups = X.uniprot_id

X_train, X_valid, y_train, y_valid = split_data_by_group(X, y, groups, train_size=0.2)