### Notebook for the pseudotime analysis of ACM monocytes using `cellrank`

#### Environment: Cellrank

- **Developed by:** Alexandra Cirnu
- **Modified by:** Carlos Talavera-López
- **Würzburg Institute for Systems Immunology & Julius-Maximilian-Universität Würzburg**
- **Date of creation:** 240514
- **Date of modification:** 240515

### Import required modules

In [1]:
import cellrank as cr
import scvi
import palantir
import anndata as ad
import muon as mu
from muon import atac as ac
from muon import prot as pt
import scipy
from scipy.sparse import csr_matrix
import numpy as np
import scanpy as sc
import pandas as pd
import seaborn as sns
import warnings
import matplotlib.pyplot as plt
import torch
import plotnine as p

ModuleNotFoundError: No module named 'scvi'

### Set up working environment

In [None]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'RdPu', dpi_save = 300, vector_friendly = True, format = 'svg')
cr.settings.verbosity = 2

In [None]:
warnings.simplefilter(action = 'ignore')
scvi.settings.seed = 1712
%config InlineBackend.print_figure_kwargs = {'facecolor' : "w"}
%config InlineBackend.figure_format = 'retina'
torch.set_float32_matmul_precision('medium')

In [None]:
if torch.cuda.is_available():
    print("CUDA is available. Current device:", torch.cuda.current_device())
    print("Total GPUs:", torch.cuda.device_count())
    print("GPU Name:", torch.cuda.get_device_name(0))  # Index 0 for the first GPU
else:
    print("CUDA is not available.")

In [None]:
arches_params = dict(
    use_layer_norm = "both",
    use_batch_norm = "none",
    encode_covariates = True,
    dropout_rate = 0.2,
    n_layers = 3)

### Read in data

In [None]:
#Take the seed labelled data set of Pkp2_Ctr_noninf+Ttn_Ctr_noninf as input

mdata = mu.read_h5mu('/home/acirnu/data/ACM_cardiac_leuco/5_Leiden_clustering_and_annotation/ACM_myeloids_clustered_muon_ac240502.raw.h5mu')     
mdata

In [None]:
adata = mdata.mod['rna']
adata.obs['classification'].cat.categories

In [None]:
adata = adata[adata.obs['classification'].isin(["Monocytes_6", 'Monocytes_11', 'Monocytes_13', 'Monocytes_17', 'DOCK4+MØ_3', 'DOCK4+MØ_9', 'LYVE1+MØ_1','LYVE1+MØ_2', 'LYVE1+MØ_4', 'LYVE1+MØ_8', 'MØ_general_0','MØ_general_7', 'MØ_general_10']) , :]
adata = adata[adata.obs['C_scANVI'].isin(["Monocytes", "DOCK4+MØ", "LYVE1+MØ", "MØ_general"]) , :]
adata 

In [None]:
def X_is_raw(adata): return np.array_equal(adata.X.sum(axis=0).astype(int), adata.X.sum(axis=0))
X_is_raw(adata)

In [None]:
adata.obs['C_scANVI'].value_counts()

### Select HVGs

In [None]:
adata_raw= adata.copy()
adata.layers['counts'] = adata.X.copy()
adata.layers["sqrt_norm"] = np.sqrt(sc.pp.normalize_total(adata, inplace = False)["X"])

sc.pp.highly_variable_genes(
    adata,
    flavor = "seurat_v3",
    n_top_genes = 7000,
    layer = "counts",
    batch_key = "sample",
    subset = True)
adata

### Batch correction with scVI

In [None]:
scvi.model.SCVI.setup_anndata(adata,
                              batch_key = "sample", 
                              categorical_covariate_keys = ["sample"],  
                              layer = 'counts')

In [None]:
scvi_model = scvi.model.SCVI(adata, 
                             n_latent = 150, 
                             n_layers = 3, 
                             dispersion = 'gene-batch', 
                             gene_likelihood = 'nb')

