In [174]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import time
from joblib import dump, load
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, average_precision_score
from sklearn.inspection import permutation_importance

### Predicting gene essentiality from transposon sequencing data

As a first pass, I will use the TraDIS transposon sequencing data as the ground truth

In [37]:
data = pd.read_csv('tnseq_features_essentiality.csv')
data = data.set_index('Gene', drop=True)

In [55]:
tradis_essentiality = data[data['TraDIS'].notnull()]['TraDIS']
keio_essentiality = data[data['Keio'].notnull()]['Keio']

The TraDIS data flags genes as nonessential, essential or unclear. For the purposes of this analysis, I will consider genes called unclear as nonessential.

Essential: 1, 
Nonessential/Unclear: 0

In [58]:
label_dict = {'Nonessential': 0, 'Unclear': 0, 'Essential': 1}
tradis_label = tradis_essentiality.replace(label_dict)

Extracting the features only for the genes for which we have a mapping to the TraDIS dataset

In [121]:
feature_cols = data.columns[3:-2]

In [122]:
tradis_data = data.loc[data['TraDIS'].notnull(), feature_cols]

In [123]:
tradis_data

Unnamed: 0_level_0,Mean_counts_5p_10pct,Mean_counts_3p_25pct,Mean_counts_interior,Insertion_index,Fraction_zeros,Fraction_above_thresh,Median_counts,Upper25,Lower25,Zeros_interval
Gene,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0,33.000000,74.200000,73.368421,0.045847,0.105263,0.859649,46.0,105.00,20.00,0.017544
1,44.000000,34.800000,63.647059,0.105063,0.058824,0.764706,28.0,58.00,6.00,0.058824
2,7.500000,39.636364,27.950000,0.033437,0.350000,0.550000,8.0,32.50,0.00,0.100000
3,148.000000,44.000000,34.777778,0.180758,0.111111,0.666667,41.0,47.00,2.00,0.111111
4,80.600000,85.909091,90.812500,0.180041,0.062500,0.937500,55.0,76.50,18.50,0.062500
...,...,...,...,...,...,...,...,...,...,...
4012,14.333333,56.500000,70.461538,0.229180,0.076923,0.923077,72.0,92.00,39.00,0.076923
4013,51.000000,105.100000,52.142857,0.116429,0.000000,0.928571,41.5,55.75,21.00,0.000000
4014,24.250000,38.437500,113.300000,0.122407,0.100000,0.875000,95.0,164.50,33.75,0.025000
4015,0.000000,0.000000,2.500000,0.005372,0.833333,0.083333,0.0,0.00,0.00,0.416667


Great! This is the main dataframe that will be used for the machine learning models.

### Defining functions for training and plotting results

In [164]:
def optimize_grid(X, y, model, grid_parameters, num_cv=5, randomState=42):
    #we will pick the best model with 5-fold cross validation.
#     model = RandomForestClassifier(class_weight={0:1, 1:2})
    grid_search = GridSearchCV(model, grid_parameters, cv = num_cv, n_jobs=-1, verbose=True)
    start = time.time()
    grid_search.fit(X, y)    #fit the model with all the grid search parameters
    end = time.time()
    best_grid_clf = grid_search.best_estimator_   #pick the model with the best performance
    #this is the score on the training/validation set
    y_pred = best_grid_clf.predict(X)    
    y_pred_score = best_grid_clf.predict_proba(X)[:,1]    #get the prediction score (so probability)
    print(f"Time taken for hyperparameter optimization is: {end-start}")
    print("The optimal model performance during the GridSearchCV is")
    print(confusion_matrix(y,y_pred))
    print(classification_report(y,y_pred))
    print(accuracy_score(y, y_pred))
    
    return best_grid_clf

Train/val (for convenience referred to as train) - test split

In [196]:
testSize = 0.3    #keeping 30% of the genes
X_train, X_test, y_train, y_test = train_test_split(tradis_data, tradis_label, test_size=testSize, random_state=42, stratify=tradis_label)

