In [86]:
# Import required libraries
import pandas as pd
import numpy as np
from pathlib import Path

# Import custom functions from miic_helper
import sys
sys.path.append('/Users/alichemkhi/Desktop/myProjects/miic_helper/src') # Adjust this path as needed
from process_anndata_object import create_anndata_from_dataframes, save_anndata_to_csv
from compute_MI_wrapper import run_miic_selection

In [115]:
# functions

def get_variable_names(wnt_summary_path, bmp_summary_path):
    # Read the WNT and BMP summary files
    wnt_df = pd.read_csv(wnt_summary_path, sep=' ')
    bmp_df = pd.read_csv(bmp_summary_path, sep=' ')

    # Extract variable names
    bmp_df_varnames = set(bmp_df['var_names'])
    wnt_df_varnames = set(wnt_df['var_names'])

    # Get union of the variable names
    var_names = bmp_df_varnames.union(wnt_df_varnames)

    return var_names


def extract_upper_triangle_mi(miic_results_df, meta_cols):
    
    cleaned_results = miic_results_df.drop(columns=meta_cols)
    
    # Check if the results are square matrix
    assert cleaned_results.shape[1] == cleaned_results.shape[0], "Results matrix must be square"
    
    # Get upper triangle indices (excluding diagonal)
    upper_triangle_indices = np.triu_indices(cleaned_results.shape[0], k=1)
    
    # Extract the upper triangle values
    upper_triangle_values = cleaned_results.values[upper_triangle_indices]
    
    # Extract corresponding variable name pairs
    upper_triangle_var_names = [
        (cleaned_results.index[i], cleaned_results.columns[j])
        for i, j in zip(*upper_triangle_indices)
    ]
    
    return upper_triangle_values, upper_triangle_var_names, cleaned_results


def plot_wnt_bmp_mi_scatter(upper_triangle_df, output_html_path):
    """
    Plots a scatter plot of WNT vs BMP Mutual Information values with distance-based coloring.
    
    Parameters:
    - upper_triangle_df (pd.DataFrame): DataFrame containing 'WNT_MI', 'BMP_MI', 'dist', and 'pair' columns.
    - output_html_path (str): Path to save the interactive HTML plot.
    """
    import plotly.express as px

    # Filter the dataset and add a color column
    filtered_df = upper_triangle_df[(upper_triangle_df['WNT_MI'] > 0) & (upper_triangle_df['BMP_MI'] > 0)].copy()
    filtered_df['color'] = filtered_df['dist'].apply(
        lambda x: 'High Distance (perpendicular >0.1)' if x > 0.1 else 'Low Distance (perpendicular ≤0.1)'
    )


    basename = output_html_path.name.strip('_mi_scatter_colored_unfiltered.html')   
    # Create the scatter plot
    fig = px.scatter(
        filtered_df, 
        x='WNT_MI', 
        y='BMP_MI', 
        color='color',
        color_discrete_map={
            'High Distance (perpendicular >0.1)': 'red',
            'Low Distance (perpendicular ≤0.1)': 'blue'
        },
        hover_data=['pair', 'dist'],
        title=f'{basename} \n Mutual Information (Distance-based Coloring)',
        labels={'WNT_MI': 'WNT Mutual Information', 'BMP_MI': 'BMP Mutual Information'}
    )

    # Add a diagonal line for reference
    fig.add_shape(
        type='line',
        x0=0, y0=0, x1=1.5, y1=1.5,
        line=dict(color='black', width=2, dash='dash'),
        name='Diagonal Line'
    )

    # Update marker properties and show the plot
    fig.update_traces(marker=dict(size=10, opacity=0.7))
    fig.show()

    # Save the plot as an HTML file
    fig.write_html(output_html_path)