In [None]:
scvi_model.train(13,                                               
                 check_val_every_n_epoch = 1, 
                 enable_progress_bar = True, 
                 accelerator = "gpu",
                 devices = [1])

### Evaluate model performance using the [_Svensson_](https://www.nxn.se/valent/2023/8/10/training-scvi-posterior-predictive-distributions-over-epochs) method

In [None]:
history_df = (
    scvi_model.history['elbo_train'].astype(float)
    .join(scvi_model.history['elbo_validation'].astype(float))
    .reset_index()
    .melt(id_vars = ['epoch'])
)

p.options.figure_size = 12, 6

p_ = (
    p.ggplot(p.aes(x = 'epoch', y = 'value', color = 'variable'), history_df.query('epoch > 0'))
    + p.geom_line()
    + p.geom_point()
    + p.scale_color_manual({'elbo_train': 'black', 'elbo_validation': 'red'})
    + p.theme_minimal()
)

p_.save('fig1.png', dpi = 300)

print(p_)

In [None]:
adata.obsm["X_scVI"] = scvi_model.get_latent_representation(adata)

In [None]:
adata_raw = ad.AnnData(X = adata_raw.X, obs = adata.obs, var = adata_raw.var)          
adata_raw     

### Diffusion maps

In [None]:
dm_res = palantir.utils.run_diffusion_maps(adata, n_components=5, pca_key= "X_scVI")

In [None]:
ms_data = palantir.utils.determine_multiscale_space(adata)

In [None]:
adata

### Visualization

In [None]:
sc.pp.neighbors(adata, use_rep = "X_scVI", n_neighbors = 150, metric = 'minkowski')
sc.tl.umap(adata, min_dist = 1.5, spread = 10, random_state = 1712)
sc.pl.umap(adata, frameon = False, color = ['classification'], size = 4)

### MAGIC imputation

In [None]:
imputed_X = palantir.utils.run_magic_imputation(adata)

In [None]:
hvg_list = adata.var[adata.var['highly_variable']].index.tolist()
print(hvg_list)

In [None]:
sc.pl.embedding(
    adata,
    basis="umap",
    layer="MAGIC_imputed_data",
    color=["Ly6c1", "Cd209a", "Ccr2", "Trem2", "Timd4", "Il1b"],
    frameon=False,
)
plt.show()

### Diffusion maps visualization

In [None]:
palantir.plot.plot_diffusion_components(adata)
plt.show()

### Running Palantir

Find an appropriate start cell - Ly6C high monocyte

In [None]:
if 'Ly6c1' in adata.var_names:
    ly6c1_expression = adata[:, 'Ly6c1'].X

    # Check if data is stored as sparse matrix and convert to dense if necessary
    if isinstance(ly6c1_expression, scipy.sparse.spmatrix):
        ly6c1_expression = ly6c1_expression.toarray()
    # Flatten the array to 1D if necessary
    ly6c1_expression = ly6c1_expression.flatten()
    # Find the index of the maximum expression
    max_expression_index = ly6c1_expression.argmax()
    # Retrieve the cell ID using the index
    cell_id_with_max_ly6c1 = adata.obs_names[max_expression_index]

    print(f"Cell with highest Ly6c1 expression is {cell_id_with_max_ly6c1} at index {max_expression_index}.")
else:
    print("Ly6c1 is not a valid gene name in this dataset.")

start_cell = cell_id_with_max_ly6c1 
start_cell

Define terminally differentiated monocyte by Ly6c expression

