# Notebook to run differential expression analysis using ´MAST´

**Created by** : **Srivalli Kolla**

**Devloped on** : July 19,2024 

**Last Modified** : July 19,2024 

**Institute of Systems Immunology, University of Wurzburg**

# Importing packages

In [2]:
import os
import glob
import scanpy as sc
import pandas as pd
import sc_toolbox
import scipy.io
import scanpy as sc
import matplotlib.pyplot as plt
import scipy.sparse as sparse
from rpy2 import robjects
from rpy2.robjects import pandas2ri
from rpy2.robjects.packages import importr

pandas2ri.activate()

MAST = importr('MAST')


In [3]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 300, color_map = 'RdPu', dpi_save = 300, vector_friendly = True, format = 'svg')

-----
anndata     0.8.0
scanpy      1.9.3
-----
PIL                 8.2.0
adjustText          1.2.0
argcomplete         NA
attr                23.2.0
backcall            0.2.0
backports           NA
backports_abc       NA
beta_ufunc          NA
binom_ufunc         NA
cffi                1.15.1
cloudpickle         2.2.1
colorama            0.4.6
cycler              0.10.0
cython_runtime      NA
dask                2022.02.0
dateutil            2.9.0
debugpy             1.6.3
decorator           5.1.1
entrypoints         0.4
fsspec              2023.1.0
h5py                3.7.0
ipykernel           6.16.2
ipython_genutils    0.2.0
jedi                0.19.1
jinja2              3.1.4
joblib              1.3.2
kiwisolver          1.4.4
llvmlite            0.39.1
markupsafe          2.1.5
matplotlib          3.5.3
more_itertools      NA
mpl_toolkits        NA
natsort             8.4.0
nbinom_ufunc        NA
numba               0.56.3
numpy               1.21.6
packaging           23.2
panda

# MAST

## Data Preparation

##### Steps followed

1. Data loading and subsetting
2. Normalization and log transformation
3. Generating individual files for genes, barcodes, metadata and a matrix having log counts

In [3]:
adata = sc.read_h5ad('/home/skolla/Github/hofmann_dmd/hofmann_dmd/data/concatenated_CMC-16_07_2024.h5ad')

desired_cell_states = ['vCM1', 'vCM2', 'vCM3', 'vCM4']
adata = adata[adata.obs['cell_states'].isin(desired_cell_states)]


sc.pp.filter_genes(adata, min_counts=10)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)


sparse_matrix = sparse.csr_matrix(adata.X.T)
scipy.io.mmwrite('/home/skolla/Github/hofmann_dmd/hofmann_dmd/DE/de_results/concat_matrix.mtx', sparse_matrix)
genes = pd.DataFrame(adata.var.index, columns=["gene"])
genes.to_csv('/home/skolla/Github/hofmann_dmd/hofmann_dmd/DE/de_results/concat_genes.tsv', sep='\t', index=False, header=False)
barcodes = pd.DataFrame(adata.obs.index, columns=["barcode"])
barcodes.to_csv('/home/skolla/Github/hofmann_dmd/hofmann_dmd/DE/de_results/concat_barcodes.tsv', sep='\t', index=False, header=False)
metadata = adata.obs
metadata.to_csv('/home/skolla/Github/hofmann_dmd/hofmann_dmd/DE/de_results/concat_metadata.tsv', sep='\t', index=True, header=True)

filtered out 1773 genes that are detected in less than 10 counts
normalizing counts per cell
    finished (0:00:00)


  adata.var['n_counts'] = number
  next(self.gen)


## Defining Functions

##### Steps followed in defining functions

*Create_single_cell_assay*

For creating a single cell assay object which can be used by MAST for differential gene expression

1. Read the matrix, genes, and barcodes files
2. Ensure the matrix has correct row and column names by checking with colnames and riownames of matrix
3. Convert sparse matrix to dense matrix
4. Ensure metadata matches the expression matrix
5. Create the SingleCellAssay object
6. Compute number of expressed genes per cell and store as a column in sca object
7. Store columns of interest as factors
8. Create the SingleCellAssay object

In [4]:
# Function to subset the data to only include the target cell state and genotype
def subset_data(adata, target_cell_state, target_genotype, groupby_cell="cell_states", groupby_genotype="genotype"):
    # Subset the AnnData object to only include cells from the target cell state and genotype
    adata_subset = adata[(adata.obs[groupby_cell] == target_cell_state) & (adata.obs[groupby_genotype] == target_genotype)].copy()
    return adata_subset


