In [None]:
import matplotlib
import warnings
import sys
import os
# Get the current working directory
current_dir = os.getcwd()
# Add the parent directory to sys.path
sys.path.insert(0, os.path.dirname(current_dir))
from SpaMV.utils import compute_gene_topic_correlations

warnings.filterwarnings("ignore")
import numpy as np
import scanpy as sc
import squidpy as sq
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.gridspec import GridSpec
from matplotlib.patches import Patch

font = {'size': 7}
plt.rc('font', size=8)
plt.rc('axes', titlesize=10)
plt.rc('axes', labelsize=8)
plt.rc('xtick', labelsize=6)

matplotlib.rc('font', **font)
dataset = 'ME13_1'
# load mouse embryo dataset
data_ac = sc.read_h5ad('../Dataset/' + dataset + '/adata_H3K27ac_ATAC.h5ad')

cluster_size = 20
dr_size = 50
width = 13
height = 17
gt_scale = {'top':1 - 1/height, 'bottom':1 - 4/height, 'left': .08, 'right': .19}
clustering_scale = {'top':1 - 1.2/height, 'bottom':1 - 4/height, 'left': .28, 'right': .75}
score_scale = {'top':1 - 1/height, 'bottom':1 - 4/height, 'left': .77, 'right': .85}
dr_scale = {'top': 1-4.3/height, 'bottom': 1-10.3/height, 'left': .1, 'right': .94}
dr_colorbar_scale = {'top': 1-7.1/height, 'bottom': 1-7.5/height, 'left': .96, 'right':.97}
heatmap_scale = {'top': 1-11/height, 'bottom': 1-17/height, 'left':.2, 'right': .6}
heatmap_colorbar_scale = {'top': 1-13.8/height, 'bottom': 1-14.2/height, 'left':.61, 'right': .62}
biomarker_scale = {'top': 1-11/height, 'bottom': 1-17/height, 'left': .66, 'right': .93}
biomarker_colorbar_scale = {'top': 1-13.8/height, 'bottom': 1-14.2/height, 'left': .94, 'right':.95}
dr_score_scale = {'top': 1-10.6/height, 'bottom': 1-15/height, 'left': .85, 'right': 1}


fig = plt.figure(figsize=(width, height), dpi=500)

spec_gt = GridSpec(1, 1)
spec_gt.update(**gt_scale)
gt = plt.subplot(spec_gt[0, 0])
d = sc.read_h5ad('../Dataset/' + dataset + '/adata_H3K27ac_ATAC.h5ad')
sc.pl.spatial(d, ax=gt, show=False, frameon=False)

fontsize=8
gt.annotate('Liver', xy=(190, 400), xytext=(0, -200), arrowprops=dict(arrowstyle='->', color='black'), fontsize=fontsize, color='black')
gt.annotate('Heart', xy=(640, 600), xytext=(420, -200), arrowprops=dict(arrowstyle='->', color='black'), fontsize=fontsize, color='black')
gt.annotate('Jaw', xy=(1100, 300), xytext=(960, -200), arrowprops=dict(arrowstyle='->', color='black'), fontsize=fontsize, color='black')
gt.annotate('Nose', xy=(1600, 250), xytext=(2000, 290), arrowprops=dict(arrowstyle='->', color='black'), fontsize=fontsize, color='black')
gt.annotate('Tongue', xy=(1400, 500), xytext=(2000, 540), arrowprops=dict(arrowstyle='->', color='black'), fontsize=fontsize, color='black')
gt.annotate('Basioccipital bone', xy=(1400, 1000), xytext=(2000, 1040), arrowprops=dict(arrowstyle='->', color='black'), fontsize=fontsize, color='black')
gt.annotate('Medulla oblongata', xy=(1550, 1500), xytext=(2000, 1540), arrowprops=dict(arrowstyle='->', color='black'), fontsize=fontsize, color='black')
gt.annotate('Spine', xy=(600, 1250), xytext=(400, 2050), arrowprops=dict(arrowstyle='->', color='black'), fontsize=fontsize, color='black')
gt.annotate('Spinal cord', xy=(1350, 1650), xytext=(930, 2050), arrowprops=dict(arrowstyle='->', color='black'), fontsize=fontsize, color='black')
###############################################################################################################################
# plot clustering results
spec_clustering = GridSpec(2, 4, wspace=0)
spec_clustering.update(**clustering_scale)

f_dict = {}
f_dict['SpatialGlue'] = plt.subplot(spec_clustering[0, 0])
f_dict['CellCharter'] = plt.subplot(spec_clustering[0, 1])
f_dict["COSMOS"] = plt.subplot(spec_clustering[0, 2])
f_dict['SMOPCA'] = plt.subplot(spec_clustering[1, 0])
f_dict['MISO'] = plt.subplot(spec_clustering[1, 1])
f_dict['SpaMV'] = plt.subplot(spec_clustering[1, 2])
legend = plt.subplot(spec_clustering[:, 3])