In [None]:
if 'Timd4' in adata.var_names:
    ly6c1_expression = adata[:, 'Timd4'].X

    # Check if data is stored as sparse matrix and convert to dense if necessary
    if isinstance(ly6c1_expression, scipy.sparse.spmatrix):
        ly6c1_expression = ly6c1_expression.toarray()
    # Flatten the array to 1D if necessary
    ly6c1_expression = ly6c1_expression.flatten()
    # Find the index of the maximum expression
    max_expression_index = ly6c1_expression.argmax()
    # Retrieve the cell ID using the index
    cell_id_with_max_ly6c1 = adata.obs_names[max_expression_index]

    print(f"Cell with highest Timd4 expression is {cell_id_with_max_ly6c1} at index {max_expression_index}.")
else:
    print("Timd4 is not a valid gene name in this dataset.")

In [None]:
if 'Ccr2' in adata.var_names:
    ly6c1_expression = adata[:, 'Ccr2'].X

    # Check if data is stored as sparse matrix and convert to dense if necessary
    if isinstance(ly6c1_expression, scipy.sparse.spmatrix):
        ly6c1_expression = ly6c1_expression.toarray()
    # Flatten the array to 1D if necessary
    ly6c1_expression = ly6c1_expression.flatten()
    # Find the index of the maximum expression
    max_expression_index = ly6c1_expression.argmax()
    # Retrieve the cell ID using the index
    cell_id_with_max_ly6c1 = adata.obs_names[max_expression_index]

    print(f"Cell with highest Ccr2 expression is {cell_id_with_max_ly6c1} at index {max_expression_index}.")
else:
    print("Ccr2 is not a valid gene name in this dataset.")

In [None]:
if 'Trem2' in adata.var_names:
    ly6c1_expression = adata[:, 'Trem2'].X

    # Check if data is stored as sparse matrix and convert to dense if necessary
    if isinstance(ly6c1_expression, scipy.sparse.spmatrix):
        ly6c1_expression = ly6c1_expression.toarray()
    # Flatten the array to 1D if necessary
    ly6c1_expression = ly6c1_expression.flatten()
    # Find the index of the maximum expression
    max_expression_index = ly6c1_expression.argmax()
    # Retrieve the cell ID using the index
    cell_id_with_max_ly6c1 = adata.obs_names[max_expression_index]

    print(f"Cell with highest Trem2 expression is {cell_id_with_max_ly6c1} at index {max_expression_index}.")
else:
    print("Trem2 is not a valid gene name in this dataset.")

In [None]:
terminal_states = pd.Series(
    ["Tissue_resident", "pro_inflam", "tissue_injury"],
    index=["GAGGCCTCAAACCATC-1-A4", "TCGCTCACAAGCACCC-1-A1", "TCCGTGTCATTGAAGA-1-A1"]
)

In [None]:
palantir.plot.highlight_cells_on_umap(adata, terminal_states)
plt.show()

Palantir generates the following results

- Pseudotime: Pseudo time ordering of each cell
- Terminal state probabilities: Matrix of cells X terminal states. Each entry represents the probability of the corresponding cell reaching the respective terminal state
- Entropy: A quantiative measure of the differentiation potential of each cell computed as the entropy of the multinomial terminal state probabilities

In [None]:
pr_res = palantir.core.run_palantir(
    adata, start_cell, num_waypoints=500, terminal_states=terminal_states)

### Visualizing Palantir results

In [None]:
palantir.plot.plot_palantir_results(adata, s=3)
plt.show()

In [None]:
adata.obs_names

In [None]:
cells = ['TTCTCTCTCTGCGAGC-1-A3', 'TTTCACACACATATCG-1-B2'] 

palantir.plot.plot_terminal_state_probs(adata, cells)
fig = plt.gcf()
ax = fig.axes[1]
ax.get_xaxis().set_visible(True)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
ax = fig.axes[0]
ax.get_xaxis().set_visible(True)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
plt.show()

In [None]:
palantir.plot.highlight_cells_on_umap(adata, cells)
plt.show()

### Gene Expression Trends

Selecting cells of a specific trend

In [None]:
masks = palantir.presults.select_branch_cells(adata, eps=0)

Visualizing the branch selection

