In [None]:
from tqdm import tqdm
import numpy as np
import os
import pandas as pd
cwd = os.getcwd()
from scipy.stats import mannwhitneyu
import torch

import pickle
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tissue_type_name = ['normalLiver', 'core', 'rim']
tissue_dict = {'normalLiver': 0,
               'core': 1,
               'rim': 2}

from itertools import compress

save_path = os.path.join(f'{cwd}','Plots_paper')
os.system(f'mkdir -p {save_path}')
from collections import Counter
# Set font size
plt.rcParams.update({'font.size': 36})

In [None]:
#there were conections that where very low on number <10 this threshould removes them from the calculation
def mean_with_threshold(series, threshold):
    if series.count() >= threshold:
        return series.mean()
    else:
        return np.nan
def median_with_threshold(series, threshold):
    if series.count() >= threshold:
        return series.median()
    else:
        return np.nan
def std_with_threshold(series, threshold):
    if series.count() >= threshold:
        return series.std()
    else:
        return np.nan

def mad_with_threshold(series, threshold):
    if series.count() >= threshold:
        return series.mad()
    else:
        return np.nan

def sum_with_threshold(series, threshold):
    if series.count() >= threshold:
        return series.sum()
    else:
        return np.nan

In [None]:
def top_c2c_interaction(df,number_top):
    flat = np.nan_to_num(df.values.flatten())
    indices = np.argpartition(flat, -number_top)[-number_top:]
    top_values = flat[indices]

    # Convert these indices to row and column labels
    cols, rows = np.unravel_index(indices, df.shape)
    row_labels = df.index[rows]
    col_labels = df.columns[cols]

    # Create a result dataframe or list of tuples
    top = pd.DataFrame({
        'Column': row_labels,
        'Row': col_labels,
        'Value': top_values
    })
    return(top.sort_values(by='Value', ascending=False))

In [None]:
number_steps_region_subsampleing = 100
minimum_number_cells = 50
similarity_measure = 'euclide'
radius_neibourhood = 50

input_dim = 27
final_layer = 3
batch_size = 450
learning_rate = 1e-2
number_attention_heads = 1
concat_attentionheads = False
data_set = 'train'

comment = ''
comment_norm = ''

cell_types = ['B cells CD38+', 'B cells CD45RA', 'B cells PD-L1+',
           'Granulocytes CD38+', 'Granulocytes CD38-', 'Kupffer cells',
           'M2 Macrophages PD-L1+', 'M2 Macrophages PD-L1-', 'MAITs',
           'MHCII APCs', 'Mixed Immune CD45+', 'NK Cells CD16',
           'NK Cells CD56', 'T cells CD4', 'T cells CD4 PD-L1+',
           'T cells CD4 naïve', 'T cells CD57', 'T cells CD8 PD-1high',
           'T cells CD8 PD-1low', 'T cells CD8 PD-L1+', 'Tregs']
Top_interaction_number = 20

attr_bool = False
if attr_bool:
    comment_att = '_attr'
else:
    comment_att = '_Noattr'

f_final = 3
Layer_1 = 27
droup_out_rate = 0.4

In [None]:
# Load the phenotype to phenotype dictionary

# Construct the name of the dictionary based on various parameters
dict_name = f"HL1_{Layer_1}_dp_{droup_out_rate}_r_{radius_neibourhood}_noSelfATT_{comment}_{minimum_number_cells}{comment_att}_eval_{data_set}".replace(
    '.', '_')

# Print the dictionary name
print(dict_name)

# Construct the path to the pickle file containing the dictionary
save_path = os.path.join(f'{cwd}','eval','baseline',f'{dict_name}.pkl')

# Open the pickle file and load the dictionary
with open(save_path, 'rb') as f:
    """
    This block of code is responsible for loading a dictionary from a pickle file.
    The name of the dictionary and the path to the pickle file are constructed based on various parameters.
    The dictionary contains the true labels, predicted labels, normalized cell-to-cell attention matrices, IDs, and edge counts.
    These values are extracted from the dictionary and converted to numpy arrays.
    """
    dict_all = pickle.load(f)

# Extract the true labels from the dictionary and convert to a numpy array
true_labels_train = np.array(dict_all['true_labels_train'])

# Extract the predicted labels from the dictionary and convert to a numpy array
correct_predicted = np.array(dict_all['correct_predicted'])

# Extract the IDs from the dictionary and convert to a numpy array
ids_all = np.array(dict_all['ids_list'])

# Extract the normalized cell-to-cell attention matrices from the dictionary
cell_to_cell_mat_std = dict_all['normed_c2c_Att']

# Extract the edge counts from the dictionary
edge_count = dict_all['number_edges_c2c']