data_methods = {}
for method in ['COSMOS', 'SpaMV', 'CellCharter', 'SpatialGlue', 'MISO', 'SMOPCA']:
    data_methods[method] = sc.read_h5ad('../Results/' + dataset + '/' + method + '.h5ad')
    data_methods[method].obsm['spatial'] = data_ac[data_methods[method].obs_names].obsm['spatial']
    data_methods[method].uns['spatial'] = data_ac[data_methods[method].obs_names].uns['spatial']
    sc.pl.spatial(data_methods[method], color=method, ax=f_dict[method], show=False)

handles, labels = f_dict['SpaMV'].get_legend_handles_labels()
legend.legend(handles, labels, loc='center left', frameon=False, ncol=2)
legend.axis('off')
for f in f_dict.values():
    f.get_legend().remove()
    f.axes.get_xaxis().set_visible(False)
    f.axes.get_yaxis().set_visible(False)

###############################################################################################################################
# plot clutering evaluation results
# plot unsupervised results
spec_score = GridSpec(2, 1, hspace=.8)
spec_score.update(**score_scale)
evaluation_spamv = pd.read_csv('../Results/' + dataset + '/Evaluation_SpaMV.csv')
evaluation_miso = pd.read_csv('../Results/' + dataset + '/Evaluation_MISO.csv')
evaluation_cosmos = pd.read_csv('../Results/' + dataset + '/Evaluation_COSMOS.csv')
evaluation_spatialglue = pd.read_csv('../Results/' + dataset + '/Evaluation_SpatialGlue.csv')
evaluation_cellcharter = pd.read_csv('../Results/' + dataset + '/Evaluation_CellCharter.csv')
evaluation_smopca = pd.read_csv('../Results/' + dataset + '/Evaluation_SMOPCA.csv')

melted_df = evaluation_spamv
for df in [evaluation_miso, evaluation_cosmos, evaluation_spatialglue, evaluation_cellcharter, evaluation_smopca]:
    melted_df = pd.merge(melted_df, df, how='outer')

unsupervised = plt.subplot(spec_score[0, 0])

melted_df = melted_df[melted_df['Dataset'] == dataset]
metrics = ['jaccard 1', 'jaccard 2']  # your metrics
methods = ['MISO', 'CellCharter', 'COSMOS', 'SMOPCA', 'SpatialGlue', 'SpaMV', 'SpaMV (Shared)', 'SpaMV (H3K27ac related)', 'SpaMV (H3K27me3 related)']
melted_df = melted_df.drop(['Dataset'], axis=1)
# Plot bars for each method
# Calculate mean and standard deviation for each algorithm and metric
df_means = melted_df.groupby('method').mean().reset_index()
df_errors = melted_df.groupby('method').std().reset_index()

# Reshape the DataFrames for plotting
df_means_melted = df_means.melt(id_vars='method', var_name='Metric', value_name='Mean')
df_errors_melted = df_errors.melt(id_vars='method', var_name='Metric', value_name='Std')

# Merge the dataframes to include errors
df_combined = pd.merge(df_means_melted, df_errors_melted, on=['method', 'Metric'])

x = np.arange(2)  # the label locations
bar_width = 0.64  # Width of the bars
colors = {'MISO': '#1f77b4', 'COSMOS': '#ff7f0e', 'spaMultiVAE': '#8c564b', 'CellCharter': '#2ca02c', 'SMOPCA': '#9467bd', 'SpatialGlue': '#bcbd22', 'SpaMV': '#d62728'}
ms = ['MISO', 'CellCharter', 'COSMOS', 'SMOPCA', 'SpatialGlue', 'SpaMV']
values = {}
for o in ['H3K27ac', 'H3K27me3']:
    values[o] = []
    for method in ['MISO', 'CellCharter', 'COSMOS', 'SMOPCA', 'SpatialGlue']:
        if o == 'H3K27ac':
            values[o].append(melted_df[melted_df['method'] == method]['jaccard 1'].mean())
        else:
            values[o].append(melted_df[melted_df['method'] == method]['jaccard 2'].mean())
    if o == 'H3K27ac':
        values[o].append(melted_df[melted_df['method'] == 'SpaMV (H3K27ac related)']['jaccard 1'].mean())
    else:
        values[o].append(melted_df[melted_df['method'] == 'SpaMV (H3K27me3 related)']['jaccard 2'].mean())

bottom = np.zeros(len(ms))