In [None]:
palantir.plot.plot_branch_selection(adata)
plt.show()

In [None]:
palantir.plot.plot_trajectory(adata, "ACTGATGCATGTGGTT-1-B2")

In [None]:
palantir.plot.plot_trajectory(
    adata,
    "ACTGATGCATGTGGTT-1-B2",
    cell_color="palantir_entropy",
    n_arrows=10,
    color="red",
    scanpy_kwargs=dict(cmap="viridis"),
    arrowprops=dict(arrowstyle="-|>,head_length=.5,head_width=.5"),
)

### Compute a transition matrix with cellrank

In [None]:
pk = cr.kernels.PseudotimeKernel(adata, time_key="palantir_pseudotime")
pk.compute_transition_matrix()

print(pk)

In [None]:
pk.plot_projection(basis="umap", recompute=True)

In [None]:
gene_trends = palantir.presults.compute_gene_trends(
    adata,
    expression_key="MAGIC_imputed_data",
)

In [None]:
genes = ["Ly6c1", "Cd209a", "Gng11", "Ccl6", "Ms4a4c"]
palantir.plot.plot_gene_trends(adata, genes)
plt.show()

### Modify object to plot canonical marker genes

In [None]:
adata

In [None]:
adata_toplot = anndata.AnnData(X = np.sqrt(sc.pp.normalize_total(adata_raw, inplace = False)["X"]), var = adata_raw.var, obs = adata.obs, obsm = adata.obsm)
adata_toplot

In [None]:
sc.pl.umap(adata_toplot, frameon = False, color = ['C_scANVI', 'Cd79a', 'Nkg7', 'Klrb1', 'Ccr7', 'Itga1', 'Cx3cr1', 'Gzmk', 'Klrg1'], size = 5, legend_fontsize = 5, ncols = 3, cmap = 'RdPu')

#Cd79a, Ms4a1, Cd19                                             //B cells
#Nkg7, Gnly, Cd8a                                               //NKT
#Gnly, Xcl2, Nkg7                                               //NK
#Klrb1, Aqp3, Itgb1, Ccr7, Sell, Cd4                            //CD4 T cells
#Itga1, Itgae, Cxcr6, Cx3cr1, Gzmb, Gnly, Gzmk, Cd8a, Ccl5      //C8 T cells
#Klrg1                                                          //MCMV-spec T cells

### Compute `scib-metrics`

In [None]:
adata

In [None]:
adata.obs["seed_labels"].value_counts()

In [None]:
adata.X = adata.X.astype(float)

In [None]:
bm = Benchmarker(
    adata,
    batch_key="donor",
    label_key="seed_labels",                                       #take C_scANVI oder seed_labels containing 'Unknown'??
    embedding_obsm_keys=["X_scVI", "X_scANVI", "X_pca"],
    n_jobs=3,
)
bm.benchmark()

In [None]:
bm.plot_results_table(min_max_scale=False)

In [None]:
df = bm.get_results(min_max_scale=False)
df.transpose()

In [None]:
adata.X = adata.X.astype(int)

In [None]:
X_is_raw(adata)

### Visualise proportions

In [None]:
def split_umap(adata, split_by, ncol=2, nrow=None, **kwargs):
    categories = adata.obs[split_by].cat.categories
    if nrow is None:
        nrow = int(np.ceil(len(categories) / ncol))
    fig, axs = plt.subplots(nrow, ncol, figsize=(5*ncol, 4*nrow))
    axs = axs.flatten()
    for i, cat in enumerate(categories):
        ax = axs[i]
        sc.pl.umap(adata[adata.obs[split_by] == cat], ax=ax, show=False, title=cat, **kwargs)
    plt.tight_layout()

In [None]:
from pylab import *

cmap = cm.get_cmap('tab20', 20)  # matplotlib color palette name, n colors
for i in range(cmap.N):
    rgb = cmap(i)[:3]  # will return rgba, we take only first 3 so we get rgb
    print(matplotlib.colors.rgb2hex(rgb))