Scaling the data

In [197]:
scaler = StandardScaler()
X_train_scaled = pd.DataFrame(scaler.fit_transform(X_train), index = X_train.index, columns = X_train.columns)
X_test_scaled = pd.DataFrame(scaler.transform(X_test), index = X_test.index, columns = X_test.columns)

I will try out a few different models for predicting gene essentiality:

#### 1. Random forest classifier

In [198]:
grid_parameters_rf = {
    'criterion': ['gini', 'entropy'],
    'bootstrap': [True],
    'oob_score': [True],
    'max_depth': [3,5,7,9],
    'max_samples': [0.25, 0.5, 0.75, 0.9],
    'max_features': ['sqrt',  0.5, 0.75],
    'n_estimators': [100, 150, 200, 400]
}

In [199]:
rf_model = RandomForestClassifier(class_weight='balanced')
best_rf_model = optimize_grid(X_train_scaled, y_train, rf_model, grid_parameters_rf)

Fitting 5 folds for each of 384 candidates, totalling 1920 fits
Time taken for hyperparameter optimization is: 144.1589229106903
The optimal model performance during the GridSearchCV is
[[2265   34]
 [  33  200]]
              precision    recall  f1-score   support

           0       0.99      0.99      0.99      2299
           1       0.85      0.86      0.86       233

    accuracy                           0.97      2532
   macro avg       0.92      0.92      0.92      2532
weighted avg       0.97      0.97      0.97      2532

0.9735387045813586


Now let's run the model on the test data

In [200]:
rf_y_pred = best_rf_model.predict(X_test_scaled)
print("model performance on unseen test set")
print(confusion_matrix(y_test,rf_y_pred))
print(classification_report(y_test,rf_y_pred))
print(accuracy_score(y_test, rf_y_pred))

model performance on unseen test set
[[969  17]
 [ 15  85]]
              precision    recall  f1-score   support

           0       0.98      0.98      0.98       986
           1       0.83      0.85      0.84       100

    accuracy                           0.97      1086
   macro avg       0.91      0.92      0.91      1086
weighted avg       0.97      0.97      0.97      1086

0.9705340699815838


In [202]:
tn, fp, fn, tp = confusion_matrix(y_test, rf_y_pred).ravel()
specificity = tn / (tn+fp)
sensitivity = tp / (tp+fn)
print(sensitivity, specificity)

0.85 0.9827586206896551


#### 2. Logistic regression

In [203]:
grid_parameters_lr = {
    'penalty': ['l2'],
    'C': [0.1, 0.5, 1, 2, 10],
    'solver': ['lbfgs','liblinear']
}
log_reg = LogisticRegression(class_weight='balanced')
best_lr_model = optimize_grid(X_train_scaled, y_train, log_reg, grid_parameters_lr)

Fitting 5 folds for each of 10 candidates, totalling 50 fits
Time taken for hyperparameter optimization is: 0.5479259490966797
The optimal model performance during the GridSearchCV is
[[2148  151]
 [  13  220]]
              precision    recall  f1-score   support

           0       0.99      0.93      0.96      2299
           1       0.59      0.94      0.73       233

    accuracy                           0.94      2532
   macro avg       0.79      0.94      0.85      2532
weighted avg       0.96      0.94      0.94      2532

0.9352290679304898


In [204]:
lr_y_pred = best_lr_model.predict(X_test_scaled)
print("model performance on unseen test set")
print(confusion_matrix(y_test,lr_y_pred))
print(classification_report(y_test,lr_y_pred))
print(accuracy_score(y_test, lr_y_pred))

model performance on unseen test set
[[919  67]
 [  5  95]]
              precision    recall  f1-score   support

           0       0.99      0.93      0.96       986
           1       0.59      0.95      0.73       100

    accuracy                           0.93      1086
   macro avg       0.79      0.94      0.84      1086
weighted avg       0.96      0.93      0.94      1086

0.9337016574585635


This is super interesting: logistic regression can pick out nearly all the essential genes but performs also misclassifies many genes as essential when they're not. This is definitely not a good model choice.