for boolean, weight_count in values.items():
    p = unsupervised.bar(ms, weight_count, bar_width, label=boolean, bottom=bottom)
    bottom += weight_count

# Add labels and title
unsupervised.set_ylabel('Jaccard Similarity')
unsupervised.legend(bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0, frameon=False)
unsupervised.axes.set_xticklabels(labels=ms, rotation=45, ha='right', rotation_mode='anchor')
unsupervised.axes.xaxis.set_tick_params(pad=0)
unsupervised.set_ylim([0, .6])
# unsupervised.set_ylim([0, ymax])

# Remove the top and right spines
unsupervised.spines['top'].set_visible(False)
unsupervised.spines['right'].set_visible(False)

# Remove the x-axis label
unsupervised.set_xlabel('')  # Set the x-axis label to an empty string

# plot unsupervised results
unsupervised_spamv = plt.subplot(spec_score[1, 0])
melted_df = evaluation_spamv
for df in [evaluation_miso, evaluation_cosmos, evaluation_spatialglue, evaluation_cellcharter, evaluation_smopca]:
    melted_df = pd.merge(melted_df, df, how='outer')

melted_df = melted_df[melted_df['Dataset'] == dataset]
melted_df = melted_df.drop(['Dataset'], axis=1)
methods = ['SpaMV', 'SpaMV (Shared)', 'SpaMV (H3K27ac related)', 'SpaMV (H3K27me3 related)']
mean_scores = melted_df.groupby('method').mean().reset_index()
std_scores = melted_df.groupby('method').std().reset_index()

# Reshape the DataFrames for plotting
df_means_melted = mean_scores.melt(id_vars='method', var_name='Metric', value_name='Mean')
df_errors_melted = std_scores.melt(id_vars='method', var_name='Metric', value_name='Std')

# Merge the dataframes to include errors
df_combined = pd.merge(df_means_melted, df_errors_melted, on=['method', 'Metric'])

x = np.arange(2)  # the label locations

cat_width = .6
width = cat_width / 4 * .8 # the width of the bars
colors = {'MISO': '#1f77b4', 'COSMOS': '#ff7f0e', 'spaMultiVAE': '#8c564b', 'CellCharter': '#2ca02c', 'SMOPCA': '#9467bd', 'SpatialGlue': '#bcbd22', 'SpaMV': '#d62728'}
# Plot bars for each method
for i, method in enumerate(methods):
    offset = cat_width * i / 4 - (.65-width*4)/.8
    algo_data = df_combined[df_combined['method'] == method]
    if method == 'SpaMV (H3K27ac related)':
        label = 'Shared + H3K27ac private'
    elif method == 'SpaMV (H3K27me3 related)':
        label = 'Shared + H3K27me3 private'
    elif method == 'SpaMV (Shared)':
        label = 'Shared'
    else:
        label = 'All'
    # unsupervised.bar(x + offset, mean_scores.loc[method], width, label=label, color=color, alpha=alpha, hatch=hatch)
    unsupervised_spamv.bar(x + offset, algo_data['Mean'], yerr=algo_data['Std'], width=width, capsize=2, label=label)

# Customize the plot
unsupervised_spamv.set_ylabel('Jaccard Similarity')
unsupervised_spamv.set_xticks(x)
unsupervised_spamv.set_xticklabels(['H3K27ac', 'H3K27me3'])
unsupervised_spamv.set_ylim([0, .4])
unsupervised_spamv.legend()
# Remove spines
unsupervised_spamv.spines['right'].set_visible(False)
unsupervised_spamv.spines['top'].set_visible(False)
unsupervised_spamv.legend(bbox_to_anchor=(1.02, 0.5), loc='center left', frameon=False)

###############################################################################################################################
# plot dimensionality reduction results
spec_dr = GridSpec(3, 6, hspace=0, wspace=.3)
spec_dr.update(**dr_scale)
z = pd.read_csv('../Results/' + dataset + '/SpaMV_z.csv', index_col=0)
w0 = pd.read_csv('../Results/' + dataset + '/SpaMV_w_H3K27ac.csv', index_col=0)
w1 = pd.read_csv('../Results/' + dataset + '/SpaMV_w_H3K27me3.csv', index_col=0)
w0 = w0.drop(columns=z.columns[[4, 17]])
w1 = w1.drop(columns=z.columns[[4]])
z = z.drop(columns=z.columns[[4, 17]])
col_dict = {}
si = 1
oi = 1
ti = 1
for topic in z.columns:
    if 'Shared' in topic:
        col_dict[topic] = topic.rsplit(' ', 1)[0] + ' ' + str(si)
        si += 1
    elif 'H3K27ac' in topic:
        col_dict[topic] = topic.rsplit(' ', 1)[0] + ' ' + str(oi)
        oi += 1
    else:
        col_dict[topic] = topic.rsplit(' ', 1)[0] + ' ' + str(ti)
        ti += 1
