# Step5b: Perturbation Simulation

## Dependencies

In [None]:
#%pip install velocyto

In [None]:
import mudata
import os
import scanpy as sc
import anndata
import matplotlib
import matplotlib.pyplot as plt
import adjustText
import numpy as np
import pandas as pd
import sys

In [None]:
from scenicplus.simulation import (
    train_gene_expression_models,
    simulate_perturbation,
    plot_perturbation_effect_in_embedding
)

In [None]:
# Determine the folder in which the code is executed
WORKING_DIR = os.getcwd()
sys.path.append(os.path.abspath( WORKING_DIR))

# Run the params codes
%run -i ../../globalParams.py #GlobalParams
%run -i ../../sampleParams.py #sampleParams
%run -i ./analysisParams.py #AnalysisParams

In [None]:
%matplotlib inline

In [None]:
scplus_mdata = mudata.read(os.path.join(PATH_TO_THE_04d_OUPUT_FOLDER , "outs/scplusmdata.h5mu"))

In [None]:
scplus_mdata

In [None]:
eRegulon_gene_AUC = anndata.concat(
    [scplus_mdata["direct_gene_based_AUC"], scplus_mdata["extended_gene_based_AUC"]],
    axis = 1,
)
eRegulon_gene_AUC.obs = scplus_mdata.obs


In [None]:
eRegulon_gene_AUC

## PCA Visualization

In [None]:
sc.pp.pca(eRegulon_gene_AUC)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors  # Correct import for color conversion

# Extract unique values
unique_values = pd.unique(eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME])

# Generate a colormap with the same number of unique values
colormap = plt.cm.get_cmap('tab10', len(unique_values))  # You can choose different colormaps

# Create the color dictionary
color_dict_line = {unique_value: mcolors.rgb2hex(colormap(i)) for i, unique_value in enumerate(unique_values)}

# Display the color dictionary
print(color_dict_line)

In [None]:
def plot_mm_line_pca(ax):
    texts = []
    # Plot PCA
    ax.scatter(
        eRegulon_gene_AUC.obsm["X_pca"][:, 0],
        eRegulon_gene_AUC.obsm["X_pca"][:, 1],
        color = [color_dict_line[line] for line in eRegulon_gene_AUC.obs["scRNA_counts:"+ CELL_TYPE_COLNAME]]
    )
    # Plot labels
    for line in set(eRegulon_gene_AUC.obs["scRNA_counts:"+ CELL_TYPE_COLNAME]):
        line_bc_idc = np.arange(len(eRegulon_gene_AUC.obs_names))[eRegulon_gene_AUC.obs["scRNA_counts:"+ CELL_TYPE_COLNAME] == line]
        avg_x, avg_y = eRegulon_gene_AUC.obsm["X_pca"][line_bc_idc, 0:2].mean(0)
        texts.append(
            ax.text(
                avg_x,
                avg_y,
                line,
                fontweight = "bold"
            )
        )
    adjustText.adjust_text(texts)

fig, ax = plt.subplots()
plot_mm_line_pca(ax)

In [None]:
gene_tf_direct_extended = pd.concat(
    [
        scplus_mdata.uns["direct_e_regulon_metadata"][["Gene", "TF"]].drop_duplicates(),
        scplus_mdata.uns["extended_e_regulon_metadata"][["Gene", "TF"]].drop_duplicates()
    ]
).drop_duplicates()
gene_to_TF = gene_tf_direct_extended.groupby("Gene")["TF"].apply(lambda tfs: list(tfs)).to_dict()

## UMAP Visualization

In [None]:
sc.pp.neighbors(eRegulon_gene_AUC, use_rep = "X")
sc.tl.umap(eRegulon_gene_AUC)
sc.pl.umap(eRegulon_gene_AUC, color =  "scRNA_counts:"+ CELL_TYPE_COLNAME)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors  # Correct import for color conversion

# Extract unique values
unique_values = pd.unique(eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME])

# Generate a colormap with the same number of unique values
colormap = plt.cm.get_cmap('tab10', len(unique_values))  # You can choose different colormaps

# Create the color dictionary
color_dict_line = {unique_value: mcolors.rgb2hex(colormap(i)) for i, unique_value in enumerate(unique_values)}