def process_mi(adata, cell_types_filter, outdir,metadata_df, var_names):
    # Subset the object to get only t37 and t36
    adata_filtered = adata[adata.obs['time'].isin(['t37', 't36'])]
    # Subset the object to get only the cells with condition "BMP"
    bmp_cells = adata_filtered[adata_filtered.obs['condition'] == 'BMP']
    # Subset the object to get only the cells with condition "WNT"
    wnt_cells = adata_filtered[adata_filtered.obs['condition'] == 'WNT']

    # Subset by cell type
    if cell_types_filter['BMP'] is not None:
        bmp_cells = bmp_cells[bmp_cells.obs['celltype_grouped'].isin(cell_types_filter['BMP'])]
    if cell_types_filter['WNT'] is not None:
        wnt_cells = wnt_cells[wnt_cells.obs['celltype_grouped'].isin(cell_types_filter['WNT'])]

    # Save BMP cells - both counts and metadata
    bmp_count_path, bmp_meta_path = save_anndata_to_csv(
        bmp_cells, 
        outdir / f"bmp_data_{cell_types_filter['BMP'] if cell_types_filter['BMP'] else 'all'}",
        save_counts=True,
        save_metadata=True,
        transpose_counts=False
    )

    # Save WNT cells - both counts and metadata
    wnt_count_path, wnt_meta_path = save_anndata_to_csv(
        wnt_cells,
        outdir / f"wnt_data_{cell_types_filter['WNT'] if cell_types_filter['WNT'] else 'all'}",
        save_counts=True,
        save_metadata=True,
        transpose_counts=False
    )

    # Run MIIC selection for BMP and WNT
    out_path_BMP = str(outdir / f"bmp_data_{cell_types_filter['BMP'] if cell_types_filter['BMP'] else 'all'}_miic_PW_MI_results.csv")
    result_path_BMP = run_miic_selection(
        bmp_count_path,
        bmp_meta_path,
        var_names,
        var_names,  # Selection pool is same as variables of interest
        out_path_BMP
    )

    out_path_WNT = str(outdir / f"wnt_data_{cell_types_filter['WNT'] if cell_types_filter['WNT'] else 'all'}_miic_PW_MI_results.csv")
    result_path_WNT = run_miic_selection(
        wnt_count_path,
        wnt_meta_path,
        var_names,
        var_names,  # Selection pool is same as variables of interest
        out_path_WNT
    )

    # Remove metadata columns from the results
    meta_cols = [var for var in var_names if var in metadata_df.columns]
    print(f"Metadata columns to be removed: {meta_cols}")

    # Load results
    miic_BMP_results = pd.read_csv(result_path_BMP, index_col=0)
    upper_triangle_values_BMP, upper_triangle_var_names_BMP, cleaned_results_BMP = extract_upper_triangle_mi(miic_BMP_results, meta_cols)

    miic_WNT_results = pd.read_csv(result_path_WNT, index_col=0)
    upper_triangle_values_WNT, upper_triangle_var_names_WNT, cleaned_results_WNT = extract_upper_triangle_mi(miic_WNT_results, meta_cols)

    assert upper_triangle_var_names_WNT == upper_triangle_var_names_BMP

    # Create a DataFrame for the upper triangle values
    upper_triangle_df = pd.DataFrame({
        'pair': [pair for pair in upper_triangle_var_names_WNT],
        'WNT_MI': upper_triangle_values_WNT,
        'BMP_MI': upper_triangle_values_BMP
    })

    # Correct formula for distance to diagonal line y=x
    upper_triangle_df['dist'] = np.abs(upper_triangle_df['WNT_MI'] - upper_triangle_df['BMP_MI']) / np.sqrt(2)

    # check 
    print("bmp_cells len",len(bmp_cells))
    print("wnt cells len",len(wnt_cells))
    print(cell_types_filter)

    return upper_triangle_df


In [83]:
# All cells 
metadata_mtx="/Users/alichemkhi/Desktop/myProjects/2D_gast/output/v0.5/metadata_grouped.csv"
count_mtx="/Users/alichemkhi/Desktop/myProjects/2D_gast/output/v0.5/raw_counts_grouped.csv"
# Read the dataframes first
metadata_df = pd.read_csv(metadata_mtx, index_col=0)
count_df = pd.read_csv(count_mtx, index_col=0)
# small fix : replace the "-" in cell names with "."
metadata_df.index = metadata_df.index.str.replace('-', '.', regex=False)

# Call the updated function with dataframes
adata = create_anndata_from_dataframes(count_df, metadata_df)
print(f"\nAnnData object loaded:")
print(f"Shape: {adata.shape}")
print(f"Observations (cells): {adata.n_obs}")
print(f"Variables (genes): {adata.n_vars}")

Count matrix shape: (23356, 900)
Metadata shape: (900, 23)
Transposing count matrix to cells x genes format

Found 900 common indices

AnnData object created successfully!
Shape: (900, 23356) (cells x genes)
Available metadata columns: ['orig.ident', 'nCount_RNA', 'nFeature_RNA', 'cells', 'time', 'condition', 'percent.mito', 'RNA_snn_res.0.1', 'seurat_clusters', 'RNA_snn_res.0.2', 'RNA_snn_res.0.3', 'RNA_snn_res.0.4', 'RNA_snn_res.0.5', 'RNA_snn_res.0.6', 'RNA_snn_res.0.7', 'RNA_snn_res.0.8', 'RNA_snn_res.0.9', 'RNA_snn_res.1', 'celltype', 'mtMean', 'rpsMean', 'rplMean', 'celltype_grouped']

AnnData object loaded:
Shape: (900, 23356)
Observations (cells): 900
Variables (genes): 23356


In [None]:
# Inputs

wnt_summary='/Users/alichemkhi/Desktop/myProjects/2D_gast/output/v0.2/2D_gast_v0.2_2pass_wnt.278.st.txt'
bmp_summary='/Users/alichemkhi/Desktop/myProjects/2D_gast/output/v0.2/2D_gast_v0.2_test_2pass_bmp.246.st.txt'
cell_types_filter={
    "BMP" : None,
    "WNT" : None
}
outdir = Path('/Users/alichemkhi/Desktop/myProjects/2D_gast/output/v0.5/')
var_names = get_variable_names(
    wnt_summary,
    bmp_summary
)
#----------------------------------
# Process Mutual Information for BMP and WNT
#----------------------------------

