### Notebook for the implementation of a MAST-like model for DEG using `PyTorch`

- **Developed by**: Carlos Talavera-López Ph.D
- **Institute of Computational Biology - Computational Health Centre - Helmholtz Munich**
- v230501

### Import required packages

In [1]:
import torch
import anndata
import numpy as np
import pandas as pd
import scanpy as sc

from tqdm import tqdm
import torch.nn as nn
from tqdm import trange

import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

### Set up working environment

In [2]:
np.random.seed(1769)

sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

-----
anndata     0.9.1
scanpy      1.9.3
-----
PIL                 9.4.0
appnope             0.1.3
asttokens           NA
backcall            0.2.0
beta_ufunc          NA
binom_ufunc         NA
cffi                1.15.1
colorama            0.4.6
comm                0.1.3
cycler              0.10.0
cython_runtime      NA
dateutil            2.8.2
debugpy             1.6.7
decorator           5.1.1
executing           1.2.0
gmpy2               2.1.2
h5py                3.8.0
hypergeom_ufunc     NA
igraph              0.10.4
invgauss_ufunc      NA
ipykernel           6.22.0
jedi                0.18.2
joblib              1.2.0
kiwisolver          1.4.4
leidenalg           0.9.1
llvmlite            0.39.1
matplotlib          3.7.1
mpl_toolkits        NA
mpmath              1.3.0
natsort             8.3.1
nbinom_ufunc        NA
ncf_ufunc           NA
nct_ufunc           NA
ncx2_ufunc          NA
numba               0.56.4
numpy               1.23.5
packaging           23.1
pandas          

### Read in dataset

In [3]:
adata = sc.read_h5ad('./data/Marburg_All_ctl230404_leiden_states.raw.h5ad')
adata

AnnData object with n_obs × n_vars = 97573 × 27208
    obs: 'sex', 'age', 'ethnicity', 'PaCO2', 'donor', 'infection', 'disease', 'SMK', 'illumina_stimunr', 'bd_rhapsody', 'n_genes', 'doublet_scores', 'predicted_doublets', 'batch', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'percent_mt2', 'n_counts', 'percent_chrY', 'XIST-counts', 'S_score', 'G2M_score', 'condition', 'sample_group', '_scvi_batch', '_scvi_labels', 'IAV_score', 'group', 'C_scANVI', 'cell_type', 'leiden', 'leiden_states'
    var: 'mt', 'ribo', 'n_cells_by_counts-V1', 'mean_counts-V1', 'pct_dropout_by_counts-V1', 'total_counts-V1', 'n_cells_by_counts-V2', 'mean_counts-V2', 'pct_dropout_by_counts-V2', 'total_counts-V2', 'n_cells_by_counts-V3', 'mean_counts-V3', 'pct_dropout_by_counts-V3', 'total_counts-V3', 'n_cells_by_counts-V4', 'mean_counts-V4', 'pct_dropout_by_counts-V4', 'total_counts-V4', 'n_cells_by_counts-V5', 'mean_counts-V5', 'pct_dropout_by_coun

### Calculate Highly Variable Genes

In [4]:
adata_raw = adata.copy()
adata.layers['counts'] = adata.X.copy()

sc.pp.highly_variable_genes(
    adata,
    flavor = "seurat_v3",
    n_top_genes = 1500,
    layer = "counts",
    batch_key = "donor",
    subset = True
)

adata

If you pass `n_top_genes`, all cutoffs are ignored.
extracting highly variable genes
--> added
    'highly_variable', boolean vector (adata.var)
    'highly_variable_rank', float vector (adata.var)
    'means', float vector (adata.var)
    'variances', float vector (adata.var)
    'variances_norm', float vector (adata.var)


AnnData object with n_obs × n_vars = 97573 × 1500
    obs: 'sex', 'age', 'ethnicity', 'PaCO2', 'donor', 'infection', 'disease', 'SMK', 'illumina_stimunr', 'bd_rhapsody', 'n_genes', 'doublet_scores', 'predicted_doublets', 'batch', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'percent_mt2', 'n_counts', 'percent_chrY', 'XIST-counts', 'S_score', 'G2M_score', 'condition', 'sample_group', '_scvi_batch', '_scvi_labels', 'IAV_score', 'group', 'C_scANVI', 'cell_type', 'leiden', 'leiden_states'
    var: 'mt', 'ribo', 'n_cells_by_counts-V1', 'mean_counts-V1', 'pct_dropout_by_counts-V1', 'total_counts-V1', 'n_cells_by_counts-V2', 'mean_counts-V2', 'pct_dropout_by_counts-V2', 'total_counts-V2', 'n_cells_by_counts-V3', 'mean_counts-V3', 'pct_dropout_by_counts-V3', 'total_counts-V3', 'n_cells_by_counts-V4', 'mean_counts-V4', 'pct_dropout_by_counts-V4', 'total_counts-V4', 'n_cells_by_counts-V5', 'mean_counts-V5', 'pct_dropout_by_count

