In [8]:
"""
Classical models for pathway data

Adham Beyki
PRaDA, AA2I2 - Deakin University
2018-11-19
"""

import numpy as np
import pandas as pd
from sklearn import metrics
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

In [2]:
def print_results(y_true, y_pred):
    print(metrics.classification_report(y_true, y_pred))
    print(metrics.confusion_matrix(y_true, y_pred))

In [3]:
# read data
PATHWAY_DATA_PATH = "../data/pathway_data.pkl"

data = pd.read_pickle(PATHWAY_DATA_PATH)

pathway_df = data['pathway_df']
train_idxs = data['train_idxs']
test_idxs = data['test_idxs']

cols = pathway_df.columns[:-2]

X_train = pathway_df.loc[train_idxs][cols].values
y_train = pathway_df.loc[train_idxs]['PAM50']
X_test = pathway_df.loc[test_idxs][cols].values
y_test = pathway_df.loc[test_idxs]['PAM50']

In [7]:
# Logistic Regression
print('Logistic Regression')
estimators = [
    ('standardize', StandardScaler()),
    ('logistic_regression', LogisticRegression(class_weight='balanced'))
]
clf = Pipeline(estimators)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print_results(y_test, y_pred)

Logistic Regression




              precision    recall  f1-score   support

        LumA       0.93      0.88      0.90       176
        LumB       0.72      0.82      0.77        67

   micro avg       0.86      0.86      0.86       243
   macro avg       0.83      0.85      0.84       243
weighted avg       0.87      0.86      0.87       243

[[155  21]
 [ 12  55]]


In [13]:
# Random Forest
print('Random Forest')
estimators = [
    ('standardize', StandardScaler()),
    ('random_forest', RandomForestClassifier(n_estimators=500, class_weight='balanced'))
]
clf = Pipeline(estimators)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
print_results(y_test, y_pred)

Random Forest
              precision    recall  f1-score   support

        LumA       0.87      0.96      0.91       176
        LumB       0.86      0.63      0.72        67

   micro avg       0.87      0.87      0.87       243
   macro avg       0.86      0.79      0.82       243
weighted avg       0.87      0.87      0.86       243

[[169   7]
 [ 25  42]]
