In [1]:
import sys

import matplotlib.pyplot as plt
import multiprocessing as mp
import anndata as ad
import pandas as pd
import scanpy as sc
import argparse
import scipy.stats as stats

import numpy as np
import pickle
import sys
import os
import seaborn as sns
import time
from sklearn.neighbors import KNeighborsClassifier

# Load imports from git directory
import pairwise_functions as pf
import gc

In [2]:
import seaborn as sns

def plot_heatmap(passing_cells_df, title = "", min_col_n=10, min_row_n=10, row_clust = "cell_type_prediction", columns_clust = None, scale_by_col = True, min_value_for_label=5, sort_alpha_numerically=False, show_vert_lines=False, show_horizontal_lines=False, is_wide = False):

# create heatmap of variable "cluster_alias" and "cell_type_prediction"
    if not is_wide:
        passing_cells_df = passing_cells_df[[row_clust, columns_clust]]
        pivot_df = passing_cells_df.pivot_table(index=row_clust, columns=columns_clust, aggfunc=len, fill_value=0)
    else:
        pivot_df = passing_cells_df

    # if column has less than n cells, remove it
    pivot_df = pivot_df.loc[:, pivot_df.sum(axis=0) > min_col_n]
    # if row has less than n cells, remove it
    pivot_df = pivot_df.loc[pivot_df.sum(axis=1) > min_row_n, :]

    pivot_df.sort_index(inplace=True, axis=1, ascending=True)

    # if scaling by columns 
    if not sort_alpha_numerically:
        if scale_by_col:
            max_index = pivot_df.idxmax(axis=0)
            max_index = max_index.sort_values()
            pivot_df = pivot_df[max_index.index]
        else:
            max_index = pivot_df.idxmax(axis=1)
            max_index = max_index.sort_values()
            pivot_df = pivot_df.loc[max_index.index]
    else:
        #sort pivot_df index in place
        pivot_df.sort_index(inplace=True, axis=0, ascending=False)
    annot_mask = np.where(pivot_df > min_value_for_label, pivot_df.astype(str), "")


    # scale by column
    if scale_by_col:
        pivot_df = pivot_df.div(pivot_df.sum(axis=0), axis=1)
    else:
        pivot_df = pivot_df.div(pivot_df.sum(axis=1), axis=0)

    plt.figure(figsize=(30, 20))
    sns.heatmap(pivot_df, cmap="viridis", xticklabels=True, yticklabels=True, cbar=True)
    # sns.heatmap(pivot_df, cmap="viridis", cbar=False)
    # rotate x axis labels 45 degrees
    plt.title(title)
    x_tick_options = {"fontsize": "large"}
    plt.xticks(ha="center", **x_tick_options)
    y_ticks_options = {"fontsize": "large"}
    plt.yticks(**y_ticks_options)

    # Add vertical lines to show separation between ticks
    if show_vert_lines:
        alternating_colors = ["white",  "orange", "yellow", "pink"]
        n_colors = len(alternating_colors)
        for i in range(len(pivot_df.columns)):
            current_color = alternating_colors[i % n_colors]
            plt.axvline(i, color=current_color, linewidth=0.5)

    if show_horizontal_lines:
        alternating_colors = ["white",  "orange", "yellow", "pink"]
        n_colors = len(alternating_colors)
        for i in range(len(pivot_df.index)):
            current_color = alternating_colors[i % n_colors]
            plt.axhline(i, color=current_color, linewidth=0.5)

    plt.show()
    return pivot_df

def v1_in_v2(v1, v2):
    out = np.empty(v1.shape[0],dtype=np.bool_)
    present_set=set(v2)
    for i in range(v1.shape[0]):
        if v1[i] in present_set:
            out[i]=True
        else:
            out[i]=False
    return out

In [11]:


def compute_and_save_markers(base_chunked_dir, cell_type_1, cell_type_2, n_genes, out_dir_base, marker_comp_method="nonzero", valid_markers_set=None):
    valid_marker_methods = ["nonzero", "mean", "balanced_mean"]
    if marker_comp_method not in valid_marker_methods:
        raise Exception(f"marker_comp_method: {marker_comp_method} is not in valid methods: {valid_marker_methods}")
        
    # load objs
    cell_type_1_path = os.path.join(base_chunked_dir, f"{cell_type_1}.h5ad")
    cell_type_2_path = os.path.join(base_chunked_dir, f"{cell_type_2}.h5ad")

    assert os.path.exists(cell_type_1_path), f"{cell_type_1_path} DNE"
    assert os.path.exists(cell_type_2_path), f"{cell_type_2_path} DNE"

    cell_type_1_adata = ad.read_h5ad(cell_type_1_path)
    cell_type_2_adata = ad.read_h5ad(cell_type_2_path)
    

    assert np.all(cell_type_1_adata.var_names == cell_type_2_adata.var_names)
    gene_names = cell_type_1_adata.var_names
    if valid_markers_set is not None:
        is_valid_feature = cell_type_1_adata.var_names.isin(valid_markers_set)
        cell_type_1_adata = cell_type_1_adata[:,is_valid_feature]
        cell_type_2_adata = cell_type_2_adata[:,is_valid_feature]


    if marker_comp_method == "balanced_mean":
        group_1_markers, group_2_markers = pf.compute_balanced_mean_markers(cell_type_1_adata, cell_type_2_adata, n_genes_dir=n_genes)
    else:
        if marker_comp_method == "nonzero":
            cell_type_1_values_series = pd.Series(np.array(np.mean(cell_type_1_adata.X > 0, axis=0)).flatten(), index=gene_names)
            cell_type_2_values_series = pd.Series(np.array(np.mean(cell_type_2_adata.X > 0, axis=0)).flatten(), index=gene_names)
        elif marker_comp_method == "mean":
            cell_type_1_values_series = pd.Series(np.array(np.mean(cell_type_1_adata.X, axis=0)).flatten(), index=gene_names)
            cell_type_2_values_series = pd.Series(np.array(np.mean(cell_type_2_adata.X, axis=0)).flatten(), index=gene_names)
        
        group_1_pos_diff = cell_type_1_values_series - cell_type_2_values_series
        group_1_pos_sorted = group_1_pos_diff.sort_values(ascending=False)
        group_1_markers = group_1_pos_sorted[:n_genes]
        
        group_2_pos_diff = cell_type_2_values_series - cell_type_1_values_series
        group_2_pos_sorted = group_2_pos_diff.sort_values(ascending=False)
        group_2_markers = group_2_pos_sorted[:n_genes]
    
    res = {cell_type_1: group_1_markers, cell_type_2: group_2_markers}

    pf.save_pairwise_model(cell_type_1, cell_type_2, out_dir_base, res, "markers")


In [53]:
# from scipy import sparse

# weighting_cell_type_1 = np.ones(cell_type_1_adata.shape[0])
# weighting_cell_type_1[30] = 2

# sparse_matrix_csc = cell_type_1_adata.X.tocsc()
# result2 = sparse_matrix_csc.multiply(weighting_cell_type_1.reshape(-1, 1)).tocsr()

# # cell_type_1_adata.X.tocsc()

In [None]:

valid_markers_path = "/broad/macosko/jsilverm/pknn_cell_type_preds/shared_features.pkl"
valid_markers_set = pickle.load(open(valid_markers_path, "rb"))

In [12]:
import scipy


base_chunked_dir = "/broad/macosko/jsilverm/pknn_cell_type_preds/HIMBA/chunked_group_label"

cell_type_1 = "STRd_D1_Matrix"
cell_type_2 = "STRd_D1D2_Hybrid"
marker_comp_method="balanced_mean"
valid_markers_set = valid_markers_set
n_genes=25
out_dir_base="/broad/macosko/jsilverm/pknn_cell_type_preds/HIMBA/test"




In [13]:
compute_and_save_markers(base_chunked_dir, cell_type_1, cell_type_2, n_genes, out_dir_base, marker_comp_method, valid_markers_set)

  cell_type_1_adata.layers["counts"] = cell_type_1_adata.X
  cell_type_2_adata.layers["counts"] = cell_type_2_adata.X