#### 3. Gradient boosted classifier

In [205]:
grid_parameters_gb = {
    'criterion': ['friedman_mse', 'squared_error'],
    'learning_rate': [0.1, 0.2, 0.5],
    'max_depth': [3,5,7,9],
    'max_features': ['log2', 'sqrt', 0.25],
    'n_estimators': [100, 150, 200, 400]
}

gb_model = GradientBoostingClassifier()
best_gb_model = optimize_grid(X_train_scaled, y_train, gb_model, grid_parameters_gb)

Fitting 5 folds for each of 288 candidates, totalling 1440 fits
Time taken for hyperparameter optimization is: 122.91658616065979
The optimal model performance during the GridSearchCV is
[[2287   12]
 [   0  233]]
              precision    recall  f1-score   support

           0       1.00      0.99      1.00      2299
           1       0.95      1.00      0.97       233

    accuracy                           1.00      2532
   macro avg       0.98      1.00      0.99      2532
weighted avg       1.00      1.00      1.00      2532

0.995260663507109


In [206]:
gb_y_pred = best_gb_model.predict(X_test_scaled)
print("model performance on unseen test set")
print(confusion_matrix(y_test,gb_y_pred))
print(classification_report(y_test,gb_y_pred))
print(accuracy_score(y_test, gb_y_pred))

model performance on unseen test set
[[968  18]
 [ 15  85]]
              precision    recall  f1-score   support

           0       0.98      0.98      0.98       986
           1       0.83      0.85      0.84       100

    accuracy                           0.97      1086
   macro avg       0.90      0.92      0.91      1086
weighted avg       0.97      0.97      0.97      1086

0.9696132596685083


#### Let's compare this to a naive classification rule that I used in the LTEE TnSeq paper:

if fraction_above_threshold < 0.1, call as essential:

In [207]:
naive_pred = X_test['Fraction_above_thresh']<0.1

In [208]:
print("model performance on validation set")
print(confusion_matrix(y_test,naive_pred))
print(classification_report(y_test,naive_pred))
print(accuracy_score(y_test, naive_pred))

model performance on validation set
[[964  22]
 [ 13  87]]
              precision    recall  f1-score   support

           0       0.99      0.98      0.98       986
           1       0.80      0.87      0.83       100

    accuracy                           0.97      1086
   macro avg       0.89      0.92      0.91      1086
weighted avg       0.97      0.97      0.97      1086

0.9677716390423573


This is quite interesting: the machine learning models seem to do no better than a naive rule that I defined. 

Now, let's compare on the entire dataset (not that this is necessarily the fairest comparison)

In [186]:
naive_pred = tradis_data['Fraction_above_thresh']<0.1

In [188]:
print("model performance on entire dataset")
print(confusion_matrix(tradis_label,naive_pred))
print(classification_report(tradis_label,naive_pred))
print(accuracy_score(tradis_label, naive_pred))

model performance on entire dataset
[[3211   74]
 [  56  277]]
              precision    recall  f1-score   support

           0       0.98      0.98      0.98      3285
           1       0.79      0.83      0.81       333

    accuracy                           0.96      3618
   macro avg       0.89      0.90      0.90      3618
weighted avg       0.97      0.96      0.96      3618

0.9640685461580984


Looking at feature importance now:

In [183]:
result = permutation_importance(best_rf_model, X_test_scaled, y_test, n_repeats=10, random_state=0, n_jobs=-1)

In [184]:
imp = pd.DataFrame(result['importances_mean'], index=feature_cols)

In [185]:
imp

Unnamed: 0,0
Mean_counts_5p_10pct,-0.000552
Mean_counts_3p_25pct,-0.00267
Mean_counts_interior,0.004052
Insertion_index,0.014917
Fraction_zeros,0.026519
Fraction_above_thresh,-0.004328
Median_counts,0.000552
Upper25,0.002026
Lower25,0.000368
Zeros_interval,-0.000921