# Function to convert the AnnData subset to SingleCellAssay object (sca) for MAST analysis
def convert_to_sca(adata_subset):
    print("Converting AnnData subset to SingleCellAssay object...")
    
    robjects.r('''
    library(MAST)
    FromMatrix <- function(counts, cData, fData) {
        sca <- FromMatrix(counts, cData, fData)
        return(sca)
    }
    ''')

    counts_dense = adata_subset.X.todense()
    cData = adata_subset.obs
    fData = adata_subset.var
    
    # Convert data to the appropriate R data structures
    r_counts = pandas2ri.py2rpy(pd.DataFrame(counts_dense))
    r_cData = pandas2ri.py2rpy(cData)
    r_fData = pandas2ri.py2rpy(fData)
    
    sca = robjects.r['FromMatrix'](r_counts, r_cData, r_fData)
    return sca


##### Steps followed in defining functions

*Find_de_MAST*

For performing differential gene expression analysis

1. Define and fit the model by considering n_genes_per_cell and genotype and 
2. Perform likelihood-ratio test for the condition of interest
3. Extract datatable from summary with log-fold changes and p-values
4. Crea te a list to store results generated in next steps
5. For each gene in a cellstate , extract log fold change using natural basic logarithm and p-value and FDR and store temporarily in alist with defined column and cellstate
6. Convert the coefficients to log2 base
7. False discovery rate calculation (Multiple testing correction)
8. Call the R function 'find_de' for dsifferential expression analysis and run on sca object

In [5]:
# Function to find differential expression using MAST
def find_de_MAST(sca):
    print("Finding differentially expressed genes...")
    robjects.r('''
    library(MAST)
    library(data.table)
    
    find_de <- function(sca) {
    zlmCond <- zlm(~ n_genes_per_cell + cell_states + genotype, sca)
    summaryCond <- summary(zlmCond, doLRT = TRUE)
    summaryDt <- summaryCond$datatable
    
    results <- list()
    
    cell_states <- unique(summaryDt[component == 'H', contrast])
    
    for (cell_state in cell_states) {
        if (cell_state != 'n_genes_per_cell'){
            cell_state_lfc <- summaryDt[contrast == cell_state & component == 'logFC', .(primerid, coef)]
            cell_state_p <- summaryDt[contrast == cell_state & component == 'H', .(primerid, `Pr(>Chisq)`)]
            tmp <- merge(cell_state_lfc, cell_state_p, by='primerid')
            tmp$log_fold_change <- tmp$coef / log(2)
            tmp$FDR <- p.adjust(tmp$`Pr(>Chisq)`, 'fdr')
            tmp$cell_states <- cell_state
            
            # Remove extra columns if any before renaming
            tmp <- tmp[, .(primerid, log_fold_change, `Pr(>Chisq)`, FDR, cell_states)]
                
            colnames(tmp) <- c('gene_id', 'log_fold_change', 'p_value', 'FDR', 'cell_states')

            
            results[[cell_state]] <- tmp
        }
     }

        results <- lapply(results, na.omit)
        return(results)
    }
    ''')

    print("Differentially expressed genes found.")

    result = robjects.r['find_de'](sca)
    return result

In [6]:
# Loop through each genotype and each cell state
genotypes = adata.obs['genotype'].unique()
cell_states = adata.obs['cell_states'].unique()

all_results = []

for genotype in genotypes:
    for cell_state in cell_states:
        target = f"{genotype}_{cell_state}"
        print(f"Processing genotype: {genotype}, cell state: {cell_state}")
        
        adata_subset = subset_data(adata, cell_state, genotype)
        
        if adata_subset.shape[0] == 0:
            print(f"No data found for genotype: {genotype}, cell state: {cell_state}")
            continue
        
        sca = convert_to_sca(adata_subset)
        
        de_results = find_de_MAST(sca)
        
        if cell_state in de_results:
            de_results_df = pandas2ri.rpy2py_dataframe(de_results[cell_state])
            
            if not de_results_df.empty:
                output_file = f'/home/skolla/Github/hofmann_dmd/hofmann_dmd/DE/de_results/DE_{genotype}_{cell_state}.txt'
                de_results_df.to_csv(output_file, sep='\t', index=False)
                print(f"Saved DE results for {genotype}_{cell_state} to {output_file}")
                all_results.append(de_results_df)
            else:
                print(f"No differential expression results found for {genotype}_{cell_state}.")
        else:
            print(f"No results found for {genotype}_{cell_state}.")

Processing genotype: Mdx, cell state: vCM2
Converting AnnData subset to SingleCellAssay object...


: 

## Writing files

##### Steps followed 

1. Convert the results to a pandas DataFrame
2. Savinng separate files for each cell state and a siungle file withh all results
3. Adding MAST results to anndata and writing a new anndata file

