# Setup

First, we need to initialize the object (a Pertpy package internal dataset), and if the object isn't already preprocessed and/or clustered, perform those steps.

## Imports & Options

In [1]:
%load_ext autoreload
%autoreload 2

import crispr as cr 
from crispr import Crispr
import pertpy as pt
import scanpy as sc
import pandas as pd
import numpy as np

# Options
print("\nAnalysis Functions\n\t" + "\n\t".join(list(pd.Series(
    [np.nan if "__" in x else x for x in dir(cr.ax)]).dropna())))
file = "perturb-seq"
pd.options.display.max_columns = 100

#  Set Arguments
kwargs_init = dict(assay=None,
                   col_gene_symbols="gene_symbol",  
                   col_cell_type="celltype", 
                   col_perturbed="perturbation", 
                   col_guide_rna="perturbation", 
                   col_num_umis=None,
                   kws_process_guide_rna=None,
                   col_condition="target_gene_name", 
                   key_control="Control", 
                   key_treatment="KO")
file_path = pt.dt.papalexi_2021()


Analysis Functions
	analyze_composition
	cluster
	clustering
	composition
	compute_distance
	find_marker_genes
	perform_augur
	perform_celltypist
	perform_differential_prioritization
	perform_gsea
	perform_mixscape
	perturbations


Output()

## Object

This code instantiates the CRISPR object, which is the main way of interacting with this package as an end-user.

This is more code than you would need in real life; it just ensures that certain public datasets are loaded from the source for various reasons.

In [None]:
self = Crispr(file_path, **kwargs_init)

# Processing

We will just use defaults for the preprocessing arguments (very little filtering, consistent with the Pertpy tutorial for this object).

## Preprocess

In [None]:
self.preprocess(kws_hvg=True)  # preprocessing 

## Cluster

In [None]:
_ = self.cluster()

### CellTypist Annotations

Now let's detect cell types using `CellTypist`.

In [None]:
_ = self.annotate_clusters("ImmuneAllHigh.pkl")

# Inspection

First, you can easily print the `.adata` representation, the gene expression modality `.obs` preview, and columns and keys stored in the object's attributes using the `print()` method.

In [None]:
self.print()

## Set Up Arguments for Later


This code looks more complicated than it actually would actually be for an end user because it was made to be generalizable across several datasets with particular column names, sizes that make it necessary to subset them in order to run the vignettes in a reasonable period of time, etc.

Basically, you won't need this code as an end user; this is just to choose random subsets of genes and perturbations, etc. that are available in a given example dataset.

In real use cases, you will know what genes and conditions are of interest, and you can manually specify them by simply stating them in the appropriate arguments (such as `target_gene_idents`) or (in many cases) by not specifying the argument (resulting in the code using all available genes, etc.).

In [None]:
genes = self.rna.var.reset_index()[self._columns["col_gene_symbols"]]
genes_subset = list(pd.Series(genes).sample(10))
target_gene_idents = list(self.obs[self._columns[
    "col_target_genes"]].sample(10))  # 10 random guide gene targets

## Explore Data Descriptives

You can plot and explore descriptives using the `describe()` method.

In [None]:
_ = self.describe()  # simple
# _ = self.describe(group_by=self._columns["col_target_genes"], plot=True)

# Plots

Create plots applicable to scRNA-seq data broadly (e.g., UMAP, dotplots) without having to do any perturbation-specific analyses.

## Basic Usage

You can create simple plots easily without having to remember a bunch of arguments to specify! 

The most useful is the `genes` argument, which allows you to subset the number of features plotted (useful for spead and layout/interpretability of plots).

In [None]:
# figs = self.plot(genes=["gene A", "gene B"...])  # to specify specific genes
figs = self.plot(genes=24)  # to specify random subset of this # of genes

## Advanced Usage

Use the `layers` argument to plot more layers (in this instance, all of them, including the scaled data) for certain plot types.
 
Use the `cell_types_circle` argument to create a UMAP with certain cell types circled in red.