z = z.rename(columns=col_dict)
w0 = w0.rename(columns=col_dict)
w1 = w1.rename(columns=col_dict)
data_ac = data_ac[z.index]
data_ac.obs[z.columns] = z.values
def auto_break_title(title, max_length_per_line=15):
    """
    Automatically break a title into multiple lines if it exceeds the threshold length.

    Args:
        title (str): The original title
        max_length_per_line (int): Maximum characters per line before breaking

    Returns:
        str: Title with line breaks added where appropriate
    """
    words = title.split()
    lines = []
    current_line = []
    current_length = 0

    for word in words:
        # Check if adding this word would exceed the max length
        if current_length + len(word) + len(current_line) <= max_length_per_line:
            current_line.append(word)
            current_length += len(word)
        else:
            # Start a new line
            lines.append(' '.join(current_line))
            current_line = [word]
            current_length = len(word)

    # Add the last line
    if current_line:
        lines.append(' '.join(current_line))

    # Join lines with newline character
    return '\n'.join(lines)
f_dict = {}
i=0
for topic in z.columns:
    f_dict[topic] = plt.subplot(spec_dr[i//6, i % 6])
    sq.pl.spatial_scatter(data_ac, color=topic, ax=f_dict[topic], title=topic if 'Shared' in topic else topic.split(maxsplit=1)[0] + '\n' + topic.split(maxsplit=1)[1], frameon=False)
    f_dict[topic].axes.get_xaxis().set_visible(False)
    f_dict[topic].axes.get_yaxis().set_visible(False)
    sm = f_dict[topic].collections[0]
    f_dict[topic].collections[0].colorbar.remove()
    i+=1

# plot common colorbar
spec_colorbar = GridSpec(1, 1)
spec_colorbar.update(**dr_colorbar_scale)
colorbar = plt.subplot(spec_colorbar[0, 0])
cb = plt.colorbar(sm, cax=colorbar)
# colorbar.set_yticks([])
vmin, vmax = .05, .4
# Define tick positions including intermediate ticks
num_ticks = 4  # Total number of ticks
ticks = np.linspace(vmin, vmax, num=num_ticks)

# Set the ticks on the colorbar
cb.outline.set_visible(False)
cb.set_ticks(ticks)

# Create labels where only the first and last are labeled
labels = ['Low'] + [''] * (num_ticks - 2) + ['High']

# Set these labels on the colorbar
cb.ax.set_yticklabels(labels)
cb.ax.tick_params(size=0)
# colorbar.yaxis.set_ticklabels([])
# colorbar.spines['outline'].set_visible(False)
# Add text to the right of the colorbar
# colorbar.text(x=1.2, y=.5, s='High', va='center', ha='left', rotation=0)
# colorbar.text(x=1.2, y=.05, s='Low', va='center', ha='left', rotation=0)
colorbar.text(x=0, y=.65, s='Topic\nAbundance', va='center', ha='left', rotation=0)

###############################################################################################################################
# plot heatmaps
spec_heatmap = GridSpec(2, 1, hspace=.01, height_ratios=[17, 15])
spec_heatmap.update(**heatmap_scale)
topk=10
# Create an empty DataFrame to store rankings
rankings_ac = pd.DataFrame(index=w0.index, columns=w0.columns)

# Iterate over each column and rank values
for column in w0.columns:
    # Rank the values in descending order (highest to lowest)
    rankings_ac[column] = w0[column].rank(ascending=False)
    
# Create an empty DataFrame to store rankings
rankings_me3 = pd.DataFrame(index=w1.index, columns=w1.columns)

# Iterate over each column and rank values
for column in w1.columns:
    # Rank the values in descending order (highest to lowest)
    rankings_me3[column] = w1[column].rank(ascending=False)
    
genes = {}
common_genes = rankings_ac.index.intersection(rankings_me3.index)
for topic in rankings_ac.columns:
    genes[topic] = []
    if topic in z.columns:
        for gene in rankings_ac.loc[common_genes, topic].sort_values().index:
            if rankings_ac.loc[gene, topic] < topk and rankings_me3.loc[gene, :].min() < topk and rankings_me3.columns[rankings_me3.loc[gene, :].argmin()] != topic and gene not in genes.values():
                genes[topic].append(gene)
gs = []
for t in genes.keys():
    for g in genes[t]:
        if g not in gs:
            gs.append(g)
            
heatmap_ac = {}
for topic in genes.keys():
    heatmap_ac[topic] = []
    for gene in gs:
        heatmap_ac[topic].append(rankings_ac.loc[gene, topic])
heatmap_ac = pd.DataFrame(heatmap_ac)
heatmap_ac.index = gs
heatmap_ac = heatmap_ac.T
new_column_order = [4, 5, 6, 7, 8, 0, 2, 3, 1, 28, 19, 17, 20, 16, 18, 9, 10, 12, 11, 21, 22, 14, 13, 15, 24, 23, 25, 27, 26, 29, 30, 31, 32]
new_column = heatmap_ac.columns[new_column_order]
new_row = heatmap_ac.index[[1, 0, 13, 7, 8, 2, 3, 6, 15, 9, 4, 5, 11, 10, 12, 14, 16]]
heatmap_ac = heatmap_ac.loc[new_row, new_column]
heatmap_ac = heatmap_ac.astype(int)
import seaborn as sns
hm_ac = plt.subplot(spec_heatmap[0, 0])
# sns.heatmap(heatmap_ac, annot=True, fmt='d', cmap='coolwarm_r', cbar=True)
hm_ac = sns.heatmap(heatmap_ac, annot=False, fmt='d', cmap='coolwarm_r')
hm_cb = hm_ac.collections[0]
hm_ac.collections[0].colorbar.remove()
# Manually annotate the heatmap
for i in range(heatmap_ac.shape[0]):
    for j in range(heatmap_ac.shape[1]):
        if heatmap_ac.iloc[i, j] <= 10:
            hm_ac.text(j + 0.5, i + 0.5, f'{heatmap_ac.iloc[i, j]:d}', ha='center', va='center', color='white')
# Move x-axis labels to the top
hm_ac.xaxis.tick_top()
hm_ac.xaxis.set_label_position('top')
# Automatically wrap long labels
plt.gca().set_yticklabels([topic if 'H3K27ac' not in topic else topic.split(maxsplit=1)[0] + '\n' + topic.split(maxsplit=1)[1] for topic in heatmap_ac.index])
plt.ylabel('                                    H3K27ac                                     ', bbox=dict(facecolor='grey', alpha=0.3, edgecolor='none'), labelpad=20)
# Rotate x-axis labels 90 degrees
plt.xticks(rotation=45, ha='left', rotation_mode='anchor')

heatmap_me3 = {}
topic_me3 = []
for gene in heatmap_ac.columns:
    for topic in rankings_me3.columns[rankings_me3.loc[gene, :] < topk]:
        if topic not in topic_me3:
            topic_me3.append(topic)

for gene in heatmap_ac.columns:
    heatmap_me3[gene] = []
    for topic in topic_me3:
        heatmap_me3[gene].append(rankings_me3.loc[gene, topic])
heatmap_me3 = pd.DataFrame(heatmap_me3)
heatmap_me3.index = topic_me3
# # new_row_order = [3, 6, 7, 0, 2, 5, 1, 4]
# # new_row = heatmap_me3.index[new_row_order]
new_row = heatmap_me3.index[[0, 1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 2]]
heatmap_me3 = heatmap_me3.loc[new_row, :]
# heatmap_me3 = heatmap_me3.loc[['Shared topic 5', 'Shared topic 12', 'Shared topic 4', 'H3K27me3 private topic 1', 'Shared topic 17', 'Shared topic 2', 'Shared topic 7', 'Shared topic 10'], :]
heatmap_me3 = heatmap_me3.astype(int)
import seaborn as sns
hm_me3 = plt.subplot(spec_heatmap[1, 0])
# sns.heatmap(heatmap_ac, annot=True, fmt='d', cmap='coolwarm_r', cbar=True)
hm_me3 = sns.heatmap(heatmap_me3, annot=False, fmt='d', cmap='coolwarm_r', cbar=False)
# Manually annotate the heatmap
for i in range(heatmap_me3.shape[0]):
    for j in range(heatmap_me3.shape[1]):
        if heatmap_me3.iloc[i, j] <= 10:
            hm_me3.text(j + 0.5, i + 0.5, f'{heatmap_me3.iloc[i, j]:d}', ha='center', va='center', color='white')

plt.xticks([])
plt.xlabel('')
plt.gca().set_yticklabels([topic if 'H3K27me3' not in topic else topic.split(maxsplit=1)[0] + '\n' + topic.split(maxsplit=1)[1] for topic in heatmap_me3.index])
plt.ylabel('                             H3K27me3                            ', bbox=dict(facecolor='grey', alpha=0.3, edgecolor='none'), labelpad=20)
spec_heatmap_colorbar = GridSpec(1, 1)
spec_heatmap_colorbar.update(**heatmap_colorbar_scale)
heatmap_colorbar = plt.subplot(spec_heatmap_colorbar[0, 0])
cb = plt.colorbar(hm_cb, cax=heatmap_colorbar)
heatmap_colorbar.invert_yaxis() 
# heatmap_colorbar.set_yticks([])
# heatmap_colorbar.spines['outline'].set_visible(False)
# Add text to the right of the colorbar
# Get the minimum and maximum values from the data
vmin, vmax = heatmap_ac.min().min() + 40, heatmap_ac.max().max() - 60
# Define tick positions including intermediate ticks
num_ticks = 4  # Total number of ticks
ticks = np.linspace(vmin, vmax, num=num_ticks)

# Set the ticks on the colorbar
cb.outline.set_visible(False)
cb.set_ticks(ticks)

# Create labels where only the first and last are labeled
labels = ['High'] + [''] * (num_ticks - 2) + ['Low']

# Set these labels on the colorbar
cb.ax.set_yticklabels(labels)
cb.ax.tick_params(size=0)
cb.ax.set_title('Gene\nRanking', fontsize=7, loc='left')  # Set the label position to the top

# plot biomarkers
spec_biomarker = GridSpec(4, 2, wspace=-.6, hspace=0)
spec_biomarker.update(**biomarker_scale)
f_dict = {}
i=0
data_ac = sc.read_h5ad('../Dataset/' + dataset + '/adata_H3K27ac_ATAC.h5ad')
sc.pp.log1p(data_ac)
data_me3 = sc.read_h5ad('../Dataset/' + dataset + '/adata_H3K27me3_ATAC.h5ad')
d = data_me3.copy()
q = np.percentile(d.X, 25, axis=0)
d.X = d.X * (d.X>=q)
sc.pp.log1p(data_me3)
# for gene in ['Hoxc4', 'Hand2', 'Gpr37l1', 'Csrp3', 'Foxf1', 'Scrt1', 'Adgra2', 'A530006G24Rik']:
for gene in ['Hoxc4', 'Hand2', 'Gpr37l1', 'Csrp3']:
    f_dict[gene + '_ac'] = plt.subplot(spec_biomarker[i, 0])
    sc.pl.spatial(data_ac, color=gene, ax=f_dict[gene + '_ac'], colorbar_loc=None, vmin='p0', vmax='p99', show=False, title='', cmap='coolwarm')
    if i == 0:
        # Set the title with a grey background
        f_dict[gene + '_ac'].set_title('       H3K27ac       ', bbox=dict(facecolor='grey', alpha=0.3, edgecolor='none'))
        f_dict[gene + '_ac'].set_ylabel('            ' + gene + '              ', bbox=dict(facecolor='grey', alpha=0.3, edgecolor='none'))
    elif i == 1:
        f_dict[gene + '_ac'].set_ylabel('            ' + gene + '             ', bbox=dict(facecolor='grey', alpha=0.3, edgecolor='none'))
    elif i == 2:
        f_dict[gene + '_ac'].set_ylabel('          ' + gene + '             ', bbox=dict(facecolor='grey', alpha=0.3, edgecolor='none'))
    else:
        f_dict[gene + '_ac'].set_ylabel('            ' + gene + '              ', bbox=dict(facecolor='grey', alpha=0.3, edgecolor='none'))
    # Remove the box (spines)
    f_dict[gene + '_ac'].spines['top'].set_visible(False)
    f_dict[gene + '_ac'].spines['right'].set_visible(False)
    f_dict[gene + '_ac'].spines['left'].set_visible(False)
    f_dict[gene + '_ac'].spines['bottom'].set_visible(False)
    f_dict[gene + '_ac'].set_xlabel('')
    f_dict[gene + '_me3'] = plt.subplot(spec_biomarker[i, 1])
    
    sc.pl.spatial(data_me3, color=gene, ax=f_dict[gene + '_me3'], vmin='p90' if i == 0 else 'p70', vmax='p99', frameon=False, show=False, title='', cmap='coolwarm')
    if i == 0:
        # Set the title with a grey background
        f_dict[gene + '_me3'].set_title('     H3K27me3      ', bbox=dict(facecolor='grey', alpha=0.3, edgecolor='none'))
    sm = f_dict[gene + '_me3'].collections[0]
    f_dict[gene + '_me3'].collections[0].colorbar.remove()
    i+=1

spec_biomarker_colorbar = GridSpec(1, 1)
spec_biomarker_colorbar.update(**biomarker_colorbar_scale)
biomarker_colorbar = plt.subplot(spec_biomarker_colorbar[0, 0])
cb = plt.colorbar(sm, cax=biomarker_colorbar)
vmin, vmax = data_me3[:, gene].X.min() + .45, data_me3[:, gene].X.max() - .7
# Define tick positions including intermediate ticks
num_ticks = 4  # Total number of ticks
ticks = np.linspace(vmin, vmax, num=num_ticks)

# Set the ticks on the colorbar
cb.outline.set_visible(False)
cb.set_ticks(ticks)

# Create labels where only the first and last are labeled
labels = ['Low'] + [''] * (num_ticks - 2) + ['High']

# Set these labels on the colorbar
cb.ax.set_yticklabels(labels)
cb.ax.tick_params(size=0)
# biomarker_colorbar.set_yticks([])
# biomarker_colorbar.spines['outline'].set_visible(False)
# Add text to the right of the colorbar
# biomarker_colorbar.text(x=1.2, y=1.8, s='High', va='center', ha='left', rotation=0)
# biomarker_colorbar.text(x=1.2, y=.2, s='Low', va='center', ha='left', rotation=0)
biomarker_colorbar.text(x=0, y=2.3, s='Normalized\nExpression', va='center', ha='left', rotation=0)

plt.tight_layout()
fs = 17
fig.text(.07, .92, 'a', fontsize=fs, fontweight='bold')
fig.text(.27, .94, 'b', fontsize=fs, fontweight='bold')
fig.text(.73, .95, 'c', fontsize=fs, fontweight='bold')
fig.text(.73, .84, 'd', fontsize=fs, fontweight='bold')
fig.text(.07, .75, 'e', fontsize=fs, fontweight='bold')
fig.text(.07, .37, 'f', fontsize=fs, fontweight='bold')
fig.text(.65, .37, 'g', fontsize=fs, fontweight='bold')
plt.savefig('../Figures/visualisation_4.pdf', bbox_inches='tight', transparent=True)
plt.show()

In [None]:
import pandas as pd
import scanpy as sc
import sys
import os
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
# Get the current working directory
current_dir = os.getcwd()
# Add the parent directory to sys.path
sys.path.insert(0, os.path.dirname(os.path.dirname(current_dir)))

dataset = '7_ME13_1'
z = pd.read_csv('../Results/' + dataset + '/SpaMV_z.csv', index_col=0)
w = pd.read_csv('../Results/' + dataset + '/SpaMV_w_H3K27ac.csv', index_col=0)
w = w.drop(columns=z.columns[[4, 17]])
z = z.drop(columns=z.columns[[4, 17]])
col_dict = {}
si = 1
oi = 1
ti = 1
for topic in z.columns:
    if 'Shared' in topic:
        col_dict[topic] = topic.rsplit(' ', 1)[0] + ' ' + str(si)
        si += 1
    elif 'H3K27ac' in topic:
        col_dict[topic] = topic.rsplit(' ', 1)[0] + ' ' + str(oi)
        oi += 1
    else:
        col_dict[topic] = topic.rsplit(' ', 1)[0] + ' ' + str(ti)
        ti += 1
z = z.rename(columns=col_dict)
w = w.rename(columns=col_dict)
rows = w.shape[1]
columns = 11
# d = sc.read_h5ad('../../Results/' + dataset + '/adata_H3K27ac_ATAC_preprocessed.h5ad')
d = sc.read_h5ad('../Dataset/' + dataset + '/adata_H3K27ac_ATAC.h5ad')
d = d[z.index]
d.obs[z.columns] = z.values
sf = 2
space = 0.2
fig = plt.figure(figsize=(columns * sf, rows * sf), dpi=50)
spec_dr = GridSpec(rows, columns, hspace=space + .2, wspace=space-.2)
spec_dr.update(left=0, right=.98, top=.98, bottom=0.02)
f_dict = {}
i = 0
for i in range(rows):
    genes = w.nlargest(columns - 1, w.columns[i]).index
    for j in range(columns):
        f_dict[w.columns[i] + str(j)] = plt.subplot(spec_dr[i, j])
        if j == 0:
            sc.pl.spatial(d, color=w.columns[i], ax=f_dict[w.columns[i] + str(j)], show=False, frameon=True, legend_loc='none', vmax='p100', legend_fontsize='xx-small')
            label = w.columns[i]
            if 'H3K27ac' in label:
                label = label.split(maxsplit=1)[0] + '\n' + label.split(maxsplit=1)[1]
            f_dict[w.columns[i] + str(j)].set_title(label, fontsize=14, pad=5)
        else:
            sc.pl.spatial(d, color=genes[j - 1], ax=f_dict[w.columns[i] + str(j)], show=False, frameon=True, legend_loc='none', vmax='p99', cmap='coolwarm')
            f_dict[w.columns[i] + str(j)].set_title(genes[j - 1], fontsize=14, pad=2)
        f_dict[w.columns[i] + str(j)].axes.get_xaxis().set_visible(False)
        f_dict[w.columns[i] + str(j)].axes.get_yaxis().set_visible(False)
        
plt.savefig('../Figures/visualisation_4_7_ME13_1_H3K27ac.pdf')
plt.show()
        

In [None]:
z = pd.read_csv('../Results/' + dataset + '/SpaMV_z.csv', index_col=0)
w = pd.read_csv('../Results/' + dataset + '/SpaMV_w_H3K27me3.csv', index_col=0)
w = w.drop(columns=z.columns[[4]])
z = z.drop(columns=z.columns[[4, 17]])
col_dict = {}
si = 1
oi = 1
ti = 1
for topic in z.columns:
    if 'Shared' in topic:
        col_dict[topic] = topic.rsplit(' ', 1)[0] + ' ' + str(si)
        si += 1
    elif 'H3K27ac' in topic:
        col_dict[topic] = topic.rsplit(' ', 1)[0] + ' ' + str(oi)
        oi += 1
    else:
        col_dict[topic] = topic.rsplit(' ', 1)[0] + ' ' + str(ti)
        ti += 1
z = z.rename(columns=col_dict)
w = w.rename(columns=col_dict)
rows = w.shape[1]
columns = 11
# d = sc.read_h5ad('../../Results/' + dataset + '/adata_H3K27ac_ATAC_preprocessed.h5ad')
d = sc.read_h5ad('../Dataset/' + dataset + '/adata_H3K27me3_ATAC.h5ad')
d = d[z.index]
d.obs[z.columns] = z.values
sf = 2
space = 0.2
fig = plt.figure(figsize=(columns * sf, rows * sf), dpi=50)
spec_dr = GridSpec(rows, columns, hspace=space + .2, wspace=space-.2)
spec_dr.update(left=0, right=.98, top=.98, bottom=0.02)
f_dict = {}
i = 0
for i in range(rows):
    genes = w.nlargest(columns - 1, w.columns[i]).index
    for j in range(columns):
        f_dict[w.columns[i] + str(j)] = plt.subplot(spec_dr[i, j])
        if j == 0:
            sc.pl.spatial(d, color=w.columns[i], ax=f_dict[w.columns[i] + str(j)], show=False, frameon=True, legend_loc='none', vmax='p100', legend_fontsize='xx-small')
            label = w.columns[i]
            if 'H3K27me3' in label:
                label = label.split(maxsplit=1)[0] + '\n' + label.split(maxsplit=1)[1]
            f_dict[w.columns[i] + str(j)].set_title(label, fontsize=14, pad=5)
        else:
            sc.pl.spatial(d, color=genes[j - 1], ax=f_dict[w.columns[i] + str(j)], show=False, frameon=True, legend_loc='none', vmax='p99', cmap='coolwarm')
            f_dict[w.columns[i] + str(j)].set_title(genes[j - 1], fontsize=14, pad=5)
        f_dict[w.columns[i] + str(j)].axes.get_xaxis().set_visible(False)
        f_dict[w.columns[i] + str(j)].axes.get_yaxis().set_visible(False)
        
plt.savefig('../Figures/visualisation_4_7_ME13_1_H3K27me3.pdf')
plt.show()

In [None]:
import sys
import os
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
# Get the current working directory
current_dir = os.getcwd()
# Add the parent directory to sys.path
sys.path.insert(0, os.path.dirname(os.path.dirname(current_dir)))
from Methods.SpaMV_copy.utils import plot_top_positive_correlations_boxplot

dataset = '7_ME13_1'
z = pd.read_csv('../Results/' + dataset + '/SpaMV_z.csv', index_col=0)
z = z.drop(columns=z.columns[[4, 17]])
col_dict = {}
si = 1
oi = 1
ti = 1
for topic in z.columns:
    if 'Shared' in topic:
        col_dict[topic] = topic.rsplit(' ', 1)[0] + ' ' + str(si)
        si += 1
    elif 'H3K27ac' in topic:
        col_dict[topic] = topic.rsplit(' ', 1)[0] + ' ' + str(oi)
        oi += 1
    else:
        col_dict[topic] = topic.rsplit(' ', 1)[0] + ' ' + str(ti)
        ti += 1
z = z.rename(columns=col_dict)
d = [sc.read_h5ad('../Dataset/' + dataset + '/adata_H3K27ac_ATAC.h5ad'), sc.read_h5ad('../Dataset/' + dataset + '/adata_H3K27me3_ATAC.h5ad')]
omics_names = ['H3K27ac', 'H3K27me3']
for i in range(len(d)):
    plot_top_positive_correlations_boxplot(d[i], z, omics_name=omics_names[i])
    plt.savefig('../Figures/visualisation_4_7_ME13_1_' + omics_names[i] + '_pcc.pdf')