cmap = cm.get_cmap('plasma', 101)
color_list = [matplotlib.colors.rgb2hex(cmap(i)[:3]) for i in range(cmap.N)]
print(color_list)

In [None]:
adata.obs["C_scANVI"].cat.categories

In [None]:
sc.pl.umap(adata, 
           frameon = False, 
           color = ['C_scANVI'], 
           size = 5, 
           legend_fontsize = 6, 
           ncols = 1, 
           palette = {
                'B':                    '#FF7F0E',  # Orange
                'CD4':                  '#945943',  # Brown
                'CD8':                  '#F9D090',  # Beige-Yellow
                'NK':                   '#C70039',  # Crimson-Red
                'NKT':                  '#B58A58',  # Teak (brown)
           })


In [None]:
split_umap(adata, color = ['C_scANVI'], split_by = 'infection', frameon = False, size = 12, legend_fontsize = 5, ncols = 4,            
           palette = {
                'B':                    '#FF7F0E',  # Orange
                'CD4':                  '#945943',  # Brown
                'CD8':                  '#F9D090',  # Beige-Yellow
                'NK':                   '#C70039',  # Crimson-Red
                'NKT':                  '#B58A58',  # Teak (brown)
           })

In [None]:
split_umap(adata, color = ['C_scANVI'], split_by = 'condition', frameon = False, size = 20, legend_fontsize = 5, ncols = 4, palette = {
                'B':                    '#FF7F0E',  # Orange
                'CD4':                  '#945943',  # Brown
                'CD8':                  '#F9D090',  # Beige-Yellow
                'NK':                   '#C70039',  # Crimson-Red
                'NKT':                  '#B58A58',  # Teak (brown)
           })

In [None]:
split_umap(adata, color = ['C_scANVI'], split_by = 'genotype', frameon = False, size = 17, legend_fontsize = 5, ncols = 4, palette = {
                'B':                    '#FF7F0E',  # Orange
                'CD4':                  '#945943',  # Brown
                'CD8':                  '#F9D090',  # Beige-Yellow
                'NK':                   '#C70039',  # Crimson-Red
                'NKT':                  '#B58A58',  # Teak (brown)
           })

In [None]:
split_umap(adata, color = ['C_scANVI'], split_by = 'model', frameon = False, size = 10, legend_fontsize = 5, ncols = 4, palette = {
                'B':                    '#FF7F0E',  # Orange
                'CD4':                  '#945943',  # Brown
                'CD8':                  '#F9D090',  # Beige-Yellow
                'NK':                   '#C70039',  # Crimson-Red
                'NKT':                  '#B58A58',  # Teak (brown)
           })

In [None]:
df = adata_toplot.obs.groupby(['condition', 'C_scANVI']).size().reset_index(name = 'counts')

grouped = df.groupby('condition')['counts'].apply(lambda x: x / x.sum() * 100)
grouped = grouped.reset_index()

df['proportions'] = grouped['counts']
df['waffle_counts'] = (df['proportions'] * 10).astype(int)

In [None]:
from matplotlib.colors import ListedColormap
cmap_ac = ListedColormap(['#FF7F0E', '#945943', '#F9D090', '#C70039', '#B58A58']) 
cmap_ac

In [None]:
tab20_palette = cmap_ac 


for group in df['condition'].unique():
    temp_df = df[df['condition'] == group]
    
    data = dict(zip(temp_df['C_scANVI'], temp_df['waffle_counts']))
    colors = [tab20_palette(i) for i in range(len(temp_df['C_scANVI']))]
    fig = plt.figure(
        FigureClass = Waffle, 
        rows = 10, 
        values = data, 
        title = {'label': f'Condition {group}', 'loc': 'left', 'fontsize': 14},
        labels = [f"{k} ({v}%)" for k, v in zip(temp_df['C_scANVI'], temp_df['proportions'].round(2))],
        #legend = {'loc': 'lower left', 'bbox_to_anchor': (0, -0.4), 'ncol': len(data), 'framealpha': 0},
        legend = {'loc': 'upper left', 'bbox_to_anchor': (0, 0), 'ncol': 8, 'framealpha': 0, 'fontsize': 14},
        figsize = (40, 4),
        colors = colors
    )
    plt.show()