# Display the color dictionary
print(color_dict_line)

In [None]:
def plot_mm_line_umap(ax):
    texts = []
    # Plot UMAP
    ax.scatter(
        eRegulon_gene_AUC.obsm["X_umap"][:, 0],
        eRegulon_gene_AUC.obsm["X_umap"][:, 1],
        color = [color_dict_line[line] for line in eRegulon_gene_AUC.obs["scRNA_counts:"+ CELL_TYPE_COLNAME]]
    )
    # Plot labels
    for line in set(eRegulon_gene_AUC.obs["scRNA_counts:"+ CELL_TYPE_COLNAME]):
        line_bc_idc = np.arange(len(eRegulon_gene_AUC.obs_names))[eRegulon_gene_AUC.obs["scRNA_counts:"+ CELL_TYPE_COLNAME] == line]
        avg_x, avg_y = eRegulon_gene_AUC.obsm["X_umap"][line_bc_idc, 0:2].mean(0)
        texts.append(
            ax.text(
                avg_x,
                avg_y,
                line,
                fontweight = "bold"
            )
        )
    adjustText.adjust_text(texts)

fig, ax = plt.subplots()
plot_mm_line_umap(ax)

In [None]:
gene_tf_direct_extended = pd.concat(
    [
        scplus_mdata.uns["direct_e_regulon_metadata"][["Gene", "TF"]].drop_duplicates(),
        scplus_mdata.uns["extended_e_regulon_metadata"][["Gene", "TF"]].drop_duplicates()
    ]
).drop_duplicates()
gene_to_TF = gene_tf_direct_extended.groupby("Gene")["TF"].apply(lambda tfs: list(tfs)).to_dict()

## WNN_UMAP Visualization

In [None]:
TABLE_EMBEDDING_OF_INTEREST

In [None]:
#Inject the WNN_UMAP values in place of UMAP
# Read the csv of the WNN_UMAP
WNN_UMAP_TABLE = pd.read_csv( os.path.join( PATH_EXPERIMENT_OUTPUT, ANALYSIS_04a0_STEP_NAME, "Embeddings",TABLE_EMBEDDING_OF_INTEREST))

# Step 1: Rename the first column of WNN_UMAP_TABLE to 'cell_names'
# Create a dictionary to map old column names to new ones
column_mapping = {
    WNN_UMAP_TABLE.columns[0]: 'cell_names',
    WNN_UMAP_TABLE.columns[1]: 'WNNUMAP_1',
    WNN_UMAP_TABLE.columns[2]: 'WNNUMAP_2'
}

# Rename the columns
WNN_UMAP_TABLE = WNN_UMAP_TABLE.rename(columns=column_mapping)

# Step 2: Modify the names in the 'cell_names' column to match the format in eRegulon_gene_AUC.obs_names
WNN_UMAP_TABLE['cell_names'] = WNN_UMAP_TABLE['cell_names'].apply(lambda x: '-'.join(x.split('_')[1:]) + '-' + x.split('_')[0] + '___' + x.split('_')[0])

# Step 3: Keep only rows where 'cell_names' in WNN_UMAP_TABLE match with eRegulon_gene_AUC.obs_names
WNN_UMAP_TABLE_filtered = WNN_UMAP_TABLE[WNN_UMAP_TABLE['cell_names'].isin(eRegulon_gene_AUC.obs_names)]

# Step 4: Reorder WNN_UMAP_TABLE_filtered to match the order of eRegulon_gene_AUC.obs_names
WNN_UMAP_TABLE_ordered = WNN_UMAP_TABLE_filtered.set_index('cell_names').reindex(eRegulon_gene_AUC.obs_names).reset_index()

# Check the reordered table
print(WNN_UMAP_TABLE_ordered)

In [None]:
#Inject the values of WNN_UMAP
# Step 1: Extract the 'WNNUMAP_1' and 'WNNUMAP_2' columns from WNN_UMAP_TABLE_ordered
wnn_umap_values = WNN_UMAP_TABLE_ordered[['WNNUMAP_1', 'WNNUMAP_2']].values

# Step 2: Inject these values into eRegulon_gene_AUC.obsm["X_umap"]
eRegulon_gene_AUC.obsm["X_umap"] = wnn_umap_values