In [5]:
sc.pp.normalize_total(adata, exclude_highly_expressed = True, target_sum = 1e6)
sc.pp.log1p(adata)

normalizing counts per cell The following highly-expressed genes are not considered during normalization factor computation:
['AC105402.3', 'ACTG1', 'ADGRF5', 'AGBL4', 'AGR2', 'AKR1B1', 'ALDH1A1', 'ALDH3A1', 'ANXA1', 'ANXA2', 'AQP3', 'AREG', 'ASCL3', 'ATF3', 'ATP12A', 'AZGP1', 'B2M', 'BIRC3', 'BPIFA1', 'BPIFB1', 'C15orf48', 'C3', 'CA12', 'CAMK2N1', 'CAV1', 'CAVIN1', 'CCL20', 'CCND1', 'CD74', 'CDC20B', 'CDH6', 'CEACAM6', 'CLDN4', 'CLU', 'COL1A1', 'CRYM', 'CSTA', 'CXCL1', 'CXCL10', 'CXCL11', 'CXCL14', 'CXCL17', 'CXCL2', 'CXCL6', 'CXCL8', 'CYB5A', 'DST', 'ELF3', 'FABP5', 'FGFBP1', 'FN1', 'GADD45B', 'GAPDH', 'GDF15', 'GJB2', 'GRIP1', 'GSN', 'GSTA2', 'H1F0', 'HLA-B', 'HMOX1', 'HSP90AA1', 'IFI6', 'IFIT1', 'IFIT2', 'IFIT3', 'IGFBP3', 'IGFBP5', 'IGFBP7', 'IL17C', 'IL1B', 'IRF1', 'ISG15', 'ITGB1', 'KLK10', 'KLK7', 'KRT13', 'KRT15', 'KRT17', 'KRT4', 'KRT5', 'KRT6A', 'KRT6B', 'KRT7', 'LAMA3', 'LAMB3', 'LAMC2', 'LCN2', 'LTF', 'MALAT1', 'MAP1B', 'MMP10', 'MMP13', 'MMP7', 'MMP9', 'MSMB', 'MT1G', 'MT

### Prepare data for analysis

- One-hot encode cell types, conditions, and batches

In [6]:

cell_types = pd.get_dummies(adata.obs['leiden_states'], prefix = 'leiden_states')
conditions = pd.get_dummies(adata.obs['group'], prefix = 'group')
batches = pd.get_dummies(adata.obs['batch'], prefix = 'batch')

-  Concatenate the one-hot encoded data

In [7]:
metadata = pd.concat([cell_types, conditions, batches], axis = 1)

-  Convert the metadata to a tensor

In [8]:
metadata_tensor = torch.tensor(metadata.values, dtype = torch.float32)

-  Combine the expression data with the metadata

In [9]:
X_dense = adata.X.toarray() if adata.X.__class__.__name__ == 'csr_matrix' else adata.X
X = torch.tensor(X_dense, dtype = torch.float32)
X_combined = torch.cat([X, metadata_tensor], dim = 1)

### Define the model and optimizer

In [10]:
class HurdleModel(nn.Module):
    def __init__(self, n_genes, n_metadata, n_outputs):
        super(HurdleModel, self).__init__()
        self.logistic_regression = nn.Sequential(
            nn.Linear(n_genes + n_metadata, n_outputs),
            nn.Sigmoid()
        )
        self.linear_regression = nn.Sequential(
            nn.Linear(n_genes + n_metadata, n_outputs)
        )

    def forward(self, x):
        zero_inflation = self.logistic_regression(x)
        positive_expression = self.linear_regression(x)
        return zero_inflation, positive_expression

In [19]:
class MAST(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MAST, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        zero_inflation = torch.sigmoid(self.fc3(x))
        positive_expression = F.relu(self.fc3(x))
        return zero_inflation, positive_expression
        

In [12]:
n_genes = X.shape[1]
n_metadata = metadata_tensor.shape[1]
n_outputs = len(adata.obs['leiden_states'].cat.categories) * len(adata.obs['group'].cat.categories)
model = HurdleModel(n_genes, n_metadata, n_outputs)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 50

### Train the model on the combined input data

In [13]:
def train(model, optimizer, X_combined, targets_zero_inflation, epochs, batch_size):
    for epoch in range(epochs):
        permutation = torch.randperm(X_combined.size()[0])
        
        for i in range(0, X_combined.size()[0], batch_size):
            indices = permutation[i:i+batch_size]
            batch_x, batch_y = X_combined[indices], targets_zero_inflation[indices]
            
            optimizer.zero_grad()
            zero_inflation, positive_expression = model(batch_x)
            
            loss_zero_inflation = nn.BCELoss()(zero_inflation, batch_y)
            loss_positive_expression = nn.MSELoss()(positive_expression, batch_y)
            
            loss = loss_zero_inflation + loss_positive_expression
            loss.backward()
            optimizer.step()


In [14]:
def calculate_zero_inflation_targets(adata, threshold=0.5):
    zero_inflation_targets = []

    for i, row in tqdm(adata.obs.iterrows(), total=adata.obs.shape[0], desc="Calculating zero inflation targets", bar_format="{l_bar}%s{bar}%s{r_bar}" % ('\033[1;34m', '\033[0m')):
        cell_type = row['leiden_states']
        condition = row['group']

        # Filter data based on cell type and condition
        subset = adata[(adata.obs['leiden_states'] == cell_type) & (adata.obs['group'] == condition), :]

        # Calculate the proportion of zero values for each gene
        proportions = (subset.X.A != 0).mean(axis = 0)

        # Determine if a gene is considered "zero-inflated" based on the threshold
        zero_inflated = (1 - proportions > threshold).astype(float)

        zero_inflation_targets.append(zero_inflated)

    # Stack the zero-inflation targets into a single tensor
    zero_inflation_targets = torch.tensor(np.vstack(zero_inflation_targets)[:, :adata.shape[1]], dtype = torch.float32)

    return zero_inflation_targets

In [15]:
targets_zero_inflation = calculate_zero_inflation_targets(adata, threshold = 0.5)

Calculating zero inflation targets: 100%|[1;34m██████████[0m| 97573/97573 [17:55<00:00, 90.69it/s] 


### Definie training loop

In [16]:
def train(model, optimizer, X_combined, targets_zero_inflation, epochs, batch_size):
    for epoch in trange(epochs, desc="Training", bar_format="{l_bar}%s{bar}%s{r_bar}" % ('\033[1;35m', '\033[0m')):
        permutation = torch.randperm(X_combined.size()[0])

        for i in range(0, X_combined.size()[0], batch_size):
            indices = permutation[i:i+batch_size]
            batch_x, batch_y = X_combined[indices], targets_zero_inflation[indices]

            optimizer.zero_grad()
            zero_inflation, positive_expression = model(batch_x)

            loss_zero_inflation = nn.BCELoss()(zero_inflation, batch_y)
            loss_positive_expression = nn.MSELoss()(positive_expression, batch_y)

            loss = loss_zero_inflation + loss_positive_expression
            loss.backward()
            optimizer.step()

### Train model

In [17]:
n_genes = adata.shape[1]
print(n_genes)

1500


In [21]:
n_input = X_combined.size()[1]
model = MASTModel(n_input, n_genes)

NameError: name 'MASTModel' is not defined

In [None]:
epochs = 100
batch_size = 64
train(model, optimizer, X_combined, targets_zero_inflation, epochs, batch_size)

### Evaluate model performance

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X_combined, targets_zero_inflation, test_size = 0.3, random_state = 1712)

In [None]:
def train(model, optimizer, X_train, y_train, epochs, batch_size):
    loss_history = []
    
    for epoch in trange(epochs, desc="Training", bar_format="{l_bar}%s{bar}%s{r_bar}" % ('\033[1;34m', '\033[0m')):
        epoch_loss = 0
        permutation = torch.randperm(X_train.size()[0])

        for i in range(0, X_train.size()[0], batch_size):
            indices = permutation[i:i+batch_size]
            batch_x, batch_y = X_train[indices], y_train[indices]

            optimizer.zero_grad()
            zero_inflation, positive_expression = model(batch_x)

            loss_zero_inflation = nn.BCELoss()(zero_inflation, batch_y)
            loss_positive_expression = nn.MSELoss()(positive_expression, batch_y)

            loss = loss_zero_inflation + loss_positive_expression
            epoch_loss += loss.item()
            loss.backward()
            optimizer.step()

        epoch_loss /= (X_train.size()[0] / batch_size)
        loss_history.append(epoch_loss)

    return loss_history


In [None]:
input_size = X_combined.shape[1]
hidden_size = 128
output_size = y_train.shape[1]

model = MAST(input_size, hidden_size, output_size)

In [None]:
epochs = 100
batch_size = 64
loss_history = train(model, optimizer, X_train, y_train, epochs, batch_size)

In [None]:
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

In [None]:
def predict(model, X):
    with torch.no_grad():
        zero_inflation, positive_expression = model(X)
    return zero_inflation, positive_expression

- Get predictions for the test set

In [None]:
zero_inflation_preds, positive_expression_preds = predict(model, X_test)

- Convert the predicted probabilities to binary values using a threshold of 0.5

In [None]:
threshold = 0.5
y_test_bin = (y_test.numpy() > threshold).astype(int)
zero_inflation_preds_bin = (zero_inflation_preds.numpy() > threshold).astype(int)

- Calculate performance metrics

In [None]:
accuracy = accuracy_score(y_test_bin, zero_inflation_preds_bin)
precision = precision_score(y_test_bin, zero_inflation_preds_bin, average='macro', zero_division=0)
recall = recall_score(y_test_bin, zero_inflation_preds_bin, average='macro', zero_division=0)
f1 = f1_score(y_test_bin, zero_inflation_preds_bin, average='macro', zero_division=0)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)