Use the `genes_highlight` argument to highlight in gold the names of specified genes on the gene expression dot plot.

Use the `kws_clustering` argument to specify a dictionary keywords to pass to certain UMAP-based plots. For instance, specify `kws_clustering=dict(col_cell_type="leiden")` to use leiden clusters instead of whatever is stored in `self._columns["col_cell_type"]`.

Use the `kws_gex_violin` argument to pass additional arguments to the violin plots of gene expression.

In [None]:
cct = "predicted_labels" if "predicted_labels" in self.rna.obs else None
clus = list(self.rna.obs[cct].sample(2))  # clusters to circle
figs = self.plot(genes=36, 
                 col_cell_type=cct,
                 cell_types_circle=clus,  # list cell types to circle on UMAP
                 genes_highlight=list(np.array(genes_subset)[1:3]), 
                 kws_gex_violin=dict(scale="area", height=10),
                 kws_umap=dict(col_cell_type=cct))

## Specific Plots

Use the package visualization sub-module function to make specific, customizable plots. 

Use arguments creatively to customize. For example, set the `col_cell_type` argument to the name of the perturbation condition column in the `plot_gex()` function to plot gene expression across perturbation conditions instead of cell types.

Use `kws_<kind>` (replacing <kind> with the plot type, e.g., `kws_dot`) arguments to specify a dictionary of keyword arguments to pass to a Scanpy plotting functions (e.g., https://scanpy.readthedocs.io/en/stable/generated/scanpy.pl.dotplot.html) for further customization.

In [None]:
# Scanpy Figure Defaults
sc.set_figure_params(scanpy=True, dpi=100, dpi_save=200, frameon=True, 
                     fontsize=8, figsize=(20, 40), color_map=None, 
                     format="pdf", facecolor=None, transparent=False)

# Simple GEX Heatmap by Cell Type
fig = self.plot(kind="hm", genes=genes_subset, figsize=(60, 60))

# Argument to Subset Data to Target Genes of Interest
subset = list(self.rna.obs[self._columns["col_condition"]].isin(
    target_gene_idents))

# Dot Plot
fig = self.plot(subset=subset, kind="dot",  # specify dot plot
                col_cell_type=self._columns["col_condition"], 
                genes=target_gene_idents)

# Dot Plot with Genes Grouped & Labeled
marker_genes_dict = {"Monocyte": ["FCGR3A", "FCN1", "S100A4", "S100A6"],
                     "B": ["CD79A", "CD19"]
                     }  # labeled ~ cell type here, but labels can be anything
fig = self.plot(subset=subset, kind="dot",  # specify dot plot
                col_cell_type=self._columns["col_condition"], 
                marker_genes_dict=marker_genes_dict)

# Heatmap
fig = self.plot(subset=subset, kind="hm",  # specify heatmap plot
                col_cell_type=self._columns["col_condition"], 
                genes=target_gene_idents, figsize=(60, 60), layer="scaled")

# Matrix Plot
fig = self.plot(subset=subset, kind="matrix",  # specify heatmap plot
                col_cell_type=self._columns["col_condition"], 
                genes=target_gene_idents, layer="scaled", figsize=(60, 60))

# Analyses

The following examples concern CRISPR or other perturbation design-specific analyses.

## Guide RNA Counts/Percentage

In [None]:
# Choose Subset of Target Genes (optional)
tgis = list(pd.Series(target_gene_idents).sample(3)) if len(
    target_gene_idents) > 3 else target_gene_idents  # smaller subset = faster
cct = "majority_voting" if "majority_voting" in self.rna.obs else None

# Guide RNA Counts
_ = self.get_guide_rna_counts(target_gene_idents=tgis)

# ...By Cell Type
_ = self.get_guide_rna_counts(target_gene_idents=tgis, group_by=cct)

# # ...By Target Gene Ultimately Assigned
# _ = self.get_guide_rna_counts(target_gene_idents=tgis, 
#                               group_by=self._columns["col_target_genes"])

# ...By Cell Type & Target Gene Ultimately Assigned
_ = self.get_guide_rna_counts(target_gene_idents=tgis, group_by=[
    cct, self._columns["col_target_genes"]], margin_titles=True)

## Augur: Perturbation Responses by Cell Type

**Which cell types are most affected by perturbations?** Quantify perturbation responses by cell type with Augur, which uses supervised machine learning classification of experimental condition labels (e.g., treated versus untreated). The more separable the condition among cells of a given type, the higher the perturbation effect score.

<u> __Features__ </u>  

- Quantify and visualize degree of perturbation response by cell type
- Identify the most important features (genes).

<u> __Input__ </u>  

There are no required arguments. 
* If you want to override defaults drawn from `._columns` and/or `._keys`, specify the appropriate argument (e.g., `col_cell_type`). 
* You can also specify a different `classifer` (default "random_forest_classifer") used in the machine learning classification procedure used to calculate the AUCs/accuracy. 
* You may pass keyword arguments to the Augur predict method by specifying a dictionary in `kws_augur_predict`.
* Specify `select_variance_features` as True to run the original Augur implementation, which removes genes that don't vary much across cell type. If False, use features selected by `scanpy.pp.highly_variable_genes()`, which is faster and sensitively recovers effects; however, the feature selection may yield inflated Augur scores because this reduced feature set is used in training, resulting in it taking advantage of the pre-existing power of this feature selection to separate cell types.
* Specify `n_folds` and/or `subsample_size` to choose the number and sample size of folds in cross-validation.
* Set an integer for `seed` to allow reproducibility across runs.

<u> __Output__ </u>  

Tuple, where the first element is the AnnData object created by the function, the second, the results dictionary, and the third, a dictionary of figures visualizing results. If copy is False (default), these outputs can also be found in `.results["augur"]` and `.figures["augur"]`.

<u> __Notes__ </u>  

- Sub-sample sizes equal across conditions; does not account for perturbation-induced compositional shifts (cell type abundance)
- Scores are for cell types (aggregated across cells, not individual cells)
- Two modes
    - If select_variance_feature=True, 
    - If False, you also have to be sure that "highly_variable_features" is a variable in your data. This can be complicated if you have a separate layer for perturbation data.

In [None]:
cct = "majority_voting" if "majority_voting" in self.rna.obs else \
    self._columns["col_cell_type"]
_ = self.run_augur(
    col_cell_type=cct, 
    # ^ will be label in self._columns by default, but can override here
    col_perturbed=self._columns["col_perturbed"], 
    # ^ will be this by default if unspecified, but can override here
    key_treatment=self._keys["key_treatment"],  
    # ^ will be this by default if unspecified, but can o verride here
    select_variance_features=True,  # filter by highly variable genes
    classifier="random_forest_classifier", n_folds=3, augur_mode="default", 
    kws_umap=kws_umap, subsample_size=5, kws_augur_predict=dict(span=0.7))

## Mixscape: Cell-Level Perturbation Classification & Scoring

**Is a perturbed cell detectibly perturbed, and to what extent?** Mixscape first calculates the "**perturbation signature**" by determining which control condition cells most closely resemble each perturbed cell in terms of mRNA expression and then subtracts the control expression from that of the perturbed cells' (i.e., centers perturbed cells' gene expression on their control neighbors).