# Step 3: Check that the values were successfully updated
print(eRegulon_gene_AUC.obsm["X_umap"])

In [None]:
#Def the visualisation function
def plot_mm_line_umap(ax):
    texts = []
    # Plot UMAP
    ax.scatter(
        eRegulon_gene_AUC.obsm["X_umap"][:, 0],
        eRegulon_gene_AUC.obsm["X_umap"][:, 1],
        color = [color_dict_line[line] for line in eRegulon_gene_AUC.obs["scRNA_counts:"+ CELL_TYPE_COLNAME]]
    )
    # Plot labels
    for line in set(eRegulon_gene_AUC.obs["scRNA_counts:"+ CELL_TYPE_COLNAME]):
        line_bc_idc = np.arange(len(eRegulon_gene_AUC.obs_names))[eRegulon_gene_AUC.obs["scRNA_counts:"+ CELL_TYPE_COLNAME] == line]
        avg_x, avg_y = eRegulon_gene_AUC.obsm["X_umap"][line_bc_idc, 0:2].mean(0)
        texts.append(
            ax.text(
                avg_x,
                avg_y,
                line,
                fontweight = "bold"
            )
        )
    adjustText.adjust_text(texts)
    
# Have a look at the figures
fig, ax = plt.subplots()
plot_mm_line_umap(ax)

In [None]:
#Improved dataviz
def plot_mm_line_umap(ax, eRegulon_gene_AUC, CELL_TYPE_COLNAME, color_dict_line, point_size=1, legend_loc='center left'):
    texts = []
    
    # Plot UMAP
    scatter = ax.scatter(
        eRegulon_gene_AUC.obsm["X_umap"][:, 0],
        eRegulon_gene_AUC.obsm["X_umap"][:, 1],
        c=[color_dict_line[line] for line in eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME]],
        s=point_size,  # Set point size
        alpha=0.7  # Add some transparency
    )
    
    # Plot labels
    for line in set(eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME]):
        line_bc_idc = np.where(eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME] == line)[0]
        avg_x, avg_y = eRegulon_gene_AUC.obsm["X_umap"][line_bc_idc, 0:2].mean(0)
        texts.append(
            ax.text(
                avg_x,
                avg_y,
                line,
                fontweight="bold",
                fontsize=8  # Adjust font size if needed
            )
        )
    
    # Adjust text labels to avoid overlap
    adjustText.adjust_text(texts, arrowprops=dict(arrowstyle='->', color='black'))
    
    # Add legend
    unique_lines = list(set(eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME]))
    legend_elements = [plt.Line2D([0], [0], marker='o', color='w', 
                                  markerfacecolor=color_dict_line[line], markersize=8, label=line)
                       for line in unique_lines]
    ax.legend(handles=legend_elements, loc=legend_loc, bbox_to_anchor=(1.05, 0.5), 
              title=CELL_TYPE_COLNAME, title_fontsize='large')
    
    # Set labels and title
    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    ax.set_title('UMAP Plot')
    
    # Adjust layout to prevent clipping of labels
    plt.tight_layout()
    
    return ax

# Usage
fig, ax = plt.subplots(figsize=(12, 8))  # Increase figure size to accommodate legend
plot_mm_line_umap(ax, eRegulon_gene_AUC, CELL_TYPE_COLNAME, color_dict_line)
plt.show()

In [None]:
color_dict_line

In [None]:
the_values_to_see = pd.unique(eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME])
print(the_values_to_see[5:11])

In [None]:
# Create color dictionary
color_dict_line = dict(zip(cluster_names, color_codes))

