In [None]:
import os

import numpy as np
import pandas as pd
import scipy.stats
import matplotlib.pyplot as plt

import scanpy as sc

import statsmodels.stats.multitest

In [None]:
def count_zero_pairs(contact_mtx):
    n_0 = 0
    for i in range(contact_mtx.shape[0]):
        for j in range(i, contact_mtx.shape[0]):
            if contact_mtx[i, j] == 0:
                n_0 += 1
    return n_0

def adjust_p_value_matrix_by_BH(p_val_mtx):
    '''Adjust the p-values in a matrix by the Benjamini/Hochberg method.
    The matrix should be symmetric.
    '''
    p_val_sequential = []
    N = p_val_mtx.shape[0]
    
    for i in range(N):
        for j in range(i, N):
            p_val_sequential.append(p_val_mtx[i, j])

    p_val_sequential_bh = statsmodels.stats.multitest.multipletests(p_val_sequential, method='fdr_bh')[1]
    
    adjusted_p_val_mtx = np.zeros((N, N))
    
    counter = 0
    for i in range(N):
        for j in range(i, N):
            adjusted_p_val_mtx[i, j] = p_val_sequential_bh[counter]
            adjusted_p_val_mtx[j, i] = p_val_sequential_bh[counter]
            counter += 1
            
    return adjusted_p_val_mtx

def get_data_frame_from_metrices(cell_types, mtx_dict):
    N = len(cell_types)
    
    serials_dict = {'cell_type1':[], 'cell_type2':[]}
    for k in mtx_dict.keys():
        serials_dict[k] = []
        
    for i in range(N):
        for j in range(i, N):
            serials_dict['cell_type1'].append(cell_types[i])
            serials_dict['cell_type2'].append(cell_types[j])
            for k in mtx_dict.keys():
                serials_dict[k].append(mtx_dict[k][i, j])
                
    return pd.DataFrame(serials_dict)
    

def sort_cell_type_contact_p_values(p_val_mtx, cell_types):
    '''Return a list of (cell_type1, cell_type2, p_value) sorted by p_values.'''
    p_val_list = []
    N = p_val_mtx.shape[0]
    for i in range(N):
        for j in range(i, N):
            p_val_list.append((cell_types[i], cell_types[j], p_val_mtx[i, j]))
    return sorted(p_val_list, key=lambda x:x[2])

In [None]:
import scipy.cluster
from scattermap import scattermap

def get_optimal_order_of_mtx(X):
    Z = scipy.cluster.hierarchy.ward(X)
    return scipy.cluster.hierarchy.leaves_list(
        scipy.cluster.hierarchy.optimal_leaf_ordering(Z, X))

def get_ordered_tick_labels(tick_labels):
    tick_labels_with_class = [s.split(' ')[-1] + ' ' + s for s in tick_labels]
    return np.argsort(tick_labels_with_class)

def filter_pval_mtx(pval_mtx, tick_labels, allowed_pairs):
    pval_mtx_filtered = pval_mtx.copy()
    
    for i in range(pval_mtx.shape[0]):
        ct1 = tick_labels[i]
        for j in range(pval_mtx.shape[1]):
            ct2 = tick_labels[j]
            
            if ((ct1, ct2) in allowed_pairs) or ((ct2, ct1) in allowed_pairs):
                continue
            else:
                pval_mtx_filtered[i, j] = 1
            
    return pval_mtx_filtered

def make_dotplot(pval_mtx, fold_change_mtx, tick_labels, title='', allowed_pairs=None):

    #optimal_order = get_optimal_order_of_mtx(pval_mtx)
    optimal_order = get_ordered_tick_labels(tick_labels)
    
    pval_mtx = pval_mtx[optimal_order][:, optimal_order]
    fold_change_mtx = fold_change_mtx[optimal_order][:, optimal_order]
    tick_labels = tick_labels[optimal_order]
    
    if None is not allowed_pairs:
        pval_mtx = filter_pval_mtx(pval_mtx, tick_labels, allowed_pairs)
    
    mlog_pvals = - np.log10(np.maximum(pval_mtx, 1e-10))

    fig_len = len(tick_labels) * 0.1
    fig = plt.figure(figsize=(fig_len, fig_len), dpi=300)


    ax = scattermap(fold_change_mtx, marker_size=mlog_pvals + 0.5, 
                square=True, 
                cmap="coolwarm",
                linewidths=0.2 * (pval_mtx < 0.05).reshape(-1), 
                linecolor='black', xticklabels=tick_labels, yticklabels=tick_labels,
                vmin=0, vmax=2, 
                cbar_kws={'shrink':0.5, 'anchor':(0, 0.7)})
    ax.tick_params(axis='both', which='major', labelsize=4.5)
    ax.figure.axes[1].tick_params(axis="y", labelsize=7)
    ax.figure.axes[1].set_ylabel('fold change', fontsize=7)

    # Create a dot size legend using off-axis scatter calls and legend
    ax.scatter(-1, -1, label='$10^{-10}$', marker="o", linewidths=0, c="grey", s=10.5)
    ax.scatter(-1, -1, label='$10^{-2}$', marker="o", linewidths=0, c="grey", s=2.5)
    ax.scatter(-1, -1, label='1', marker="o", linewidths=0, c="grey", s=0.5)
    leg = ax.legend(loc="upper left", bbox_to_anchor=(1, 0.2), fontsize=7)
    leg.set_title('adjusted p-val',prop={'size':7})
    
    ax.set_title(title)
    
    return ax