In [None]:
de_results_df = {cell_state: pandas2ri.rpy2py_dataframe(df) for cell_state, df in de_results.items()}

for cell_state, df in de_results_df.items():
    file_name = f"/home/skolla/Github/hofmann_dmd/hofmann_dmd/DE/de_results/DE_mdx_{cell_state}.txt"
    df.to_csv(file_name, sep='\t', index=False)
    print(f"Saved DE results for {cell_state} to {file_name}")
all_results = pd.concat(de_results_df.values(), keys=de_results_df.keys())
all_results.to_csv('/home/skolla/Github/hofmann_dmd/hofmann_dmd/DE/de_results/DE_mdx-all.txt', sep='\t', index=False)

adata.uns['MAST_results'] = de_results_df

try:
    sc_toolbox.tools.de_res_to_anndata(
        adata,
        all_results,
        groupby="cell_states",
        gene_id_col='gene_id',
        score_col='log_fold_change',
        pval_col='p_value',
        pval_adj_col="FDR",
        lfc_col='log_fold_change',
        key_added='MAST_results'
    )
except ValueError as e:
    print(f"Error updating AnnData: {e}")

print("Updated AnnData object:")
print(adata)
adata.write_h5ad('/home/skolla/Github/hofmann_dmd/hofmann_dmd/DE/de_results/DE_wt_adata.h5ad')


##### Steps followed 

1. Set thresholds
2. Define a function
3. Extracting MAST results into a dataframe and ensuring it has required columns
4. Set index to gene names
5. Filter based on thresholds and extract genes
6. Create a output directory if doesnt exist
7. Plot heatmap for each cellstate by subsetting to desired cellstate

## Data visualization

In [None]:
FDR = 0.01
LOG_FOLD_CHANGE = 1.5
TOP_N_GENES = 50

def plot_heatmap(adata, group_key, groupby="cell_states", output_dir="plots"):
    if group_key not in adata.uns:
        raise ValueError(f"Group key '{group_key}' not found in adata.uns")
    
    de_results_dict = adata.uns[group_key]
    
    # Check if de_results is a dictionary of DataFrames
    if isinstance(de_results_dict, dict) and all(isinstance(v, pd.DataFrame) for v in de_results_dict.values()):
        # Concatenate all DataFrames in the dictionary
        res = pd.concat(de_results_dict.values(), keys=de_results_dict.keys(), names=['cell_states', 'index'])
    else:
        raise ValueError(f"Unexpected format for differential expression results in adata.uns['{group_key}']")
    
    required_columns = ['gene_id', 'log_fold_change', 'p_value', 'FDR']
    for col in required_columns:
        if col not in res.columns:
            raise KeyError(f"Expected column '{col}' not found in differential expression results.")
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    unique_groups = adata.obs[groupby].unique()

    for group in unique_groups:
        group_key_full = f'cell_states{group}'
        
        if group_key_full not in de_results_dict:
            print(f"No results found for group: {group}")
            continue

        print(f"Plotting heatmap for group: {group}")
        
        # Filter results for the current group
        group_res = de_results_dict[group_key_full]
        
        filtered_res = group_res[
            (group_res["FDR"] < FDR) & (abs(group_res["log_fold_change"]) > LOG_FOLD_CHANGE)
        ].sort_values(by=["log_fold_change"], ascending=False)
        
        # Get top genes for the current cell state
        top_genes = filtered_res.head(TOP_N_GENES)
        
        markers = list(top_genes['gene_id'].unique())

        if len(markers) == 0:
            print(f"No significant genes found for group: {group}")
            continue
        
        adata_group = adata[adata.obs[groupby] == group].copy()
        adata_group = adata_group[:, adata_group.var_names.isin(markers)]
        
        # Ensure that markers are displayed correctly
        filename = f"{group}_mdx_heatmap.png"
        filepath = os.path.join(output_dir, filename)

        sc.pl.heatmap(
            adata_group,
            var_names=markers,
            groupby=groupby,
            swap_axes=True,
            show=True,
        )
        
        plt.title(f"Heatmap - {group}_mdx")
        plt.ylabel("Gene Names") 
        plt.savefig(filepath)
        print(f"Heatmap saved to {filepath}")
        plt.close()

# Call the function
plot_heatmap(adata, group_key="MAST_results")

In [None]:
FDR = 0.01
LOG_FOLD_CHANGE = 1.0
TOP_N_GENES = 50

