# Run scGen

add class documentation

In [1]:
# Importing necessary libraries
import pertpy as pp
import scanpy as sc
import sys
sys.path.append('./')
from scGen_class import scGenPerturbationAnalysis

In [2]:
# Loading the dataset
covid_data = pp.dt.stephenson_2021_subsampled()
#covid_data = sc.read('data/stephenson_2021_subsampled.h5ad')

In [3]:
# input parameters
condition_col= "time_after_LPS" #or "disease" in case COVID vs normal
ctrl_key = "nan"
stim_key = "10h"
celltype_col = "author_cell_type"
celltype_to_predict = "B_naive"
n_epochs = 20

In [4]:
# filter the data set to have pairwise comparison 
LPSN_data = covid_data[((covid_data.obs["disease"] == "normal"))]
LPSN_data = LPSN_data[(LPSN_data.obs[condition_col].isin([ctrl_key, stim_key]))]

In [5]:
#downsample the data set
sc.pp.subsample(LPSN_data, n_obs=3000)

In [11]:
### get cell types that exists in stimulated condition only
filtered_data = LPSN_data[LPSN_data.obs[condition_col] == stim_key]
# Count the occurrences of each unique value in celltype_col in the filtered data
celltype_counts = filtered_data.obs[celltype_col].value_counts()
# Get only the cell types that have more than 0 counts in stimulated condition
unique_celltypes = celltype_counts[celltype_counts > 0].index.tolist()
# Display the unique cell types
print(unique_celltypes)
len(unique_celltypes)

['B_naive', 'CD8.Naive', 'CD4.Naive', 'NK_16hi', 'B_immature', 'B_switched_memory', 'B_exhausted', 'NK_56hi', 'Platelets', 'CD8.EM', 'B_non-switched_memory', 'CD4.CM', 'MAIT', 'gdT', 'Plasma_cell_IgM', 'Plasmablast', 'CD8.TE', 'Plasma_cell_IgG', 'CD4.IL22']


19

In [12]:
combined_counts = LPSN_data.obs.groupby([celltype_col,condition_col]).size() 
combined_counts_df = combined_counts.reset_index(name='counts')
combined_counts_df

  combined_counts = LPSN_data.obs.groupby([celltype_col,condition_col]).size()


Unnamed: 0,author_cell_type,time_after_LPS,counts
0,ASDC,10h,0
1,ASDC,,1
2,B_exhausted,10h,11
3,B_exhausted,,8
4,B_immature,10h,16
...,...,...,...
83,Treg,,1
84,gdT,10h,5
85,gdT,,89
86,pDC,10h,0


In [None]:
# Initialize an empty list to store the results
results = []

for i in unique_celltypes:

    celltype_to_predict = i

    # Instantiating the PerturbationAnalysis class with the loaded data
    analysis = scGenPerturbationAnalysis(LPSN_data)
    # Preprocessing the data
    analysis.preprocess_data() #just normalisation
    analysis.prepare_training_set(condition_col, 
                                stim_key,
                                celltype_col,
                                celltype_to_predict)
    # Setting up AnnData for scGen
    analysis.setup_anndata(condition_col, 
                        celltype_col)
    analysis.train_model(max_epochs= n_epochs, batch_size=64)
    # Making predictions 
    analysis.make_prediction(
                            ctrl_key,
                            stim_key,
                            celltype_to_predict,
                            condition_col)
    # Evaluating the predictions
    analysis.evaluate_prediction(celltype_col,
                                celltype_to_predict,
                                condition_col,
                                ctrl_key,
                                stim_key)
    # Identifying differentially expressed genes
    analysis.identify_diff_genes(celltype_col,
                                celltype_to_predict,
                                condition_col,
                                stim_key)
    analysis.plot_mean_correlation(stim_key)
    
    # Get R2 value
    r2_value = analysis.r2_value

    # Computing the distance metric
    metric = "edistance"
    n_comps = 50
    analysis.compute_distance_metric(n_comps,
                                    metric,
                                    condition_col,
                                    stim_key,
                                    ctrl_key)
    # Get the perturbation score
    e_distance = analysis.perturbation_score

    metric = "mmd"
    n_comps = 50
    analysis.compute_distance_metric(n_comps,
                                    metric,
                                    condition_col,
                                    stim_key,
                                    ctrl_key)
    # Get the perturbation score
    mmd = analysis.perturbation_score #maximum mean distance

    metric = "euclidean"
    n_comps = 50
    analysis.compute_distance_metric(n_comps,
                                    metric,
                                    condition_col,
                                    stim_key,
                                    ctrl_key)
    # Get the perturbation score
    euclidean = analysis.perturbation_score #maximum mean distance

    results.append({
            "celltype_to_predict": celltype_to_predict,
            "r2_value": r2_value,
            "e_distance": e_distance,
            "mmd": mmd,
            "euclidean": euclidean


        })
    
    results

Preprocessing data for scGen...
Data normalized for scGen.

Removed B_naive &  10h  from the training set


Setting up AnnData for scGen...
AnnData set up for scGen.

Initializing and training scGen model...


  categorical_mapping = _make_column_categorical(


scGen model initialized.

[34mINFO    [0m Jax module moved to TFRT_CPU_0.Note: Pytorch lightning will show GPU is not being used for the Trainer.   


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/rosario/opt/anaconda3/envs/pertpy-env/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/rosario/opt/anaconda3/envs/pertpy-env/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training:   0%|          | 0/20 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


scGen model trained.

scGen model saved to 'model_perturbation_prediction.pt'.

Making predictions for  B_naive &  10h


  utils.warn_names_duplicates("obs")


[34mINFO    [0m Received view of anndata, making copy.                                                                    
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             