In [None]:
%%time

permutation_path = 'outputs_15um'

major_brain_regions = ['Cerebellum', 'Cortical_subplate', 'Fiber_tracts', 'Hippocampus',
       'Isocortex', 'Medulla', 'Olfactory', 'Pallidum', 'Pons', 'Striatum',
       'Thalamus', 'Ventricular_systems', 'anterior_HY', 'anterior_MB',
       'posterior_HY', 'posterior_MB']

result_dfs = []

for region in major_brain_regions:

    
    # Load the cell type labels
    df_ct_labels = pd.read_csv(os.path.join('cells_by_regions', f'{region}.csv'), index_col=0)

    subclass_types = np.unique(df_ct_labels['subclass_label_transfer'])
    
    cell_contact_counts = np.load(os.path.join(permutation_path, f'{region}_no_permutation.npy'))

    local_null_means = np.load(os.path.join(permutation_path, f'{region}_local_permutation_mean.npy'))
    local_null_stds = np.load(os.path.join(permutation_path, f'{region}_local_permutation_std.npy'))

    # Require all stds to be larger or equal to the minimal observable std value
    local_null_stds = np.maximum(local_null_stds, np.sqrt(1 / 1000))
    
    local_z_scores = (cell_contact_counts - local_null_means) / local_null_stds
    local_p_values = scipy.stats.norm.sf(local_z_scores)
    adjusted_local_p_values = adjust_p_value_matrix_by_BH(local_p_values)
    
    fold_changes = cell_contact_counts / (local_null_means + 1e-4)
    
    #make_dotplot(local_p_values, fold_changes, subclass_types, title=region)
    make_dotplot(adjusted_local_p_values, fold_changes, subclass_types, title=region)
    #make_dotplot(adjusted_local_p_values, fold_changes, subclass_types, title=region + ' L-R filtered', 
    #             allowed_pairs=allowed_pairs)
    
    # Gather all results into a data frame
    contact_result_df = get_data_frame_from_metrices(subclass_types, 
                                             {'pval-adjusted': adjusted_local_p_values,
                                              'pval': local_p_values,
                                              'z_score': local_z_scores,
                                              'contact_count': cell_contact_counts,
                                              'permutation_mean': local_null_means,
                                              'permutation_std': local_null_stds
                                            }).sort_values('z_score', ascending=False)

    # Filter out pairs that don't contact
    contact_result_df = contact_result_df[contact_result_df['pval-adjusted'] < 0.05]
    contact_result_df = contact_result_df[contact_result_df['contact_count'] > 50]
    contact_result_df.to_csv(os.path.join(permutation_path, f'{region}_close_contacts.csv'))
    
    result_dfs.append(contact_result_df)
    #break

In [None]:
combined_results = pd.concat(result_dfs)
combined_results

In [None]:
ct1, ct2 = 'BAM NN', 'Endo NN'
#ct1, ct2 = 'SCs Dmbx1 Gaba', 'SCig Foxb1 Glut'
#ct1, ct2 = 'Vip Gaba', 'L2/3 IT CTX Glut'


combined_results[((combined_results['cell_type1'] == ct1) & (combined_results['cell_type2'] == ct2))
                |((combined_results['cell_type1'] == ct2) & (combined_results['cell_type2'] == ct1))]

In [None]:
all_pairs = combined_results[['cell_type1', 'cell_type2']].values
selected_pairs = []
selected_ids = []

for i in range(combined_results.shape[0]):
    p =  tuple(sorted(all_pairs[i]))
    
    if p not in selected_pairs:
        selected_pairs.append(p)
        selected_ids.append(i)
        
#combined_results.iloc[selected_ids].to_csv(os.path.join(permutation_path,
#    'subclass_close_contacts_non_redundant.csv'))
combined_results.iloc[selected_ids]

In [None]:
len(set([tuple(sorted([ct1, ct2])) for ct1, ct2 in 
     combined_results[['cell_type1', 'cell_type2']].values]))

In [None]:
len(set([tuple(sorted([ct1, ct2])) for ct1, ct2 in 
     combined_results[['cell_type1', 'cell_type2']].values if ct1!=ct2]))

In [None]:
len(set([tuple(sorted([ct1, ct2])) for ct1, ct2 in 
     combined_results[['cell_type1', 'cell_type2']].values if ct1==ct2]))

In [None]:
set([tuple(sorted([ct1, ct2])) for ct1, ct2 in 
     combined_results[['cell_type1', 'cell_type2']].values])