def plot_dotplot(adata, group_key, genotype_key, groupby="cell_states", output_dir="figures"):
    if group_key not in adata.uns:
        raise ValueError(f"Group key '{group_key}' not found in adata.uns")
    
    if genotype_key not in adata.obs:
        raise ValueError(f"Genotype key '{genotype_key}' not found in adata.obs")
    
    de_results_dict = adata.uns[group_key]

    if isinstance(de_results_dict, dict) and all(isinstance(v, pd.DataFrame) for v in de_results_dict.values()):
        res = pd.concat(de_results_dict.values(), keys=de_results_dict.keys(), names=['cell_states', 'index'])
    else:
        raise ValueError(f"Unexpected format for differential expression results in adata.uns['{group_key}']")
    
    required_columns = ['gene_id', 'log_fold_change', 'p_value', 'FDR']
    for col in required_columns:
        if col not in res.columns:
            raise KeyError(f"Expected column '{col}' not found in differential expression results.")
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    unique_groups = adata.obs[groupby].unique()

    for group in unique_groups:
        group_key_full = f'cell_states{group}'
        
        if group_key_full not in de_results_dict:
            print(f"No results found for group: {group}")
            continue

        print(f"Plotting dot plot for group: {group}")
        
        group_res = de_results_dict[group_key_full]
        
        filtered_res = group_res[
            (group_res["FDR"] < FDR) & (abs(group_res["log_fold_change"]) > LOG_FOLD_CHANGE)
        ].sort_values(by=["log_fold_change"], ascending=False)
        
        top_genes = filtered_res.head(TOP_N_GENES)
        markers = list(top_genes['gene_id'].unique())
        if len(markers) == 0:
            print(f"No significant genes found for group: {group}")
            continue
        
        adata_group = adata[adata.obs[groupby] == group].copy()
        
        genotypes = adata_group.obs[genotype_key].unique()
        if len(genotypes) > 0:
            genotype = genotypes[0]
            title = f"Dot Plot for {group} - Genotype: {genotype}"
        else:
            title = f"Dot Plot for {group}"
        
        filename = f"{group}_dotplot.png"
        filepath = os.path.join(output_dir, filename)

        sc.pl.dotplot(
            adata_group,
            var_names=markers,
            groupby=groupby,
            cmap="RdYlBu_r",
            title=title,
            show=True,
            save=filepath
        )
        
        print(f"Dot plot saved to {filepath}")
        plt.close()


plot_dotplot(adata, group_key="MAST_results", genotype_key="genotype")


In [14]:
adata = sc.read_h5ad('/home/skolla/Github/hofmann_dmd/hofmann_dmd/data/concatenated_CMC-16_07_2024.h5ad')


desired_states = ['vCM1', 'vCM2', 'vCM3', 'vCM4']
filtered_adata = adata[adata.obs['cell_states'].isin(desired_states)]

filtered_adata

View of AnnData object with n_obs × n_vars = 8257 × 16060
    obs: 'cell_source', 'cell_type', 'donor', 'cell_states', 'genotype', 'compartment', 'object', 'samples', 'n_counts', 'batch'
    var: 'gene_ids-CMC'

In [28]:
import anndata
import pandas as pd
import numpy as np
from upsetplot import UpSet
import matplotlib.pyplot as plti

In [29]:
# Extract gene counts for a specific gene (adjust index as needed)
gene_of_interest = 0  # Adjust this index or use gene names if available
if hasattr(filtered_adata.X, 'toarray'):
    gene_counts = filtered_adata.X[:, gene_of_interest].toarray().flatten()
else:
    gene_counts = filtered_adata.X[:, gene_of_interest].flatten()

# Create DataFrame
df = pd.DataFrame({
    'GeneCounts': gene_counts,
    'Genotype': filtered_adata.obs['genotype'].values,
    'CellStates': filtered_adata.obs['cell_states'].values
})

# Create a DataFrame where each row represents the presence of a combination
# Using cross-tabulation to represent combinations
combination_df = pd.crosstab(index=df['Genotype'], columns=df['CellStates'])

# Create a MultiIndex with all combinations of Genotype and CellStates
all_combinations = pd.MultiIndex.from_product([combination_df.index, combination_df.columns], names=['Genotype', 'CellStates'])

# Reindex the combination_df to include all possible combinations
combination_df = combination_df.reindex(all_combinations, fill_value=0)

# Convert this DataFrame into a Series with multi-index for UpSet
combinations = combination_df.stack()
combinations.name = 'count'

# Create the UpSet plot
upset = UpSet(combinations)
upset.plot()
plt.title('Upset Plot of Gene Counts by Genotype and Cell States')
plt.show()

ValueError: The DataFrame has values in its index that are not boolean