In [None]:
#Combine the whanted cells
shortend_dfs = [df.drop(['T cells CD57','Kupffer cells'],axis=1).drop(['T cells CD57','Kupffer cells'],axis=0) for df in tqdm(cell_to_cell_mat_std)]

In [None]:
# calc liver top interaction

tissue_id = tissue_dict['normalLiver']

sample_id_list = true_labels_train == tissue_id

all_dfs = list(compress(shortend_dfs, sample_id_list))
print(len(all_dfs))
threshould_liver = len(all_dfs)*0.01
mean_cell_att_liver = pd.concat(all_dfs)
df_liver = mean_cell_att_liver.groupby(mean_cell_att_liver.index).agg(lambda x: median_with_threshold(x, threshould_liver)).sort_index()[sorted(mean_cell_att_liver.columns)]


In [None]:
top_connections_liver = {}
name_list_liver = []

# Loop over each cell type
for scr_cell in cell_types:
    """
    This block of code is responsible for finding the top connections for each cell type in the liver.
    It first gets the indices of the top connections for the source cell type.
    If there are any top connections, it gets the values of these connections and appends them to a list.
    It also appends the source cell type and destination cell type to another list.
    It then adds the source cell type and its top connections to a dictionary.
    """

    # Get the indices of the top connections for the source cell type
    top_dst_cells = df_liver[scr_cell][~np.isnan(df_liver[scr_cell])].sort_values(ascending=False)[:Top_interaction_number].index

    values = []

    # Loop over each top connection
    for dst_cell in top_dst_cells:

        # Append the value of the connection to the values list
        values.append([dst_cell, mean_cell_att_liver[scr_cell][dst_cell]])

        # Append the source cell type and destination cell type to the name list
        name_list_liver.append([scr_cell, dst_cell])

    # Add the source cell type and its top connections to the dictionary
    top_connections_liver[scr_cell] = values

In [None]:
# Set the alpha values for the Mann-Whitney U test
alpha = 0.05
alpha_plus = 0.005

# Get the number of cell types
number_cells = len(cell_types)

# Create a subplot for each cell type
fig, axs = plt.subplots(nrows=2, ncols=int(number_cells/2), figsize=(130,60), sharey=True, layout='constrained')

# Set the line width and star size for the plot
line_width = 5
size_star = 2000

# Loop over the rows and columns of the subplot
for row in range(2):
    for idx in range(int(number_cells/2)):

        # Get the source cell type
        scr_cell = cell_types[idx+int(number_cells/2)*row]

        # Get the names and values of the top connections for the source cell type
        names = [dst_cell[0] for dst_cell in top_connections_liver[scr_cell]]
        values = [dst_cell[1][~np.isnan(dst_cell[1])] for dst_cell in top_connections_liver[scr_cell]
                                                    if len(dst_cell[1])>2]

        # Create a boxplot of the values
        axs[row, idx].grid(True, linewidth=line_width)
        box = axs[row, idx].boxplot(values)

        # Set the line width for the boxplot
        for element in ['boxes', 'whiskers', 'caps', 'medians', 'fliers']:
            plt.setp(box[element], linewidth=line_width)

        # Thicken subplot axis spines
        for spine in axs[row, idx].spines.values():
            spine.set_linewidth(line_width)

        # Perform a Mann-Whitney U test for each destination cell type and mark significant results on the plot
        for dst_cell_idx in range(len(names)):

            # Get the data for the source cell type and the current destination cell type
            data_1 = np.nan_to_num(df_liver[scr_cell].to_numpy())

            # Perform the Mann-Whitney U test
            stat, p = mannwhitneyu(data_1, np.nan_to_num(values[dst_cell_idx].to_numpy()))

            # If the p-value is less than alpha, mark the result on the plot
            if ((p < alpha) and (p > alpha_plus)):
                marker = '*'
                axs[row, idx].scatter(dst_cell_idx + 1, 1.05, marker=marker, color='black', s=size_star)
            elif p < alpha_plus:
                marker = '*'
                axs[row, idx].scatter(dst_cell_idx + 1.1, 1.05, marker=marker, color='black', s=size_star)
                axs[row, idx].scatter(dst_cell_idx + 0.9, 1.05, marker=marker, color='black', s=size_star)

        # Set the x-ticks and title for the subplot
        axs[row, idx].set_xticks(np.arange(1, Top_interaction_number+1), names, rotation=90)
        axs[row, idx].grid(False)

        # Calculate the median, lower quantile, and upper quantile for the source cell type
        mat = np.nan_to_num(df_liver[scr_cell].to_numpy())
        median = np.median(mat)
        lower_quantile = np.quantile(mat, 0.25)
        upper_quantile = np.quantile(mat, 0.75)

        # Highlight the interquartile range and plot the median
        axs[row, idx].axhspan(lower_quantile, upper_quantile, facecolor='green', alpha=0.2)
        axs[row, idx].axhline(y=median, color='red', linestyle='--', linewidth=line_width)

