In [1]:
import pandas as pd

In [2]:
# Import the data
ppi_clean = pd.read_csv('/kaggle/input/dataset-for-ml-project/train.csv')


In [3]:
# Remove columns that shouldn't be features
def remove_non_features(df: pd.DataFrame, non_feature_cols: list[str]) -> pd.DataFrame:
    return df.drop(non_feature_cols, axis=1)

In [4]:
# Split data into train/validation and test set, taking into account the protein groups
from sklearn.model_selection import GroupShuffleSplit

def split_data_by_group(X: pd.DataFrame, y: pd.Series, groups: pd.Series, **kwargs) -> tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]:
    gss = GroupShuffleSplit(**kwargs)

    for train_idx, test_idx in gss.split(X, y, groups):
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

    assert set(X_train.uniprot_id).intersection(set(X_test.uniprot_id)) == set(), 'The same uniprot ID appears in both the train and test set!'
    return X_train, X_test, y_train, y_test

def select_window_size(df: pd.DataFrame, window_sizes) -> pd.DataFrame:
    '''Remove the window columns that are not used, i.e. not specified in window_size'''
    keep_cols = set()
    wm_cols = df.columns[df.columns.str.match(r"^\d+_wm")].tolist()
    all_cols = set(wm_cols)

        # Regex for [number]_wm
      #pssm_cols = [col for col in df.columns if col.startswith('pssm') or col.startswith('prob')]
      #extra_cols = ['rel_surf_acc', 'normalized_abs_surf_acc', 'normalized_hydropathy_index']
      
    for window_size in window_sizes:
        keep_cols.update(set([col for col in df.columns if col.startswith(f"{window_size}_wm")]))



    cols_to_remove = list(set(all_cols).difference(set(keep_cols)))
    cols_before = frozenset(df.columns)
    df = df.drop(cols_to_remove, axis=1)
    cols_after = frozenset(df.columns)
    
    return df




In [5]:
from itertools import combinations
from sklearn.model_selection import GridSearchCV
from xgboost import XGBClassifier


In [6]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GroupKFold
from sklearn.base import BaseEstimator
from sklearn.metrics import roc_auc_score, recall_score

def kfold_cv_by_group(X: pd.DataFrame, y: pd.Series, groups: pd.Series, model: BaseEstimator, **kwargs) -> pd.DataFrame:
    group_kfold = GroupKFold(**kwargs)

In [7]:
from sklearn.ensemble import RandomForestClassifier


In [8]:
%%time

y = ppi_clean.p_interface
X_non_windowed = ppi_clean.drop('p_interface', axis=1)
non_feature_cols = ['domain', 'aa_ProtPosition', 'uniprot_id']
X_non_windowed = remove_non_features(X_non_windowed, non_feature_cols)
groups = ppi_clean.uniprot_id
gs_model = RandomForestClassifier(class_weight='balanced', criterion = 'gini', n_estimators= 200, max_features = 'log2')

gs_params = {
    'max_depth': [5, 10, 20, 25],  
    'min_samples_leaf': [ 4, 8, 16, 32, 64 ,128]
}
gridsearch = GridSearchCV(gs_model, param_grid=gs_params, n_jobs=-1, scoring='roc_auc', cv=GroupKFold(n_splits=3))
list_results = []
for size in [0,1,2,3,4]:
  for selection in combinations(set([3,5,7,9]), size):
    
    X = select_window_size(X_non_windowed, list(selection))
    gridsearch.fit(X, y, groups=groups)

    list_results.append([size, selection, gridsearch.best_params_, gridsearch.best_score_])

results_df = pd.DataFrame(list_results)


print(results_df)


    0             1                                          2         3
0   0            ()  {'max_depth': 5, 'min_samples_leaf': 128}  0.676247
1   1          (9,)  {'max_depth': 25, 'min_samples_leaf': 64}  0.695195
2   1          (3,)  {'max_depth': 25, 'min_samples_leaf': 32}  0.688265
3   1          (5,)  {'max_depth': 20, 'min_samples_leaf': 64}  0.688446
4   1          (7,)  {'max_depth': 20, 'min_samples_leaf': 64}  0.692870
5   2        (9, 3)  {'max_depth': 20, 'min_samples_leaf': 64}  0.695768
6   2        (9, 5)  {'max_depth': 25, 'min_samples_leaf': 64}  0.694720
7   2        (9, 7)  {'max_depth': 20, 'min_samples_leaf': 64}  0.694492
8   2        (3, 5)  {'max_depth': 25, 'min_samples_leaf': 32}  0.688954
9   2        (3, 7)  {'max_depth': 25, 'min_samples_leaf': 32}  0.693358
10  2        (5, 7)  {'max_depth': 25, 'min_samples_leaf': 64}  0.692341
11  3     (9, 3, 5)  {'max_depth': 25, 'min_samples_leaf': 32}  0.693375
12  3     (9, 3, 7)  {'max_depth': 20, 'min_samples