### Notebook to the identification of infected and bystander cells using a logistic regression classifier 

- **Developed by**: Carlos Talavera-López Ph.D
- **Würzburg Institute for Systems Immunology & Faculty of Medicine, Julius-Maximilian-Universität Würzburg**
- v231212

### Import required modules

In [1]:
import anndata
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
from pywaffle import Waffle
from gprofiler import GProfiler
import matplotlib.pyplot as plt

from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import label_binarize
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

### Set up working environment

In [2]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'RdPu', dpi_save = 300, vector_friendly = True, format = 'svg')

-----
anndata     0.10.2
scanpy      1.9.5
-----
PIL                 10.0.0
asttokens           NA
backcall            0.2.0
certifi             2023.07.22
charset_normalizer  3.2.0
colorama            0.4.6
comm                0.1.4
cycler              0.10.0
cython_runtime      NA
dateutil            2.8.2
debugpy             1.6.7.post1
decorator           5.1.1
exceptiongroup      1.1.3
executing           1.2.0
gprofiler           1.0.0
h5py                3.9.0
idna                3.4
igraph              0.10.8
importlib_resources NA
ipykernel           6.25.1
ipywidgets          8.1.0
jedi                0.19.0
joblib              1.3.2
kiwisolver          1.4.5
leidenalg           0.10.1
llvmlite            0.40.1
louvain             0.8.1
matplotlib          3.7.2
mpl_toolkits        NA
natsort             8.4.0
numba               0.57.1
numexpr             2.8.7
numpy               1.24.4
packaging           23.1
pandas              2.1.1
parso               0.8.3
patsy     

### Read in dataset

In [3]:
adata = sc.read_h5ad('../data/Marburg_cell_states_locked_ctl230901.raw.h5ad') 
adata

### Identify viral genes and create label for them

In [None]:
viral_genes = [gene for gene in adata.var_names if 'NC_' in gene]
adata.obs['viral_counts'] = np.sum(adata[:, viral_genes].X, axis = 1)

In [None]:
group_means = adata.obs.groupby('group')['viral_counts'].mean()

plt.figure(figsize = (8, 8))
sns.barplot(x = group_means.index, y = group_means.values, palette = 'Dark2')

plt.xlabel('Group')
plt.ylabel('Average Viral Transcript Counts')
plt.title('Average Viral Transcripts per Group')
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize = (12, 6))
sns.kdeplot(data = adata.obs, x = 'viral_counts', hue = 'group', common_norm = False, fill = True)
plt.title('Density of Viral Transcript Counts per Group')
plt.xlabel('Total Viral Transcript Counts')
plt.ylabel('Density')
plt.xlim(0, 4000)
plt.show()

### Define infected cells based on viral counts

In [None]:
adata.obs['infected_status'] = adata.obs['viral_counts'] > 300
adata.obs['infected_status'].value_counts()

In [None]:
pd.crosstab(adata.obs['infected_status'], adata.obs['group'], dropna = False)

### Identify infection signature using a logistic regression model

In [None]:
X = adata.X 
y = adata.obs['infected_status']

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 1712)

In [None]:
log_reg = LogisticRegression()
log_reg.fit(X_train, y_train)

In [None]:
important_genes_indices = log_reg.coef_.argsort()[0][-1500:]
important_genes = adata.var_names[important_genes_indices]
important_genes

### Score learned signature in all cells

In [None]:
sc.tl.score_genes(adata, gene_list = important_genes.tolist(), score_name = 'infection_signature_score')
pd.crosstab(adata.obs['infection_signature_score'], adata.obs['group'], dropna = False)

In [None]:
plt.figure(figsize = (12, 6))
sns.kdeplot(data = adata.obs, x = 'infection_signature_score', hue = 'group', common_norm = False, fill = True)
plt.title('Density of Infection Signature Score per Group')
plt.xlabel('Infection Signature Score')
plt.ylabel('Density')
plt.xlim(0, 50)
plt.show()

### Categorise cells based on viral counts and signature score

In [None]:
def classify_cells(row):
    if row['viral_counts'] > 500 and row['infection_signature_score'] > threshold_high:
        return 'Infected'
    elif row['viral_counts'] < 500 and row['infection_signature_score'] > threshold_high:
        return 'Bystander'
    elif row['viral_counts'] < 300 and row['infection_signature_score'] <= threshold_low:
        return 'Uninfected'
    else:
        return 'Bystander'

In [None]:
threshold_high = 40
threshold_low = 10
adata.obs['final_classification'] = adata.obs.apply(classify_cells, axis=1)
pd.crosstab(adata.obs['final_classification'], adata.obs['group'], dropna = False)

In [None]:
crosstab_df = pd.crosstab(adata.obs['group'], adata.obs['final_classification'], dropna = False)

crosstab_df.plot(kind = 'bar', stacked = True, figsize = (10, 6))
plt.title('Stacked Bar Plot of Classifications by Group')
plt.xlabel('Classification')
plt.ylabel('Count')
plt.legend(title='Group')
plt.show()