### Notebook to build a Gradient Boosting Machine (GBM) classifier for infected and bystander cells 

- **Developed by**: Carlos Talavera-López Ph.D
- **Würzburg Institute for Systems Immunology & Faculty of Medicine, Julius-Maximilian-Universität Würzburg**
- v230813

### Import required modules

In [None]:
import anndata
import numpy as np
import pandas as pd
import scanpy as sc
import xgboost as xgb
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

### Set up working environment

In [None]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'RdPu', dpi_save = 300, vector_friendly = True, format = 'svg')

### Read in Healthy-CTRL dataset

In [None]:
adata = sc.read_h5ad('../data/Marburg_cell_states_locked_scANVI_ctl230813.raw.h5ad') 
adata

### Create labels for infected, non-infected and bystander/unknown

In [None]:
def classify_cells(row):
    iav_genes = adata.var_names[adata.var_names.str.startswith('NC_')]
    num_iav_genes_expressed = sum(row[adata.var_names.isin(iav_genes)] > 0)

    if num_iav_genes_expressed == len(iav_genes):
        return 'infected'
    elif 0 < num_iav_genes_expressed <= 2:
        return 'bystander'
    else:
        return 'control'

In [None]:
adata.obs['classification'] = [classify_cells(row) for row in adata.X.toarray()]

### Split data for training

In [None]:
X = adata.X
y = adata.obs['classification'].values

label_encoder = LabelEncoder()
y_int = label_encoder.fit_transform(y)

X_train, X_test, y_train_int, y_test_int = train_test_split(X, y_int, test_size = 0.2, random_state = 1712)

In [None]:
adata.obs['classification'].value_counts()

### Train GBM classifier

In [None]:
dtrain = xgb.DMatrix(X_train, label = y_train_int)
dtest = xgb.DMatrix(X_test, label = y_test_int)

In [None]:
# Parameters for XGBoost
param = {
    'max_depth': 6,
    'objective': 'multi:softprob',  # output probabilities
    'num_class': 3
}
num_round = 20

In [None]:
watchlist = [(dtrain, 'train'), (dtest, 'test')]
bst = xgb.train(param, dtrain, num_round, evals=[(dtest, 'eval'), (dtrain, 'train')], verbose_eval=True)
preds = bst.predict(dtest)

In [None]:
preds_prob = bst.predict(dtest)

# Convert the labels to a binary format for ROC curve
y_test_bin = label_binarize(y_test_int, classes = [0, 1, 2])

# Compute ROC curve for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(3):
    fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], preds_prob[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Plot the ROC curve
colors = ['blue', 'red', 'green']
classes = ['infected', 'bystander', 'control']
for i, color in zip(range(3), colors):
    plt.plot(fpr[i], tpr[i], color = color, lw = 2,
             label = '{0} (area = {1:0.2f})'.format(classes[i], roc_auc[i]))

plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC curve for cell classification')
plt.legend(loc = "lower right")
plt.show()


In [None]:
evals_result = {}
bst = xgb.train(param, dtrain, num_round, evals=[(dtest, 'eval'), (dtrain, 'train')],
                evals_result = evals_result, verbose_eval = True)

In [None]:
print(evals_result['train'].keys())

In [None]:
epochs = len(evals_result['train']['mlogloss'])
x_axis = range(0, epochs)

fig, ax = plt.subplots(figsize = (5, 5))
ax.plot(x_axis, evals_result['train']['mlogloss'], label = 'Train')
ax.plot(x_axis, evals_result['eval']['mlogloss'], label = 'Test') 
ax.legend()
plt.ylabel('Multi Class Log Loss')
plt.title('XGBoost Multi Class Log Loss')
plt.show()

### Visualise gene importance

In [None]:
importance = bst.get_score(importance_type = 'weight')
sorted_importance = sorted(importance.items(), key = lambda x: x[1], reverse = True)

In [None]:
gene_names = adata.var_names.tolist()
sorted_importance_with_names = [(gene_names[int(key[1:])], value) for key, value in sorted_importance]

In [None]:
N = 20  # number of top features to display, change as needed
top_genes = sorted_importance_with_names[:N]
names, values = zip(*top_genes)

plt.figure(figsize = (8, 8))
plt.barh(names, values)
plt.gca().invert_yaxis()  # to display the most important gene at the top
plt.xlabel('Importance Score')
plt.title('Top {} Gene Importance'.format(N))
plt.show()


In [None]:
sc.pl.umap(adata, frameon = False, color = ['group', 'disease', 'infection', 'classification', 'cell_states'], size = 0.4, legend_fontsize = 5, ncols = 4)