In [None]:
pd.crosstab(adata.obs['C_scANVI'], adata.obs['condition'])

### Visualize n_genes by condition to be able to set a threshold for trimming out low-gene-clusters

In [None]:
adata.obs['C_scANVI'].cat.categories

In [None]:
sns.set(style="whitegrid")
covariate_to_visualize = 'n_genes_by_counts'

plt.figure(figsize=(10, 6))
sns.histplot(data=adata.obs, x=covariate_to_visualize, hue='C_scANVI', stat='count', common_norm=False)
plt.xlabel(covariate_to_visualize)
plt.ylabel('Abundance')
plt.title(f'Abundance Plot of {covariate_to_visualize} by Condition')
plt.legend(['B', 'CD4', 'CD8', 'NK', 'NKT'], title='Cell_Type', loc='upper right')

plt.show()

In [None]:
sns.set(style="whitegrid")
covariate_to_visualize = 'n_genes_by_counts'

sample_names = adata.obs['condition'].unique()
num_samples = len(sample_names)
color_palette = sns.color_palette("Set1", n_colors=num_samples)

g = sns.FacetGrid(adata.obs, col="C_scANVI", col_wrap=3, height=5, palette=color_palette)
g.map_dataframe(sns.histplot, x=covariate_to_visualize, stat='count', common_norm=False)

g.set_axis_labels(covariate_to_visualize, 'Abundance')
g.set_titles(col_template="{col_name}")
#g.add_legend(['B_cells', 'CD4+T', 'CD8+T', 'DC', 'Hematopoetic', 'Macrophages', 'Mast_cells', 'Monocytes', 'NK', 'NKT', 'Neutrophils', 'Plasma_cells', 'Platelets', 'Treg', 'pDC'], title='Cell_Type', loc='upper right')

plt.tight_layout()
plt.show()

Set threshold at 500 genes?

### Export annotated sample object 

In [None]:
adata.obs.index

In [None]:
prot = mdata.mod['prot']
prot.obs_names

In [None]:
adata.obs['C_scANVI'].cat.categories

In [None]:
adata.obs['C_scANVI'].value_counts()

### Export annotated object with raw counts

In [None]:
adata

In [None]:
adata_raw

In [None]:
adata_export = anndata.AnnData(X = adata_raw.X, obs = adata.obs, var = adata_raw.var)
adata_export.obsm['X_scVI'] = adata.obsm['X_scVI'].copy()
adata_export.obsm['X_umap'] = adata.obsm['X_umap'].copy()
adata_export.obsm['X_scANVI'] = adata.obsm['X_scANVI'].copy()
adata_export

#### Update the mdata object

In [None]:
mdata.mod['rna'] = adata_export
mdata

In [None]:
mdata.mod['rna'].obs_names

In [None]:
mdata.mod['prot'].obs_names

In [None]:
rna_cells_export = set(mdata.mod['rna'].obs_names)
mask = mdata.mod['prot'].obs_names.isin(rna_cells_export)
filtered_prot_export = mdata.mod['prot'][mask]
mdata = mu.MuData({"rna": mdata.mod['rna'], "prot": filtered_prot_export})
mdata

In [None]:
mdata.write('/home/acirnu/data/ACM_cardiac_leuco/4_Seed_labeling_with_scANVI/ACM_lymphoids_scANVI_general_celltypes_from_HCA_lymphoids_muon_ac240506.raw.h5mu')

