### Notebook for the implementation of a DEG model for multicondition experiments

- **Developed by**: Carlos Talavera-López Ph.D
- **Institute of Computational Biology - Computational Health Centre - Helmholtz Munich**
- v230501

### Load required modules

In [None]:
import jax
import anndata
import numpyro
import numpy as np
import scanpy as sc
from jax import random
import jax.numpy as jnp
from numpyro import handlers
import numpyro.optim as optim
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO, autoguide
from numpyro.infer.initialization import init_to_median

### Set up working environment 

In [None]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

### Read in dataset

In [None]:
adata = sc.read_h5ad('./data/hca_heart_immune_download.h5ad')
adata

In [None]:
assert 'cell_states' in adata.obs.columns, "adata.obs['cell_states'] is missing"
assert 'donor' in adata.obs.columns, "adata.obs['donor'] is missing"
assert 'cell_source' in adata.obs.columns, "adata.obs['cell_source'] is missing"

In [None]:
sc.pp.normalize_total(adata, target_sum = 1e6, exclude_highly_expressed = True)
sc.pp.log1p(adata)

### Calculate HVGs

In [None]:
adata_raw = adata.copy()
adata.layers['counts'] = adata.X.copy()

sc.pp.highly_variable_genes(
    adata,
    flavor = "seurat_v3",
    n_top_genes = 1500,
    layer = "counts",
    batch_key = "donor",
    subset = True
)

adata

### Prepare data for a `numPyro` model

- Encoding categorical variables as integers

In [None]:
adata.obs['cell_states_encoded'] = adata.obs['cell_states'].astype('category').cat.codes
adata.obs['donor_encoded'] = adata.obs['donor'].astype('category').cat.codes
adata.obs['cell_source_encoded'] = adata.obs['cell_source'].astype('category').cat.codes

n_genes = adata.shape[1]
n_donors = len(adata.obs['donor'].cat.categories)
n_protocols = len(adata.obs['cell_source'].cat.categories)
n_cell_states = len(adata.obs['cell_states'].cat.categories)

- Linear regression model accounting for donor and protocol effects

In [None]:
def linear_regression_model(X, cell_states, donor, protocol, y = None):
    sigma = numpyro.sample("sigma", dist.Exponential(1.))
    beta0 = numpyro.sample("beta0", dist.Normal(0., 1.))
    beta_genes = numpyro.sample("beta_genes", dist.Normal(0., 1.), sample_shape = (n_genes,))
    
    with numpyro.plate("plate_donors", n_donors):
        beta_donors = numpyro.sample("beta_donors", dist.Normal(0., 1.))
    
    with numpyro.plate("plate_protocols", n_protocols):
        beta_protocols = numpyro.sample("beta_protocols", dist.Normal(0., 1.))
    
    mean_expression = beta0 + jnp.matmul(X, beta_genes) + beta_donors[donor] + beta_protocols[protocol]
    numpyro.sample("obs", dist.Normal(mean_expression, sigma), obs=y)


- Prepare data for inference

In [None]:
X = adata.X
cell_states = adata.obs['cell_states_encoded'].values
donor = adata.obs['donor_encoded'].values
protocol = adata.obs['cell_source_encoded'].values

- Create the guide and SVI objects

In [None]:
guide = autoguide.AutoNormal(linear_regression_model, init_loc_fn = init_to_median)
svi = SVI(model=linear_regression_model, guide=guide, optim=Adam(0.01), loss=Trace_ELBO())


- Train the model

In [None]:
num_epochs = 1000
for epoch in range(num_epochs):
    loss = svi.update(X, cell_states, donor, protocol)
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, loss: {loss:.2f}")

- Compile the model

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate = 0.01)
linear_regression_model.compile(optimizer = optimizer, loss = 'mean_squared_error')


### Train the model and find differentially expressed genes

- Train the model for each gene and store the learned weights for cell types

In [None]:
def create_linear_regression_model(num_features):
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.InputLayer(input_shape=(num_features,)))
    model.add(tf.keras.layers.Dense(1, activation=None))

    model.compile(optimizer='adam', loss='mean_squared_error')
    return model

In [None]:
cell_type_weights = []
num_genes = gene_expression_matrix.shape[0]
num_cell_types = cell_types.shape[1]
num_pcs = 50

for gene_index in range(num_genes):
    gene_expression = gene_expression_matrix[gene_index].toarray().reshape(-1, 1)
    
    linear_regression_model = create_linear_regression_model(feature_matrix.shape[1])
    
    linear_regression_model.fit(feature_matrix, gene_expression, epochs=10, verbose=0)
    learned_weights = linear_regression_model.get_weights()[0]
    cell_type_weights.append(learned_weights[:num_cell_types])

cell_type_weights = np.array(cell_type_weights)

In [None]:
# Find differentially expressed genes for each cell type
differentially_expressed_genes = {}
for cell_type_index in range(num_cell_types):
    cell_type_name = cell_types.columns[cell_type_index]
    cell_type_specific_weights = cell_type_weights[:, cell_type_index]
    top_gene_indices = np.argsort(-np.abs(cell_type_specific_weights))[:10]  # Adjust the number of top genes if needed
    top_gene_names = adata.var_names[top_gene_indices]
    differentially_expressed_genes[cell_type_name] = top_gene_names

# Print the differentially expressed genes for each cell type
for cell_type, genes in differentially_expressed_genes.items():
    print(f"{cell_type}: {', '.join(genes)}")