# Save the figure
save_path = os.path.join(f'{cwd}', 'Plots_paper', 'liver_interactions.pdf')
fig.savefig(save_path, dpi=300)

## Core

In [None]:
# Set the tissue ID for the core
tissue_id = tissue_dict['core']

# Create a list of boolean values indicating whether each sample in the training set is from the core
sample_id_list = true_labels_train == tissue_id

# Filter the list of dataframes to include only those from the core
all_dfs = list(compress(shortend_dfs, sample_id_list))

# Print the number of dataframes from the core
print(len(all_dfs))

# Calculate the threshold for the core as 1% of the number of dataframes from the core
threshould_core = int(len(all_dfs)*0.01)

# Concatenate all the dataframes from the core
mean_cell_att_core = pd.concat(all_dfs)

# Group the concatenated dataframe by index, calculate the median of each group with the threshold,
# sort the dataframe by index, and select only the sorted columns
df_core = mean_cell_att_core.groupby(mean_cell_att_core.index).agg(lambda x: median_with_threshold(x, threshould_core)).sort_index()[sorted(mean_cell_att_core.columns)]

In [None]:
top_connections_core = {}
name_list_core = []
# Loop over each cell type
for scr_cell in cell_types:
    """
    This block of code is responsible for finding the top connections for each cell type in the core.
    It first gets the indices of the top connections for the source cell type.
    If there are any top connections, it gets the values of these connections and appends them to a list.
    It also appends the source cell type and destination cell type to another list.
    It then adds the source cell type and its top connections to a dictionary.
    If there are no top connections, it adds the source cell type and NaN to the dictionary.
    """
    # Get the indices of the top connections for the source cell type
    top_dst_cells = df_core[scr_cell][~np.isnan(df_core[scr_cell])].sort_values(ascending=False)[:Top_interaction_number].index

    # If there are any top connections
    if len(top_dst_cells) != 0:
        values = []

        # Loop over each top connection
        for dst_cell in top_dst_cells:

            # Append the source cell type and destination cell type to the name list
            name_list_core.append([scr_cell, dst_cell])

            # Append the value of the connection to the values list
            values.append([dst_cell, mean_cell_att_core[scr_cell][dst_cell]])

        # Add the source cell type and its top connections to the dictionary
        top_connections_core[scr_cell] = values

    # If there are no top connections
    else:

        # Add the source cell type and NaN to the dictionary
        top_connections_core[scr_cell] = [np.nan]

In [None]:
# Get the number of cell types
number_cells = len(cell_types)

# Create a subplot for each cell type
fig, axs = plt.subplots(nrows=2, ncols=int(number_cells/2), figsize=(130,60), sharey=True, layout='constrained')

# Set the line width and star size for the plot
line_width = 5
size_star = 2000

# Loop over the rows and columns of the subplot
for row in range(2):
    for idx in range(int(number_cells/2)):

        # Get the source cell type
        scr_cell = cell_types[idx+int(number_cells/2)*row]

        # If there are connections for the source cell type
        if len(top_connections_core[scr_cell]) != 1 :

            # Get the names of the top connections for the source cell type
            names = [dst_cell[0] for dst_cell in top_connections_core[scr_cell]]

            # Get the values of the top connections for the source cell type
            values = []
            for dst_cell in top_connections_core[scr_cell]:
                if len(dst_cell[1])>2:
                    values.append(dst_cell[1][~np.isnan(dst_cell[1])])
                else:
                    values.append([np.nan])

            # Create a boxplot of the values
            axs[row, idx].grid(True, linewidth=line_width)
            box = axs[row,idx].boxplot(values)

            # Set the line width for the boxplot
            for element in ['boxes', 'whiskers', 'caps', 'medians', 'fliers']:
                plt.setp(box[element], linewidth=line_width)

            # Set the x-ticks for the subplot
            axs[row,idx].set_xticks(np.arange(1,len(box['boxes'])+1),names,rotation=90)

            # Remove grid lines
            axs[row, idx].grid(False)

            # Calculate the median, lower quantile, and upper quantile for the source cell type
            mat = np.nan_to_num(df_core[scr_cell].to_numpy())
            median = np.median(mat)
            lower_quantile = np.quantile(mat, 0.25)
            upper_quantile = np.quantile(mat, 0.75)

            # Highlight the interquartile range and plot the median
            axs[row, idx].axhspan(lower_quantile, upper_quantile, facecolor='green', alpha=0.2)
            axs[row, idx].axhline(y=median, color='red', linestyle='--',linewidth=line_width)

# Save the figure
save_path = os.path.join(f'{cwd}','Plots_paper','core_interactions.pdf')
fig.savefig(save_path,dpi=300)