def plot_mm_line_umap(ax, eRegulon_gene_AUC, CELL_TYPE_COLNAME, color_dict_line, point_size= 1.5, legend_loc='center left'):
    texts = []
    
    # Plot UMAP
    scatter = ax.scatter(
        eRegulon_gene_AUC.obsm["X_umap"][:, 0],
        eRegulon_gene_AUC.obsm["X_umap"][:, 1],
        c=[color_dict_line.get(line, "#000000") for line in eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME]],
        s=point_size,
        alpha=0.7
    )
    
    # Plot labels
    for line in set(eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME]):
        line_bc_idc = np.where(eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME] == line)[0]
        avg_x, avg_y = eRegulon_gene_AUC.obsm["X_umap"][line_bc_idc, 0:2].mean(0)
        texts.append(
            ax.text(
                avg_x,
                avg_y,
                line,
                fontweight="bold",
                fontsize=8
            )
        )
    
    # Adjust text labels to avoid overlap
    adjustText.adjust_text(texts, arrowprops=dict(arrowstyle='->', color='black'))
    
    # Add legend
    legend_elements = [plt.Line2D([0], [0], marker='o', color='w', 
                                  markerfacecolor=color_dict_line[line], markersize=8, label=line)
                       for line in color_dict_line if line in set(eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME])]
    ax.legend(handles=legend_elements, loc=legend_loc, bbox_to_anchor=(1.05, 0.5), 
              title=CELL_TYPE_COLNAME, title_fontsize='large')
    
    # Set labels and title
    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    ax.set_title('UMAP Plot')
    
    # Adjust layout to prevent clipping of labels
    plt.tight_layout()
    
    return ax

# Usage
fig, ax = plt.subplots(figsize=(12, 8))  # Increase figure size to accommodate legend
plot_mm_line_umap(ax, eRegulon_gene_AUC, CELL_TYPE_COLNAME, color_dict_line)
plt.show()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors  # Correct import for color conversion

# Extract unique values
unique_values = pd.unique(eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME])

# Generate a colormap with the same number of unique values
colormap = plt.cm.get_cmap('tab10', len(unique_values))  # You can choose different colormaps

# Create the color dictionary
color_dict_line = {unique_value: mcolors.rgb2hex(colormap(i)) for i, unique_value in enumerate(unique_values)}

# Display the color dictionary
print(color_dict_line)

In [None]:
#Old version

#def plot_mm_line_umap(ax):
#    texts = []
#    # Plot UMAP
#    ax.scatter(
#        eRegulon_gene_AUC.obsm["X_umap"][:, 0],
#        eRegulon_gene_AUC.obsm["X_umap"][:, 1],
#        color = [color_dict_line[line] for line in eRegulon_gene_AUC.obs["scRNA_counts:"+ CELL_TYPE_COLNAME]]
#    )
    # Plot labels
#    for line in set(eRegulon_gene_AUC.obs["scRNA_counts:"+ CELL_TYPE_COLNAME]):
#        line_bc_idc = np.arange(len(eRegulon_gene_AUC.obs_names))[eRegulon_gene_AUC.obs["scRNA_counts:"+ CELL_TYPE_COLNAME] == line]
#        avg_x, avg_y = eRegulon_gene_AUC.obsm["X_umap"][line_bc_idc, 0:2].mean(0)
#        texts.append(
#            ax.text(
#                avg_x,
#                avg_y,
#                line,
#                fontweight = "bold"
#            )
#        )
#    adjustText.adjust_text(texts)

#fig, ax = plt.subplots()
#plot_mm_line_umap(ax)

In [None]:
gene_tf_direct_extended = pd.concat(
    [
        scplus_mdata.uns["direct_e_regulon_metadata"][["Gene", "TF"]].drop_duplicates(),
        scplus_mdata.uns["extended_e_regulon_metadata"][["Gene", "TF"]].drop_duplicates()
    ]
).drop_duplicates()
gene_to_TF = gene_tf_direct_extended.groupby("Gene")["TF"].apply(lambda tfs: list(tfs)).to_dict()


## Plot with TF driven modifications

## Perturbation analysis

In [None]:
gene_to_TF[FT_OF_INTEREST]

In [None]:
# use a subset of genes, just so the notebook runs fast
genes_to_use = scplus_mdata.uns["direct_e_regulon_metadata"].sort_values("triplet_rank")["Gene"].iloc[0:5_000].drop_duplicates()

In [None]:
regressors = train_gene_expression_models(
    df_EXP = scplus_mdata["scRNA_counts"].to_df(),
    gene_to_TF = gene_to_TF,
    genes = genes_to_use,
)

In [None]:
perturbation_over_iter = simulate_perturbation(
    df_EXP = scplus_mdata["scRNA_counts"].to_df(),
    perturbation = {FT_OF_INTEREST: 0},
    keep_intermediate = True,
    n_iter = 5,
    regressors = regressors
)