upper_triangle_df=process_mi(adata, cell_types_filter, outdir,metadata_df, var_names)
plot_path=outdir / f"bmp_wnt_{cell_types_filter['BMP'] if cell_types_filter['BMP'] else 'all'}_{cell_types_filter['WNT'] if cell_types_filter['WNT'] else 'all'}_mi_scatter_colored_unfiltered.html"
plot_wnt_bmp_mi_scatter(
    upper_triangle_df, 
    plot_path
)


# #check 
# # Total number of unique pairs in a set of 340 variables
# sum_=340-1
# for i in range(2,340):
#     #print(i,sum_)

#     sum_=sum_+(340-i)
# print(f"Total number of unique pairs: {sum_}")


Count matrix saved to: /Users/alichemkhi/Desktop/myProjects/2D_gast/output/v0.5/bmp_data_all_counts.csv
Shape: (321, 23356) (cells x genes)
Metadata saved to: /Users/alichemkhi/Desktop/myProjects/2D_gast/output/v0.5/bmp_data_all_metadata.csv
Shape: (321, 23) (cells x metadata_columns)
Metadata columns: ['orig.ident', 'nCount_RNA', 'nFeature_RNA', 'cells', 'time', 'condition', 'percent.mito', 'RNA_snn_res.0.1', 'seurat_clusters', 'RNA_snn_res.0.2', 'RNA_snn_res.0.3', 'RNA_snn_res.0.4', 'RNA_snn_res.0.5', 'RNA_snn_res.0.6', 'RNA_snn_res.0.7', 'RNA_snn_res.0.8', 'RNA_snn_res.0.9', 'RNA_snn_res.1', 'celltype', 'mtMean', 'rpsMean', 'rplMean', 'celltype_grouped']
Count matrix saved to: /Users/alichemkhi/Desktop/myProjects/2D_gast/output/v0.5/wnt_data_all_counts.csv
Shape: (278, 23356) (cells x genes)
Metadata saved to: /Users/alichemkhi/Desktop/myProjects/2D_gast/output/v0.5/wnt_data_all_metadata.csv
Shape: (278, 23) (cells x metadata_columns)
Metadata columns: ['orig.ident', 'nCount_RNA', '

# BMP(NesMes) - WNT(NesMes)


In [114]:
# Inputs

wnt_summary='/Users/alichemkhi/Desktop/myProjects/2D_gast/output/miic_state_orders/2D_gast_grouped_v0.3_BMP_NasMes.333.st.txt'
bmp_summary='/Users/alichemkhi/Desktop/myProjects/2D_gast/output/miic_state_orders/2D_gast_grouped_v0.3_BMP_NasMes.333.st.txt'
cell_types_filter={
    "BMP" : ["NasMes"],
    "WNT" : ["NasMes"]
}
outdir = Path('/Users/alichemkhi/Desktop/myProjects/2D_gast/output/v0.5/')
var_names = get_variable_names(
    wnt_summary,
    bmp_summary
)
#----------------------------------
# Process Mutual Information for BMP and WNT
#----------------------------------

upper_triangle_df=process_mi(adata, cell_types_filter, outdir,metadata_df, var_names)
plot_path=outdir / f"bmp_wnt_{cell_types_filter['BMP'] if cell_types_filter['BMP'] else 'all'}_{cell_types_filter['WNT'] if cell_types_filter['WNT'] else 'all'}_mi_scatter_colored_unfiltered.html"
plot_wnt_bmp_mi_scatter(
    upper_triangle_df, 
    plot_path
)



Count matrix saved to: /Users/alichemkhi/Desktop/myProjects/2D_gast/output/v0.5/bmp_data_['NasMes']_counts.csv
Shape: (112, 23356) (cells x genes)
Metadata saved to: /Users/alichemkhi/Desktop/myProjects/2D_gast/output/v0.5/bmp_data_['NasMes']_metadata.csv
Shape: (112, 23) (cells x metadata_columns)
Metadata columns: ['orig.ident', 'nCount_RNA', 'nFeature_RNA', 'cells', 'time', 'condition', 'percent.mito', 'RNA_snn_res.0.1', 'seurat_clusters', 'RNA_snn_res.0.2', 'RNA_snn_res.0.3', 'RNA_snn_res.0.4', 'RNA_snn_res.0.5', 'RNA_snn_res.0.6', 'RNA_snn_res.0.7', 'RNA_snn_res.0.8', 'RNA_snn_res.0.9', 'RNA_snn_res.1', 'celltype', 'mtMean', 'rpsMean', 'rplMean', 'celltype_grouped']
Count matrix saved to: /Users/alichemkhi/Desktop/myProjects/2D_gast/output/v0.5/wnt_data_['NasMes']_counts.csv
Shape: (75, 23356) (cells x genes)
Metadata saved to: /Users/alichemkhi/Desktop/myProjects/2D_gast/output/v0.5/wnt_data_['NasMes']_metadata.csv
Shape: (75, 23) (cells x metadata_columns)
Metadata columns: ['or

Total number of unique pairs: 57630
