# Run scGen

add class documentation

In [None]:
# Importing necessary libraries
import pertpy as pp
import scanpy as sc
import sys
sys.path.append('./')
from scGen_class import scGenPerturbationAnalysis

In [None]:
# Loading the dataset
covid_data = pp.dt.stephenson_2021_subsampled()
#covid_data = sc.read('data/stephenson_2021_subsampled.h5ad')

In [None]:
# input parameters
condition_col= "time_after_LPS" #or "disease" in case COVID vs normal
ctrl_key = "nan"
stim_key = "10h"
celltype_col = "author_cell_type"
celltype_to_predict = "B_naive"
n_epochs = 20

In [None]:
# filter the data set to have pairwise comparison 
LPSN_data = covid_data[((covid_data.obs["disease"] == "normal"))]
LPSN_data = LPSN_data[(LPSN_data.obs[condition_col].isin([ctrl_key, stim_key]))]

In [None]:
#downsample the data set
sc.pp.subsample(LPSN_data, n_obs=3000)

In [None]:
LPSN_data.obs[celltype_col].unique()

In [None]:
# Initialize an empty list to store the results
results = []

for i in LPSN_data.obs[celltype_col].unique():

    celltype_to_predict = i

    # Instantiating the PerturbationAnalysis class with the loaded data
    analysis = scGenPerturbationAnalysis(LPSN_data)
    # Preprocessing the data
    analysis.preprocess_data() #just normalisation
    analysis.prepare_training_set(condition_col, 
                                stim_key,
                                celltype_col,
                                celltype_to_predict)
    # Setting up AnnData for scGen
    analysis.setup_anndata(condition_col, 
                        celltype_col)
    analysis.train_model(max_epochs= n_epochs, batch_size=64)
    # Making predictions 
    analysis.make_prediction(
                            ctrl_key,
                            stim_key,
                            celltype_to_predict,
                            condition_col)
    # Evaluating the predictions
    analysis.evaluate_prediction(celltype_col,
                                celltype_to_predict,
                                condition_col,
                                ctrl_key,
                                stim_key)
    # Identifying differentially expressed genes
    analysis.identify_diff_genes(celltype_col,
                                celltype_to_predict,
                                condition_col,
                                stim_key)
    analysis.plot_mean_correlation(stim_key)
    
    # Get R2 value
    r2_value = analysis.r2_value

    # Computing the distance metric
    metric = "edistance"
    n_comps = 50
    analysis.compute_distance_metric(n_comps,
                                    metric,
                                    condition_col,
                                    stim_key,
                                    ctrl_key)
    # Get the perturbation score
    e_distance = analysis.perturbation_score

    metric = "mmd"
    n_comps = 50
    analysis.compute_distance_metric(n_comps,
                                    metric,
                                    condition_col,
                                    stim_key,
                                    ctrl_key)
    # Get the perturbation score
    mmd = analysis.perturbation_score #maximum mean distance

    metric = "euclidean"
    n_comps = 50
    analysis.compute_distance_metric(n_comps,
                                    metric,
                                    condition_col,
                                    stim_key,
                                    ctrl_key)
    # Get the perturbation score
    euclidean = analysis.perturbation_score #maximum mean distance

    results.append({
            "celltype_to_predict": celltype_to_predict,
            "r2_value": r2_value,
            "e_distance": e_distance,
            "mmd": mmd,
            "euclidean": euclidean


        })
    
    results