In [None]:
# Set the file path for saving the plot
output_path = os.path.join(PATH_ANALYSIS_OUTPUT, "Donwstream_Genes_After_disruption", f"{FT_OF_INTEREST}_perturbation_effect_in_{GROUP_OF_INTEREST}cluster.pdf")


genes_to_show = gene_to_TF[FT_OF_INTEREST]
cell_line = GROUP_OF_INTEREST
fig, ax = plt.subplots()
baseline = perturbation_over_iter[0].groupby(eRegulon_gene_AUC.obs["scRNA_counts:"+ CELL_TYPE_COLNAME]).mean().loc[cell_line, genes_to_show]
for gene in genes_to_show:
    ax.plot(
        np.arange(5) + 1,
        [
            np.log2(perturbation_over_iter[i].groupby(eRegulon_gene_AUC.obs["scRNA_counts:"+ CELL_TYPE_COLNAME]).mean().loc[cell_line, gene] / baseline[gene])
            for i in np.arange(5) + 1
        ],
        label = gene
    )
ax.set_ylabel("Predicted $log{_2}FC$")
ax.set_xlabel("Iteration")
ax.legend()
ax.axhline(y = 0, color = "black")
ax.grid("gray")
ax.set_axisbelow(True)

# Save the plot to the specified file path
plt.savefig(output_path, format='pdf')

In [None]:
# Create color dictionary
color_dict_line = dict(zip(cluster_names, color_codes))

# Set the file path for saving the plot
output_path = os.path.join(PATH_ANALYSIS_OUTPUT, "Disrupted_heatmap", f"{FT_OF_INTEREST}_perturbation_effect.pdf")


def plot_mm_line_umap(ax, eRegulon_gene_AUC, CELL_TYPE_COLNAME, color_dict_line, point_size= 1.5, legend_loc='center left'):
    texts = []
    
    # Plot UMAP
    scatter = ax.scatter(
        eRegulon_gene_AUC.obsm["X_umap"][:, 0],
        eRegulon_gene_AUC.obsm["X_umap"][:, 1],
        c=[color_dict_line.get(line, "#000000") for line in eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME]],
        s=point_size,
        alpha=0.7
    )
    
    # Plot labels
    for line in set(eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME]):
        line_bc_idc = np.where(eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME] == line)[0]
        avg_x, avg_y = eRegulon_gene_AUC.obsm["X_umap"][line_bc_idc, 0:2].mean(0)
        texts.append(
            ax.text(
                avg_x,
                avg_y,
                line,
                fontweight="bold",
                fontsize=8
            )
        )
    
    # Adjust text labels to avoid overlap
    adjustText.adjust_text(texts, arrowprops=dict(arrowstyle='->', color='black'))
    
    # Add legend
    legend_elements = [plt.Line2D([0], [0], marker='o', color='w', 
                                  markerfacecolor=color_dict_line[line], markersize=8, label=line)
                       for line in color_dict_line if line in set(eRegulon_gene_AUC.obs["scRNA_counts:" + CELL_TYPE_COLNAME])]
    ax.legend(handles=legend_elements, loc=legend_loc, bbox_to_anchor=(1.05, 0.5), 
              title=CELL_TYPE_COLNAME, title_fontsize='large')
    
    # Set labels and title
    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    ax.set_title('UMAP Plot')
    
    # Adjust layout to prevent clipping of labels
    plt.tight_layout()
    
    return ax

# Usage
fig, ax = plt.subplots(figsize=(12, 8))  # Increase figure size to accommodate legend
plot_mm_line_umap(ax, eRegulon_gene_AUC, CELL_TYPE_COLNAME, color_dict_line)

# Umap Dataviz
plot_perturbation_effect_in_embedding(
    perturbed_matrix = perturbation_over_iter[5],
    original_matrix = perturbation_over_iter[0],
    embedding = eRegulon_gene_AUC.obsm["X_umap"][:, 0:2],
    AUC_kwargs = {},
    ax = ax,
    eRegulons = pd.concat(
        [
            scplus_mdata.uns["direct_e_regulon_metadata"],
            scplus_mdata.uns["extended_e_regulon_metadata"]
        ]
    ),
    n_cpu = NUMBER_CPU
)

# Save the plot to the specified file path
plt.savefig(output_path, format='pdf')