Then, it **identifies** and removes perturbed **cells with no detectible perturbation** (i.e., assigns them to predicted classes of perturbed versus not perturbed). You can then create visuals based on whether the cell is detectibly perturbed, "non-perturbed" (not detectibly perturbed), or control (no treatment). Optionally, you can visualize protein expression by this predicted class in certain multi-modal data.

**Are there perturbation-specific clusters?** Mixscape uses linear discriminant analysis (LDA) to cluster cells that resemble each other in terms of gene expression and perturbation condition. _(LDA reduces dimensionality and attempts to maximize the separability of classes. Unperturbed cells are removed from analysis.)_ 

<u> __Features__ </u>  

- Plot targeting efficiency.
- Remove confounds (e.g., cell cycle, batch effects)
- Classify cells as affected or unaffected (i.e., "escapees") by the perturbation
- Quantify and visualize degree of perturbation response

<u> __Input__ </u> 

See documentation, but the key arguments are listed here.

* **col_cell_type**: If you want to run using a different cell classification column, (e.g., CellTypist annotations that weren't used for the original `self._columns["col_cell_type"]``), you can specify a different column by passing `col_cell_type=<column name>` if you'd like.
* **target_gene_idents**: A list of gene symbols to focus on in plots/analyses. Specify as True to include all.
* **target_gene_idents**: The default layer of data used is "log1p." Remember that Mixscape centers cells on their control neighbors when considering whether to use centered and/or scaled data.

<u> __Output__ </u>  

Assuming your `Crispr` object is named "self":
- Targeting Efficiency: `self.figures["mixscape"]["targeting_efficiency"]`
- Differential Expression Ordered by Posterior Probabilities: `self.figures["mixscape"]["DEX_ordered_by_ppp_heat"]`
- Posterior Probabilities Violin Plot: `self.figures["mixscape"]["ppp_violin"]`
- Perturbation Scores: `self.figures["mixscape"]["perturbation_score"]`
- Perturbation Clusters (from LDA): `self.figures["mixscape"]["perturbation_clusters"]`

The above instructions are for accessing output via the object attributes. Assuming output is assigned to a variable `figs` (i.e., `figs = ` would replace the `_ = ` in the code below), replace `self.figures["mixscape"]` in the above code with `figs`.

<u> __Notes__ </u>  

- If `._columns["col_sample_id"]` is not None, perturbation scores will by default be calculated and/or plotted taking that into consideration (e.g., biological replicates) unless `col_split_by=False`. That argument can also be set to a different column name explicitly, in which case that specification will be used as the `col_split_by` argument in Pertpy Mixscape functions in place of sample ID.

#### Run Mixscape

In [None]:
import scanpy as sc
import matplotlib.pyplot as plt

adata = self.rna.copy()
col_target_genes = self._columns["col_target_genes"]
target_gene_idents = ["PDGFRA", "SP100"]

nrow, ncol = cr.pl.square_grid(len(target_gene_idents))

fig, axs = plt.subplots(nrow, ncol, figsize=(50, 50))
if "flatten" in dir(axs):
    axs = axs.flatten()
for i, x in enumerate(target_gene_idents):  # iterate target genes
    try:
        sc.pl.violin(
            adata[adata.obs[col_target_genes] == x], keys=x,
            groupby="mixscape_class_global", ax=axs[i])
    except Exception as err:
        print(f"{err}\n\nGene expression violin plot failed for {x}!")
        figs = err
        

In [None]:
cct = "predicted_labels" if "predicted_labels" in self.rna.obs else None
tgis = list(self.rna.obs[self._columns["col_target_genes"]].sample(3))
_ = self.run_mixscape(col_cell_type=cct, target_gene_idents=tgis)
# _ = self.run_mixscape(target_gene_idents=True)  # plot all target genes

#### Create Different Mixscape Plots

If you want to re-create mixscape **plots with <u> different target genes and/or proteins of interest**</u> later, you can use `self.plot_mixscape(<ONE OR MORE TARGET GENES>)`. If you want a different color for the perturbation score curves, specify `color=` in that method.

In [None]:
tgis = pd.Series(self.rna.uns["mixscape"].keys()).sample(1)
_ = self.plot_mixscape(tgis, color="red")

## Distance

Distance Metrics

See `self.figures["distances"]` and  `self.results["distances"]` for results.

In [None]:
for x in ["mmd", "edistance"]:
    _ = self.compute_distance(distance_type=x, method="X_pca", kws_plot=dict(
        figsize=(15, 15), robust=True))

## Multi-Cellular Programs (Dialogue)

In [None]:
cct = "majority_voting" if "majority_voting" in self.rna.obs else \
    self._columns["col_cell_type"]
fig_mcp = self.run_dialogue(n_programs=4, col_cell_type=cct, cmap="coolwarm")