#### Create stacked barplots in addition to waffle plots

In [None]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')
warnings.simplefilter(action = 'ignore')
scvi.settings.seed = 1712
%config InlineBackend.print_figure_kwargs = {'facecolor' : "w"}
%config InlineBackend.figure_format = 'retina'
torch.set_float32_matmul_precision('medium')
arches_params = dict(
    use_layer_norm = "both",
    use_batch_norm = "none",
    encode_covariates = True,
    dropout_rate = 0.2,
    n_layers = 3,
)

In [None]:
mdata = mu.read_h5mu('/home/acirnu/data/ACM_cardiac_leuco/4_Seed_labeling_with_scANVI/ACM_lymphoids_scANVI_general_celltypes_from_HCA_lymphoids_muon_ac240506.raw.h5mu')     

In [None]:
adata = mdata.mod['rna']
adata

In [None]:
df = adata.obs.groupby(['condition', 'C_scANVI']).size().reset_index(name = 'counts')

grouped = df.groupby('condition')['counts'].apply(lambda x: x / x.sum() * 100)
grouped = grouped.reset_index()

df['proportions'] = grouped['counts']
df.head(10)

In [None]:
Pkp2_conditions = ["Pkp2_Ctr_noninf", "Pkp2_HetKO_noninf","Pkp2_Ctr_MCMV", "Pkp2_HetKO_MCMV"]
df_Pkp2 = df[df['condition'].isin(Pkp2_conditions)]
df_Pkp2 = df_Pkp2.set_index('condition').loc[Pkp2_conditions].reset_index()

In [None]:
#Define a color map
cmap_ac = ['#FF7F0E', '#945943', '#F9D090', '#C70039', '#B58A58']  

# Pivot the DataFrame to have 'leiden' clusters as columns
pivot_df = df_Pkp2.pivot(index='condition', columns='C_scANVI', values='proportions').fillna(0)

# Plot the stacked barplot
plt.figure()  # Create a new figure
ax = pivot_df.plot(kind='bar', stacked=True, color=cmap_ac, edgecolor='none')

# Set the figure size explicitly after creating the figure
plt.gcf().set_size_inches(7, 7)

# Remove the grid
ax.grid(False)

# Add labels and title
plt.xlabel('')
plt.ylabel('Proportion (%)')
plt.title('')
plt.legend(title='Cluster', bbox_to_anchor=(1.05, 1), loc='upper left')

# Show the plot
plt.tight_layout()
plt.savefig('stacked_barplot_Pkp2.png', dpi=300)
plt.show()


In [None]:
Ttn_conditions = ["Ttn_Ctr_noninf", "Ttn_HetKO_noninf","Ttn_Ctr_MCMV", "Ttn_HetKO_MCMV"]
df_Ttn = df[df['condition'].isin(Ttn_conditions)]
df_Ttn = df_Ttn.set_index('condition').loc[Ttn_conditions].reset_index()

In [None]:
#Define a color map
cmap_ac = ['#FF7F0E', '#945943', '#F9D090', '#C70039', '#B58A58']  

# Pivot the DataFrame to have 'leiden' clusters as columns
pivot_df = df_Ttn.pivot(index='condition', columns='C_scANVI', values='proportions').fillna(0)

# Plot the stacked barplot
plt.figure()  # Create a new figure
ax = pivot_df.plot(kind='bar', stacked=True, color=cmap_ac, edgecolor='none')

# Set the figure size explicitly after creating the figure
plt.gcf().set_size_inches(7, 7)

# Remove the grid
ax.grid(False)

# Add labels and title
plt.xlabel('')
plt.ylabel('Proportion (%)')
plt.title('')
plt.legend(title='Cluster', bbox_to_anchor=(1.05, 1), loc='upper left')

# Show the plot
plt.tight_layout()
plt.savefig('stacked_barplot_Ttn.png', dpi=300)
plt.show()