In [None]:
import scanpy as sc
import numpy as np
import pandas as pd
import anndata as ad
import seaborn as sns
from scipy import stats
import matplotlib.pyplot as plt

from mesa import ecospatial as eco

In [None]:
adata = sc.read_h5ad('RelnAll_Annotated.h5ad')
adata

In [None]:
adata = adata[adata.obs['STR'] == True].copy()

In [None]:
sc.pl.embedding(adata, basis = 'spatial', color = 'cell_type', size = 5)

In [None]:
KEEP_CELL_TYPE = ['Astro-TE', 'CHOR', 'Endo', 'Epen', 'Fibro', 'Inh Lamp5', 'Inh STR D1', 'Inh STR D2', 'Inh Sst', 'Inh Vip', 'Microglia', 'Mural', 'OPC', 'Oligo']

In [None]:
adata = adata[adata.obs['cell_type'].isin(KEEP_CELL_TYPE)].copy()

In [None]:
adata.obsm['spatial'] = adata.obsm['spatial'] * 0.325# Convert units to microns

In [None]:
adata.obs['Sample']

In [None]:
library_ids = ['KO1', 'KO2', 'WT1', 'WT2']

In [None]:
# Define the sequence of scales
scales = [1., 2., 4., 8., 16., 32., 64.]

mdi_results = eco.calculate_MDI(spatial_data=adata,
                                scales=scales,
                                library_key='Sample',
                                library_id=library_ids,
                                spatial_key='spatial',
                                cluster_key='cell_type',
                                random_patch=False,
                                plotfigs=False,
                                savefigs=False,
                                patch_kwargs={'random_seed': None, 'min_points':2},
                                other_kwargs={'metric': 'Shannon Diversity'})

In [None]:
# Add 'Condition' and 'Sample_id' to the columns
mdi_results['Condition'] = ' '
mdi_results['Sample_id'] = mdi_results.index
mdi_results.loc[mdi_results.index.str.contains('WT'), 'Condition'] = 'WT'
mdi_results.loc[mdi_results.index.str.contains('KO'), 'Condition'] = 'KO'
mdi_results.head()

In [None]:
df_melted = pd.melt(mdi_results, id_vars=['Sample_id', 'Condition'], value_vars=scales,
                    var_name='Scale', value_name='Diversity Value')
df_melted['sample'] = 'Tissue Sample'
df_melted

In [None]:
xrange = []
yrange = []
for region in adata.obs['Sample'].unique():
    spatial_value = adata[adata.obs['Sample']==region].obsm['spatial']
    xrange.append(spatial_value.max(axis=0)[0] - spatial_value.min(axis=0)[0])
    yrange.append(spatial_value.max(axis=0)[1] - spatial_value.min(axis=0)[1])
mean_xrange = np.mean(xrange)
std_xrange = np.std(xrange)
mean_yrange = np.mean(yrange)
std_yrange = np.std(yrange)

# Calculate mean and confidence interval
grouped = df_melted.groupby('Scale')
mean_values = grouped['Diversity Value'].mean()
conf_intervals = grouped['Diversity Value'].apply(lambda x: stats.sem(x) * stats.t.ppf((1 + 0.95) / 2., len(x)-1))

# Plotting using sns.lineplot
plt.figure(figsize=(6, 4))
ax = sns.lineplot(data=df_melted,
                  x='Scale',
                  y='Diversity Value',
                  style='sample',
                  markers=True,
                  estimator='mean',
                  err_style='bars',
                  errorbar=("ci", 95),
                  err_kws={"capsize":5.0}
                 )

# Annotating error bars with their value
for i, (scale, mean, ci) in enumerate(zip(mean_values.index, mean_values, conf_intervals)):
    ax.text(scale, mean + ci, f'{mean:.3f}±{ci:.3f}', color='black', ha='center', va='bottom')

# Drawing red dashed horizontal lines at half the maximum of x and y axes
mean_diversity_per_scale = df_melted.groupby('Scale')['Diversity Value'].mean()
y_sep = mean_diversity_per_scale.median()
x_sep = mean_diversity_per_scale.idxmax()

ax.axhline(y_sep, color='red', linestyle='--')
ax.axvline(x_sep, color='red', linestyle='--')
ax.get_legend().remove()

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel('', fontsize=0)
plt.xticks(fontsize=12)
plt.ylabel(f"GDI", fontsize=16)
plt.yticks(fontsize=12)

# Add secondary x-axis
xtick_labels = [tick.get_text() for tick in ax.get_xticklabels()][2:-1]
scales = [int(label) for label in xtick_labels if label.strip() != '']
x_sizes = [mean_xrange / scale for scale in scales]
y_sizes = [mean_yrange / scale for scale in scales]
size_labels = [f"{int(x_size)}×{int(y_size)}" for x_size, y_size in zip(x_sizes, y_sizes)]
secax = ax.secondary_xaxis(location=-0.075)
secax.set_xticks(scales)
secax.set_xticklabels(size_labels)
secax.tick_params('x', length=0)
secax.spines['bottom'].set_linewidth(0)
secax.set_xlabel(f'Scale \n (Area μm²)', fontsize=12)

plt.title(f'GDI per Scale with 95% Confidence Intervals')
plt.grid(False)
fig = plt.gcf()
plt.show()

In [None]:
mdi_results

In [None]:
gdi_results = eco.calculate_GDI(spatial_data=adata,
                                scale=64,
                                library_key='Sample',
                                library_id=library_ids,
                                spatial_key='spatial',
                                cluster_key='cell_type',
                                hotspot=True,
                                restricted=False,
                                metric='Shannon Diversity')
gdi_results

In [None]:
gdi_results.to_csv("STR_GDI.csv")

In [None]:
# Calculate DPI for hotspots
dpi_results = eco.calculate_DPI(spatial_data=adata,
                                scale=64.0,
                                library_key='Sample',
                                library_id=library_ids,
                                spatial_key='spatial',
                                cluster_key='cell_type',
                                hotspot=True,
                                metric='Shannon Diversity')
dpi_results

In [None]:
dpi_results.to_csv('STR_DPI.csv')

In [None]:
global_cellfreq_df, global_co_occurrence_df = eco.spot_cellfreq(spatial_data=adata,
                                                                scale=64.0,
                                                                library_key='Sample',
                                                                library_id=library_ids,
                                                                spatial_key='spatial',
                                                                cluster_key='cell_type',
                                                                spots='global',
                                                                top=None,
                                                                selected_comb=None,
                                                                restricted=False,
                                                                metric='Shannon Diversity')

In [None]:
global_cellfreq_df

In [None]:
global_cellfreq_df['Condition'] = ' '
global_cellfreq_df.loc[global_cellfreq_df.index.str.contains('KO'), 'Condition'] = 'KO'
global_cellfreq_df.loc[global_cellfreq_df.index.str.contains('WT'), 'Condition'] = 'WT'

global_co_occurrence_subcols = global_co_occurrence_df.loc[:,global_co_occurrence_df.mean()>0.05].columns.tolist()
global_co_occurrence_df['Condition'] = ' '
global_co_occurrence_df['Patch'] = global_co_occurrence_df.index
global_co_occurrence_df.loc[global_co_occurrence_df.index.str.contains('KO'), 'Condition'] = 'KO'
global_co_occurrence_df.loc[global_co_occurrence_df.index.str.contains('WT'), 'Condition'] = 'WT'
global_co_occurrence_subcols.extend([('Condition',''),('Patch','')])

In [None]:
global_cellfreq_df

In [None]:
# Melt the dataframe for easier plotting and statistical analysis
global_cellfreq_df_melt = global_cellfreq_df.reset_index().melt(id_vars=['Sample', 'Condition'])
global_cellfreq_df_melt.columns = ['Sample', 'group', 'cell_type', 'Frequency']

global_cellfreq_df_melt

In [None]:
# Perform t-tests
selected_cell_types = sorted(adata.obs['cell_type'].unique())
selected_p_values = []
for ct in selected_cell_types:
    group1 = global_cellfreq_df_melt[(global_cellfreq_df_melt['cell_type'] == ct) & (global_cellfreq_df_melt['group'] == 'WT')]['Frequency']
    group2 = global_cellfreq_df_melt[(global_cellfreq_df_melt['cell_type'] == ct) & (global_cellfreq_df_melt['group'] == 'KO')]['Frequency']
    t_stat, p_value = stats.ttest_ind(group1, group2, equal_var=False)
    print(f"{ct} has p value of {p_value}")
    selected_p_values.append(p_value)

pvals_corrected = stats.false_discovery_control(selected_p_values, method='bh')
print('-'*42)
print(f"p-values after correction:")

# Plot
fig, ax = plt.subplots(figsize=(30,10))
sns.boxplot(data=global_cellfreq_df_melt, x='cell_type', y='Frequency', hue='group', palette='muted', boxprops=dict(alpha=.3), ax=ax, dodge=True, order=selected_cell_types)
sns.swarmplot(data=global_cellfreq_df_melt, x='cell_type', y='Frequency', hue='group', palette='dark:black', size=2.0, dodge=True, order=selected_cell_types, ax=ax, edgecolor='auto', linewidth=0.5)
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:2], labels[:2], title="Groups", handletextpad=1, columnspacing=1, bbox_to_anchor=(1, 1), ncol=3, frameon=True)
plt.xticks(rotation=90)

p_vals_corrected_dict = {}
yrange = ax.get_ylim()[1] - ax.get_ylim()[0]
for i, ct in enumerate(selected_cell_types):
    ax.text(i, yrange, f"p = {pvals_corrected[i]:.3f}", ha='center', fontsize=12, rotation=0)
    print(f"{ct} has p value = {pvals_corrected[i]:.3f}", flush=True)
    p_vals_corrected_dict[ct] = pvals_corrected[i]

for i in range(len(selected_cell_types) - 1):
    ax.axvline(i + 0.55, color='grey', linestyle='--', linewidth=0.5)

ax.set_ylabel("Frequency", fontsize=14)
ax.set_xlabel('')
plt.savefig("Global_Cell_Frequency_STR.pdf", dpi = 300)
plt.show()

In [None]:
spot_cellfreq_df, spot_co_occurrence_df = eco.spot_cellfreq(spatial_data=adata,
                                                            scale=64.0,
                                                            library_key='Sample',
                                                            library_id=library_ids,
                                                            spatial_key='spatial',
                                                            cluster_key='cell_type',
                                                            spots='hot',
                                                            top=None,
                                                            selected_comb=None,
                                                            restricted=False,
                                                            metric='Shannon Diversity')

In [None]:
spot_cellfreq_df['Condition'] = ' '
spot_cellfreq_df.loc[spot_cellfreq_df.index.str.contains('KO'), 'Condition'] = 'KO'
spot_cellfreq_df.loc[spot_cellfreq_df.index.str.contains('WT'), 'Condition'] = 'WT'

spot_co_occurrence_subcols = spot_co_occurrence_df.loc[:,spot_co_occurrence_df.mean()>0.05].columns.tolist()
spot_co_occurrence_df['Condition'] = ' '
spot_co_occurrence_df['Patch'] = spot_co_occurrence_df.index
spot_co_occurrence_df.loc[spot_co_occurrence_df.index.str.contains('KO'), 'Condition'] = 'KO'
spot_co_occurrence_df.loc[spot_co_occurrence_df.index.str.contains('WT'), 'Condition'] = 'WT'
spot_co_occurrence_subcols.extend([('Condition',''),('Patch','')])

In [None]:
spot_cellfreq_df['Patch'] = spot_cellfreq_df.index

# Melt the DataFrame
spot_cellfreq_df_melt = spot_cellfreq_df.melt(id_vars=['Patch', 'Condition'], var_name='CellType', value_name='Frequency')

In [None]:
spot_cellfreq_df_melt

In [None]:
selected_cell_types = sorted(spot_cellfreq_df_melt['CellType'].unique())
selected_p_values = []

# Perform t-tests
print(f"p-value before correction:")
for ct in selected_cell_types: # df_melted['CellType'].unique():
    subset = spot_cellfreq_df_melt[spot_cellfreq_df_melt['CellType'] == ct]
    group1 = subset[subset['Condition'] == 'WT']['Frequency']
    group2 = subset[subset['Condition'] == 'KO']['Frequency']

    t_stat, p_value = stats.ttest_ind(group1, group2, equal_var=False)
    print(f"{ct} has p value = {p_value:.4f}")
    selected_p_values.append(p_value)

# Filter the dataframe based on selected CellTypes
df_filtered = spot_cellfreq_df_melt[spot_cellfreq_df_melt['CellType'].isin(selected_cell_types)]

# Plot the filtered data
fig, ax = plt.subplots(figsize=(30,10))
sns.boxplot(data=df_filtered, x='CellType', y='Frequency', hue='Condition', palette='muted', boxprops=dict(alpha=.3), ax=ax, dodge=True,order=selected_cell_types)
sns.swarmplot(data=df_filtered, x='CellType', y='Frequency', hue='Condition', palette='dark:black', size=3.0, dodge=True, order=selected_cell_types, ax=ax, edgecolor='auto', linewidth=0.5)

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:2], labels[:2], title="Groups", handletextpad=1, columnspacing=1, bbox_to_anchor=(1, 1), ncol=3, frameon=True)

spot_pvals_corrected = stats.false_discovery_control(selected_p_values, method='bh')
spot_pvals_corrected = dict(map(lambda i,j : (i,j) , selected_cell_types, spot_pvals_corrected))

print('-'*42)
print(f"p-values after correction: ")

yrange = ax.get_ylim()[1] - ax.get_ylim()[0]
for i, ct in enumerate(selected_cell_types):
    ax.text(i, yrange, f"p = {spot_pvals_corrected[ct]:.3f}", ha='center', fontsize=12, rotation=90)
    print(f"{ct} in hot spots has p value = {spot_pvals_corrected[ct]:.3f}", flush=True)
    if spot_pvals_corrected[ct] < 0.05 and p_vals_corrected_dict[ct] > 0.05:
        print(f"{ct} in whole tissue has p value = {p_vals_corrected_dict[ct]:.3f}", flush=True)
        print('*'*42)

for i in range(len(selected_cell_types) - 1):
    ax.axvline(i + 0.55, color='grey', linestyle='--', linewidth=0.5)

ax.set_ylabel("Frequency", fontsize=14)
ax.set_xlabel('')
plt.xticks(rotation=90)
plt.yticks(rotation=90)
plt.savefig("HotColdSpot_Cell_Frequency_STR.pdf", dpi = 300)
plt.show()
plt.close(fig)

In [None]:
union_cols = set(global_co_occurrence_subcols).union(set(spot_co_occurrence_subcols))

In [None]:
# Make them have the same set of columns
global_co_occurrence_df = global_co_occurrence_df.reindex(columns=union_cols).fillna(0)
spot_co_occurrence_df = spot_co_occurrence_df.reindex(columns=union_cols).fillna(0)

In [None]:
# Global Cell Co-Occurrence
# Multi-index to single-index column
new_columns = []
for col in global_co_occurrence_df.columns:
    if isinstance(col, tuple):  # This checks if the column is a MultiIndex
        # Join only if the column name is not 'Mouse' or 'Condition'
        if "Patch" not in col and "Condition" not in col:
            new_columns.append('&'.join(map(str, col)).strip())
        else:
            # If 'Mouse' or 'Condition' is in the column, it is not joined with '&'
            new_columns.append(col[0])
    else:
        new_columns.append(col)

global_co_occurrence_df_single = global_co_occurrence_df.copy()
global_co_occurrence_df_single.columns = new_columns
global_co_occurrence_df_single = global_co_occurrence_df_single[[col for col in global_co_occurrence_df_single.columns if 'noid' not in col]]

# Melt the DataFrame
global_co_occurrence_melted = global_co_occurrence_df_single.melt(id_vars=['Patch', 'Condition'], var_name='Cell Combination', value_name='Frequency')
global_co_occurrence_melted

In [None]:
# Global Cell Co-Occurrence
selected_cell_types = sorted(global_co_occurrence_melted['Cell Combination'].unique())
selected_p_values = []

# Perform t-tests
print(f"p-value before correction:")
for ct in selected_cell_types:
    subset = global_co_occurrence_melted[global_co_occurrence_melted['Cell Combination'] == ct]
    group1 = subset[subset['Condition'] == 'WT']['Frequency']
    group2 = subset[subset['Condition'] == 'KO']['Frequency']

    t_stat, p_value = stats.ttest_ind(group1, group2, equal_var=False)
    print(f"{ct} has p value = {p_value:.3f}")
    selected_p_values.append(p_value)

# Filter the dataframe based on selected Cell Combinations
df_filtered = global_co_occurrence_melted[global_co_occurrence_melted['Cell Combination'].isin(selected_cell_types)]

# Plot the filtered data
fig, ax = plt.subplots(figsize=(45,10))
sns.boxplot(data=df_filtered, x='Cell Combination', y='Frequency', hue='Condition', palette='muted', boxprops=dict(alpha=.3), ax=ax, dodge=True,order=selected_cell_types)
sns.swarmplot(data=df_filtered, x='Cell Combination', y='Frequency', hue='Condition', palette='dark:black', size=1.0, dodge=True,order=selected_cell_types, ax=ax, edgecolor='gray', linewidth=0.5)

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:2], labels[:2], title="Groups", handletextpad=1, columnspacing=1, bbox_to_anchor=(1, 1), ncol=3, frameon=True)

pvals_corrected = stats.false_discovery_control(selected_p_values, method='bh')

print('-'*42)
print(f"p-values after correction:")

p_vals_corrected_dict = {}
yrange = ax.get_ylim()[1] - ax.get_ylim()[0]
for i, ct in enumerate(selected_cell_types):
    ax.text(i, yrange, f"p = {pvals_corrected[i]:.3f}", ha='center', fontsize=8, rotation=0)
    print(f"{ct} has p value = {pvals_corrected[i]:.3f}", flush=True)
    p_vals_corrected_dict[ct] = pvals_corrected[i]

for i in range(len(selected_cell_types) - 1):
    ax.axvline(i + 0.55, color='grey', linestyle='--', linewidth=0.5)

ax.set_ylabel("Frequency", fontsize=14)
ax.set_xlabel('')
plt.xticks(rotation=90)
plt.savefig("Global_Cell_co_occurance_STR.pdf", dpi = 300)
plt.show()

In [None]:
# Spot Cell Co-Occurrence
# Multi-index to single-index column
new_columns = []
for col in spot_co_occurrence_df.columns:
    if isinstance(col, tuple):  # This checks if the column is a MultiIndex
        # Join only if the column name is not 'Mouse' or 'Condition'
        if "Patch" not in col and "Condition" not in col:
            new_columns.append('&'.join(map(str, col)).strip())
        else:
            # If 'Mouse' or 'Condition' is in the column, it is not joined with '&'
            new_columns.append(col[0])
    else:
        new_columns.append(col)

spot_co_occurrence_df_single = spot_co_occurrence_df.copy()
spot_co_occurrence_df_single.columns = new_columns
spot_co_occurrence_df_single = spot_co_occurrence_df_single[[col for col in spot_co_occurrence_df_single.columns if 'noid' not in col]]

# Melt the DataFrame
spot_co_occurrence_melted = spot_co_occurrence_df_single.melt(id_vars=['Patch', 'Condition'], var_name='Cell Combination', value_name='Frequency')
spot_co_occurrence_melted

In [None]:
# Spot Cell Co-Occurrence
selected_cell_types = sorted(spot_co_occurrence_melted['Cell Combination'].unique())
selected_p_values = []

# Perform t-tests
print(f"p-value before correction: ")
for ct in selected_cell_types: # df_melted['CellType'].unique():
    subset = spot_co_occurrence_melted[spot_co_occurrence_melted['Cell Combination'] == ct]
    group1 = subset[subset['Condition'] == 'WT']['Frequency']
    group2 = subset[subset['Condition'] == 'KO']['Frequency']

    t_stat, p_value = stats.ttest_ind(group1, group2, equal_var=False)
    print(f"{ct} has p value = {p_value:.3f}")
    selected_p_values.append(p_value)

# Filter the dataframe based on selected CellTypes
df_filtered = spot_co_occurrence_melted[spot_co_occurrence_melted['Cell Combination'].isin(selected_cell_types)]

# Plot the filtered data
fig, ax = plt.subplots(figsize=(42,10))
sns.boxplot(data=df_filtered, x='Cell Combination', y='Frequency', hue='Condition', palette='muted', boxprops=dict(alpha=.3), ax=ax, dodge=True,order=selected_cell_types)
sns.swarmplot(data=df_filtered, x='Cell Combination', y='Frequency', hue='Condition', palette='dark:black', size=2.0, dodge=True,order=selected_cell_types, ax=ax, edgecolor='gray', linewidth=0.5)

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:2], labels[:2], title="Groups", handletextpad=1, columnspacing=1, bbox_to_anchor=(1, 1), ncol=3, frameon=True)

hot_pvals_corrected = stats.false_discovery_control(selected_p_values, method='bh')

print('-'*42)
print(f"p-values after correction:")

highlighted_comb = []
yrange = ax.get_ylim()[1] - ax.get_ylim()[0]
for i, ct in enumerate(selected_cell_types):
    ax.text(i, yrange, f"p = {pvals_corrected[i]:.3f}", ha='center', fontsize=8, rotation=0)
    print(f"{ct} in hot spots has p value = {hot_pvals_corrected[i]:.3f}", flush=True)
    if hot_pvals_corrected[i] < 0.05 and p_vals_corrected_dict[ct] >= 0.05:
        highlighted_comb.append(tuple(map(str.strip, ct.split('&'))))
        print(f"{ct} in whole tissue has p value = {p_vals_corrected_dict[ct]:.3f}", flush=True)
        print('*'*42)

for i in range(len(selected_cell_types) - 1):
    ax.axvline(i + 0.55, color='grey', linestyle='--', linewidth=0.5)

ax.set_ylabel("Frequency", fontsize=14)
ax.set_xlabel('')
plt.xticks(rotation=90)
plt.yticks(rotation=90)
plt.savefig("Spot_Cell_co_occurance_STR.pdf", dpi = 300)
plt.show()

In [None]:
circoplot_df1 = global_co_occurrence_df.sort_index(axis=1, level=[0,1]).drop(columns=['Patch'])
circoplot_df1 = circoplot_df1[[col for col in circoplot_df1.columns if 'noid' not in col]]
# Group by 'Condition' and calculate the mean of the other columns
circoplot_df1 = circoplot_df1.groupby('Condition').mean().reset_index()
circoplot_df1 = circoplot_df1.set_index('Condition')
circoplot_df1

In [None]:
global_cellfreq_df

In [None]:
# circoplot_df2 = global_cellfreq_df.drop(columns=['noid'])
circoplot_df2 = global_cellfreq_df.copy()
# Group by 'Condition' and calculate the mean of the other columns
circoplot_df2 = circoplot_df2.groupby('Condition').mean().reset_index()
circoplot_df2 = circoplot_df2.set_index('Condition')
circoplot_df2

In [None]:
patient_group = 'WT'
eco.create_circos_plot(circoplot_df1.loc[[patient_group]],
                       cell_type_colors_hex=None,
                       cell_abundance=circoplot_df2.loc[[patient_group]],
                       threshold=0.05,
                       edge_weights_scaler=10,
                       highlighted_edges=None,
                       node_weights_scaler=5000,
                       figure_size=(8,8),
                       save_path='WT_STR_circoplot.pdf')

In [None]:
circoplot_df1 = spot_co_occurrence_df.sort_index(axis=1, level=[0,1]).drop(columns=['Patch'])
circoplot_df1 = circoplot_df1[[col for col in circoplot_df1.columns if 'noid' not in col]]
# Group by 'Condition' and calculate the mean of the other columns
circoplot_df1 = circoplot_df1.groupby('Condition').mean().reset_index()
circoplot_df1 = circoplot_df1.set_index('Condition')
circoplot_df1

In [None]:
circoplot_df2 = spot_cellfreq_df.drop(columns=['Patch'])
# Group by 'Condition' and calculate the mean of the other columns
circoplot_df2 = circoplot_df2.groupby('Condition').mean().reset_index()
circoplot_df2 = circoplot_df2.set_index('Condition')
circoplot_df2

In [None]:
patient_group = 'KO'
eco.create_circos_plot(circoplot_df1.loc[[patient_group]],
                       cell_type_colors_hex=None,
                       cell_abundance=circoplot_df2.loc[[patient_group]],
                       threshold=0.05,
                       edge_weights_scaler=10,
                       highlighted_edges=highlighted_comb,
                       node_weights_scaler=5000,
                       figure_size=(8,8),
                       save_path='KO_STR_circoplot.pdf')

In [None]:
import scanpy as sc
import numpy as np
import pandas as pd
import anndata as ad
import seaborn as sns
from scipy import stats
import matplotlib.pyplot as plt

from mesa import ecospatial as eco

In [None]:
adata = sc.read_h5ad('RelnAll_Annotated.h5ad')
adata

In [None]:
adata = adata[adata.obs['CB'] == True].copy()

In [None]:
sc.pl.embedding(adata, basis = 'spatial', color = 'cell_type', size = 5)

In [None]:
KEEP_CELL_TYPE = ['Astro-CB', 'CHOR', 'Endo', 'Epen', 'Ext CB', 'Ext UBC', 'Fibro', 'Inh CB', 'Inh CB Purkinje', 'Microglia', 'Mural', 'OPC', 'Oligo']

In [None]:
adata = adata[adata.obs['cell_type'].isin(KEEP_CELL_TYPE)].copy()

In [None]:
adata.obsm['spatial'] = adata.obsm['spatial'] * 0.325# Convert units to microns

In [None]:
adata.obs['Sample']

## 先单独运行一个样本的MESA，看看scale factor选多少合适

In [None]:
adata_all = adata.copy()

In [None]:
adata = adata_all[adata_all.obs['Sample'] == 'WT2'].copy()

In [None]:
adata.obs['sample'] = '1'

In [None]:
library_ids = adata.obs['sample'].unique().tolist()

# Define the sequence of scales
scales = [2., 4., 8., 16., 24., 32., 48., 64., 72.]

mdi_results = eco.calculate_MDI(spatial_data=adata,
                                scales=scales,
                                library_key='sample',
                                library_id=library_ids,
                                spatial_key='spatial',
                                cluster_key='cell_type',
                                selecting_scale=True,
                                random_patch=False,
                                plotfigs=False,
                                savefigs=False,
                                patch_kwargs={'random_seed': None, 'min_points':2},
                                other_kwargs={'metric': 'Shannon Diversity'})

In [None]:
mdi_results

In [None]:
# Add 'Condition' and 'Sample_id' to the columns
mdi_results['Condition'] = 'WT'
mdi_results['Sample_id'] = '1'
mdi_results

In [None]:
df_melted = pd.melt(mdi_results, id_vars=['Sample_id', 'Condition'], value_vars=scales,
                    var_name='Scale', value_name='Diversity Value')
df_melted['sample'] = 'Tissue Sample'
df_melted

In [None]:
xrange = []
yrange = []
for region in adata.obs['sample'].unique():
    spatial_value = adata[adata.obs['sample']==region].obsm['spatial']
    xrange.append(spatial_value.max(axis=0)[0] - spatial_value.min(axis=0)[0])
    yrange.append(spatial_value.max(axis=0)[1] - spatial_value.min(axis=0)[1])
mean_xrange = np.mean(xrange)
std_xrange = np.std(xrange)
mean_yrange = np.mean(yrange)
std_yrange = np.std(yrange)

# Calculate mean and confidence interval
grouped = df_melted.groupby('Scale')
mean_values = grouped['Diversity Value'].mean()
conf_intervals = grouped['Diversity Value'].apply(lambda x: stats.sem(x) * stats.t.ppf((1 + 0.95) / 2., len(x)-1))

# Plotting using sns.lineplot
plt.figure(figsize=(6, 4))
ax = sns.lineplot(data=df_melted,
                  x='Scale',
                  y='Diversity Value',
                  style='sample',
                  markers=True,
                  estimator='mean',
                  err_style='bars',
                  errorbar=("ci", 95),
                  err_kws={"capsize":5.0}
                 )

# Annotating error bars with their value
for i, (scale, mean, ci) in enumerate(zip(mean_values.index, mean_values, conf_intervals)):
    ax.text(scale, mean + ci, f'{mean:.3f}±{ci:.3f}', color='black', ha='center', va='bottom')

# Drawing red dashed horizontal lines at half the maximum of x and y axes
mean_diversity_per_scale = df_melted.groupby('Scale')['Diversity Value'].mean()
y_sep = mean_diversity_per_scale.median()
x_sep = mean_diversity_per_scale.idxmax()

ax.axhline(y_sep, color='red', linestyle='--')
ax.axvline(x_sep, color='red', linestyle='--')
ax.get_legend().remove()

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel('', fontsize=0)
plt.xticks(fontsize=12)
plt.ylabel(f"GDI", fontsize=16)
plt.yticks(fontsize=12)

# Add secondary x-axis
xtick_labels = [tick.get_text() for tick in ax.get_xticklabels()][2:-1]
scales = [int(label) for label in xtick_labels if label.strip() != '']
x_sizes = [mean_xrange / scale for scale in scales]
y_sizes = [mean_yrange / scale for scale in scales]
size_labels = [f"{int(x_size)}×{int(y_size)}" for x_size, y_size in zip(x_sizes, y_sizes)]
secax = ax.secondary_xaxis(location=-0.075)
secax.set_xticks(scales)
secax.set_xticklabels(size_labels)
secax.tick_params('x', length=0)
secax.spines['bottom'].set_linewidth(0)
secax.set_xlabel(f'Scale \n (Area μm²)', fontsize=12)

plt.title(f'GDI per Scale with 95% Confidence Intervals')
plt.grid(False)
fig = plt.gcf()
plt.show()

In [None]:
adata = adata_all.copy()

In [None]:
library_ids = ['KO1', 'KO2', 'WT1', 'WT2']

In [None]:
# Define the sequence of scales
scales = [1., 2., 4., 8., 16., 32., 64.]

mdi_results = eco.calculate_MDI(spatial_data=adata,
                                scales=scales,
                                library_key='Sample',
                                library_id=library_ids,
                                spatial_key='spatial',
                                cluster_key='cell_type',
                                random_patch=False,
                                plotfigs=False,
                                savefigs=False,
                                patch_kwargs={'random_seed': None, 'min_points':2},
                                other_kwargs={'metric': 'Shannon Diversity'})

In [None]:
# Add 'Condition' and 'Sample_id' to the columns
mdi_results['Condition'] = ' '
mdi_results['Sample_id'] = mdi_results.index
mdi_results.loc[mdi_results.index.str.contains('WT'), 'Condition'] = 'WT'
mdi_results.loc[mdi_results.index.str.contains('KO'), 'Condition'] = 'KO'
mdi_results.head()

In [None]:
df_melted = pd.melt(mdi_results, id_vars=['Sample_id', 'Condition'], value_vars=scales,
                    var_name='Scale', value_name='Diversity Value')
df_melted['sample'] = 'Tissue Sample'
df_melted

In [None]:
xrange = []
yrange = []
for region in adata.obs['Sample'].unique():
    spatial_value = adata[adata.obs['Sample']==region].obsm['spatial']
    xrange.append(spatial_value.max(axis=0)[0] - spatial_value.min(axis=0)[0])
    yrange.append(spatial_value.max(axis=0)[1] - spatial_value.min(axis=0)[1])
mean_xrange = np.mean(xrange)
std_xrange = np.std(xrange)
mean_yrange = np.mean(yrange)
std_yrange = np.std(yrange)

# Calculate mean and confidence interval
grouped = df_melted.groupby('Scale')
mean_values = grouped['Diversity Value'].mean()
conf_intervals = grouped['Diversity Value'].apply(lambda x: stats.sem(x) * stats.t.ppf((1 + 0.95) / 2., len(x)-1))

# Plotting using sns.lineplot
plt.figure(figsize=(6, 4))
ax = sns.lineplot(data=df_melted,
                  x='Scale',
                  y='Diversity Value',
                  style='sample',
                  markers=True,
                  estimator='mean',
                  err_style='bars',
                  errorbar=("ci", 95),
                  err_kws={"capsize":5.0}
                 )

# Annotating error bars with their value
for i, (scale, mean, ci) in enumerate(zip(mean_values.index, mean_values, conf_intervals)):
    ax.text(scale, mean + ci, f'{mean:.3f}±{ci:.3f}', color='black', ha='center', va='bottom')

# Drawing red dashed horizontal lines at half the maximum of x and y axes
mean_diversity_per_scale = df_melted.groupby('Scale')['Diversity Value'].mean()
y_sep = mean_diversity_per_scale.median()
x_sep = mean_diversity_per_scale.idxmax()

ax.axhline(y_sep, color='red', linestyle='--')
ax.axvline(x_sep, color='red', linestyle='--')
ax.get_legend().remove()

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel('', fontsize=0)
plt.xticks(fontsize=12)
plt.ylabel(f"GDI", fontsize=16)
plt.yticks(fontsize=12)

# Add secondary x-axis
xtick_labels = [tick.get_text() for tick in ax.get_xticklabels()][2:-1]
scales = [int(label) for label in xtick_labels if label.strip() != '']
x_sizes = [mean_xrange / scale for scale in scales]
y_sizes = [mean_yrange / scale for scale in scales]
size_labels = [f"{int(x_size)}×{int(y_size)}" for x_size, y_size in zip(x_sizes, y_sizes)]
secax = ax.secondary_xaxis(location=-0.075)
secax.set_xticks(scales)
secax.set_xticklabels(size_labels)
secax.tick_params('x', length=0)
secax.spines['bottom'].set_linewidth(0)
secax.set_xlabel(f'Scale \n (Area μm²)', fontsize=12)

plt.title(f'GDI per Scale with 95% Confidence Intervals')
plt.grid(False)
fig = plt.gcf()
plt.show()

In [None]:
mdi_results

In [None]:
gdi_results = eco.calculate_GDI(spatial_data=adata,
                                scale=64,
                                library_key='Sample',
                                library_id=library_ids,
                                spatial_key='spatial',
                                cluster_key='cell_type',
                                hotspot=True,
                                restricted=False,
                                metric='Shannon Diversity')
gdi_results

In [None]:
gdi_results.to_csv("CB_GDI.csv")

In [None]:
# Calculate DPI for hotspots
dpi_results = eco.calculate_DPI(spatial_data=adata,
                                scale=64.0,
                                library_key='Sample',
                                library_id=library_ids,
                                spatial_key='spatial',
                                cluster_key='cell_type',
                                hotspot=True,
                                metric='Shannon Diversity')
dpi_results

In [None]:
dpi_results.to_csv('CB_DPI.csv')

In [None]:
global_cellfreq_df, global_co_occurrence_df = eco.spot_cellfreq(spatial_data=adata,
                                                                scale=64.0,
                                                                library_key='Sample',
                                                                library_id=library_ids,
                                                                spatial_key='spatial',
                                                                cluster_key='cell_type',
                                                                spots='global',
                                                                top=None,
                                                                selected_comb=None,
                                                                restricted=False,
                                                                metric='Shannon Diversity')

In [None]:
global_cellfreq_df

In [None]:
global_cellfreq_df['Condition'] = ' '
global_cellfreq_df.loc[global_cellfreq_df.index.str.contains('KO'), 'Condition'] = 'KO'
global_cellfreq_df.loc[global_cellfreq_df.index.str.contains('WT'), 'Condition'] = 'WT'

global_co_occurrence_subcols = global_co_occurrence_df.loc[:,global_co_occurrence_df.mean()>0.05].columns.tolist()
global_co_occurrence_df['Condition'] = ' '
global_co_occurrence_df['Patch'] = global_co_occurrence_df.index
global_co_occurrence_df.loc[global_co_occurrence_df.index.str.contains('KO'), 'Condition'] = 'KO'
global_co_occurrence_df.loc[global_co_occurrence_df.index.str.contains('WT'), 'Condition'] = 'WT'
global_co_occurrence_subcols.extend([('Condition',''),('Patch','')])

In [None]:
global_cellfreq_df

In [None]:
# Melt the dataframe for easier plotting and statistical analysis
global_cellfreq_df_melt = global_cellfreq_df.reset_index().melt(id_vars=['Sample', 'Condition'])
global_cellfreq_df_melt.columns = ['Sample', 'group', 'cell_type', 'Frequency']

global_cellfreq_df_melt

In [None]:
# Perform t-tests
selected_cell_types = sorted(adata.obs['cell_type'].unique())
selected_p_values = []
for ct in selected_cell_types:
    group1 = global_cellfreq_df_melt[(global_cellfreq_df_melt['cell_type'] == ct) & (global_cellfreq_df_melt['group'] == 'WT')]['Frequency']
    group2 = global_cellfreq_df_melt[(global_cellfreq_df_melt['cell_type'] == ct) & (global_cellfreq_df_melt['group'] == 'KO')]['Frequency']
    t_stat, p_value = stats.ttest_ind(group1, group2, equal_var=False)
    print(f"{ct} has p value of {p_value}")
    selected_p_values.append(p_value)

pvals_corrected = stats.false_discovery_control(selected_p_values, method='bh')
print('-'*42)
print(f"p-values after correction:")

# Plot
fig, ax = plt.subplots(figsize=(30,10))
sns.boxplot(data=global_cellfreq_df_melt, x='cell_type', y='Frequency', hue='group', palette='muted', boxprops=dict(alpha=.3), ax=ax, dodge=True, order=selected_cell_types)
sns.swarmplot(data=global_cellfreq_df_melt, x='cell_type', y='Frequency', hue='group', palette='dark:black', size=2.0, dodge=True, order=selected_cell_types, ax=ax, edgecolor='auto', linewidth=0.5)
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:2], labels[:2], title="Groups", handletextpad=1, columnspacing=1, bbox_to_anchor=(1, 1), ncol=3, frameon=True)
plt.xticks(rotation=90)

p_vals_corrected_dict = {}
yrange = ax.get_ylim()[1] - ax.get_ylim()[0]
for i, ct in enumerate(selected_cell_types):
    ax.text(i, yrange, f"p = {pvals_corrected[i]:.3f}", ha='center', fontsize=12, rotation=0)
    print(f"{ct} has p value = {pvals_corrected[i]:.3f}", flush=True)
    p_vals_corrected_dict[ct] = pvals_corrected[i]

for i in range(len(selected_cell_types) - 1):
    ax.axvline(i + 0.55, color='grey', linestyle='--', linewidth=0.5)

ax.set_ylabel("Frequency", fontsize=14)
ax.set_xlabel('')
plt.savefig("Global_Cell_Frequency_CB.pdf", dpi = 300)
plt.show()

In [None]:
spot_cellfreq_df, spot_co_occurrence_df = eco.spot_cellfreq(spatial_data=adata,
                                                            scale=64.0,
                                                            library_key='Sample',
                                                            library_id=library_ids,
                                                            spatial_key='spatial',
                                                            cluster_key='cell_type',
                                                            spots='hot',
                                                            top=None,
                                                            selected_comb=None,
                                                            restricted=False,
                                                            metric='Shannon Diversity')

In [None]:
spot_cellfreq_df['Condition'] = ' '
spot_cellfreq_df.loc[spot_cellfreq_df.index.str.contains('KO'), 'Condition'] = 'KO'
spot_cellfreq_df.loc[spot_cellfreq_df.index.str.contains('WT'), 'Condition'] = 'WT'

spot_co_occurrence_subcols = spot_co_occurrence_df.loc[:,spot_co_occurrence_df.mean()>0.05].columns.tolist()
spot_co_occurrence_df['Condition'] = ' '
spot_co_occurrence_df['Patch'] = spot_co_occurrence_df.index
spot_co_occurrence_df.loc[spot_co_occurrence_df.index.str.contains('KO'), 'Condition'] = 'KO'
spot_co_occurrence_df.loc[spot_co_occurrence_df.index.str.contains('WT'), 'Condition'] = 'WT'
spot_co_occurrence_subcols.extend([('Condition',''),('Patch','')])

In [None]:
spot_cellfreq_df['Patch'] = spot_cellfreq_df.index

# Melt the DataFrame
spot_cellfreq_df_melt = spot_cellfreq_df.melt(id_vars=['Patch', 'Condition'], var_name='CellType', value_name='Frequency')

In [None]:
spot_cellfreq_df_melt

In [None]:
selected_cell_types = sorted(spot_cellfreq_df_melt['CellType'].unique())
selected_p_values = []

# Perform t-tests
print(f"p-value before correction:")
for ct in selected_cell_types: # df_melted['CellType'].unique():
    subset = spot_cellfreq_df_melt[spot_cellfreq_df_melt['CellType'] == ct]
    group1 = subset[subset['Condition'] == 'WT']['Frequency']
    group2 = subset[subset['Condition'] == 'KO']['Frequency']

    t_stat, p_value = stats.ttest_ind(group1, group2, equal_var=False)
    print(f"{ct} has p value = {p_value:.4f}")
    selected_p_values.append(p_value)

# Filter the dataframe based on selected CellTypes
df_filtered = spot_cellfreq_df_melt[spot_cellfreq_df_melt['CellType'].isin(selected_cell_types)]

# Plot the filtered data
fig, ax = plt.subplots(figsize=(30,10))
sns.boxplot(data=df_filtered, x='CellType', y='Frequency', hue='Condition', palette='muted', boxprops=dict(alpha=.3), ax=ax, dodge=True,order=selected_cell_types)
sns.swarmplot(data=df_filtered, x='CellType', y='Frequency', hue='Condition', palette='dark:black', size=3.0, dodge=True, order=selected_cell_types, ax=ax, edgecolor='auto', linewidth=0.5)

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:2], labels[:2], title="Groups", handletextpad=1, columnspacing=1, bbox_to_anchor=(1, 1), ncol=3, frameon=True)

spot_pvals_corrected = stats.false_discovery_control(selected_p_values, method='bh')
spot_pvals_corrected = dict(map(lambda i,j : (i,j) , selected_cell_types, spot_pvals_corrected))

print('-'*42)
print(f"p-values after correction: ")

yrange = ax.get_ylim()[1] - ax.get_ylim()[0]
for i, ct in enumerate(selected_cell_types):
    ax.text(i, yrange, f"p = {spot_pvals_corrected[ct]:.3f}", ha='center', fontsize=12, rotation=90)
    print(f"{ct} in hot spots has p value = {spot_pvals_corrected[ct]:.3f}", flush=True)
    if spot_pvals_corrected[ct] < 0.05 and p_vals_corrected_dict[ct] > 0.05:
        print(f"{ct} in whole tissue has p value = {p_vals_corrected_dict[ct]:.3f}", flush=True)
        print('*'*42)

for i in range(len(selected_cell_types) - 1):
    ax.axvline(i + 0.55, color='grey', linestyle='--', linewidth=0.5)

ax.set_ylabel("Frequency", fontsize=14)
ax.set_xlabel('')
plt.xticks(rotation=90)
plt.yticks(rotation=90)
plt.savefig("HotColdSpot_Cell_Frequency_CB.pdf", dpi = 300)
plt.show()
plt.close(fig)

In [None]:
union_cols = set(global_co_occurrence_subcols).union(set(spot_co_occurrence_subcols))

In [None]:
# Make them have the same set of columns
global_co_occurrence_df = global_co_occurrence_df.reindex(columns=union_cols).fillna(0)
spot_co_occurrence_df = spot_co_occurrence_df.reindex(columns=union_cols).fillna(0)

In [None]:
# Global Cell Co-Occurrence
# Multi-index to single-index column
new_columns = []
for col in global_co_occurrence_df.columns:
    if isinstance(col, tuple):  # This checks if the column is a MultiIndex
        # Join only if the column name is not 'Mouse' or 'Condition'
        if "Patch" not in col and "Condition" not in col:
            new_columns.append('&'.join(map(str, col)).strip())
        else:
            # If 'Mouse' or 'Condition' is in the column, it is not joined with '&'
            new_columns.append(col[0])
    else:
        new_columns.append(col)

global_co_occurrence_df_single = global_co_occurrence_df.copy()
global_co_occurrence_df_single.columns = new_columns
global_co_occurrence_df_single = global_co_occurrence_df_single[[col for col in global_co_occurrence_df_single.columns if 'noid' not in col]]

# Melt the DataFrame
global_co_occurrence_melted = global_co_occurrence_df_single.melt(id_vars=['Patch', 'Condition'], var_name='Cell Combination', value_name='Frequency')
global_co_occurrence_melted

In [None]:
# Global Cell Co-Occurrence
selected_cell_types = sorted(global_co_occurrence_melted['Cell Combination'].unique())
selected_p_values = []

# Perform t-tests
print(f"p-value before correction:")
for ct in selected_cell_types:
    subset = global_co_occurrence_melted[global_co_occurrence_melted['Cell Combination'] == ct]
    group1 = subset[subset['Condition'] == 'WT']['Frequency']
    group2 = subset[subset['Condition'] == 'KO']['Frequency']

    t_stat, p_value = stats.ttest_ind(group1, group2, equal_var=False)
    print(f"{ct} has p value = {p_value:.3f}")
    selected_p_values.append(p_value)

# Filter the dataframe based on selected Cell Combinations
df_filtered = global_co_occurrence_melted[global_co_occurrence_melted['Cell Combination'].isin(selected_cell_types)]

# Plot the filtered data
fig, ax = plt.subplots(figsize=(45,10))
sns.boxplot(data=df_filtered, x='Cell Combination', y='Frequency', hue='Condition', palette='muted', boxprops=dict(alpha=.3), ax=ax, dodge=True,order=selected_cell_types)
sns.swarmplot(data=df_filtered, x='Cell Combination', y='Frequency', hue='Condition', palette='dark:black', size=1.0, dodge=True,order=selected_cell_types, ax=ax, edgecolor='gray', linewidth=0.5)

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:2], labels[:2], title="Groups", handletextpad=1, columnspacing=1, bbox_to_anchor=(1, 1), ncol=3, frameon=True)

pvals_corrected = stats.false_discovery_control(selected_p_values, method='bh')

print('-'*42)
print(f"p-values after correction:")

p_vals_corrected_dict = {}
yrange = ax.get_ylim()[1] - ax.get_ylim()[0]
for i, ct in enumerate(selected_cell_types):
    ax.text(i, yrange, f"p = {pvals_corrected[i]:.3f}", ha='center', fontsize=8, rotation=0)
    print(f"{ct} has p value = {pvals_corrected[i]:.3f}", flush=True)
    p_vals_corrected_dict[ct] = pvals_corrected[i]

for i in range(len(selected_cell_types) - 1):
    ax.axvline(i + 0.55, color='grey', linestyle='--', linewidth=0.5)

ax.set_ylabel("Frequency", fontsize=14)
ax.set_xlabel('')
plt.xticks(rotation=90)
plt.savefig("Global_Cell_co_occurance_CB.pdf", dpi = 300)
plt.show()

In [None]:
# Spot Cell Co-Occurrence
# Multi-index to single-index column
new_columns = []
for col in spot_co_occurrence_df.columns:
    if isinstance(col, tuple):  # This checks if the column is a MultiIndex
        # Join only if the column name is not 'Mouse' or 'Condition'
        if "Patch" not in col and "Condition" not in col:
            new_columns.append('&'.join(map(str, col)).strip())
        else:
            # If 'Mouse' or 'Condition' is in the column, it is not joined with '&'
            new_columns.append(col[0])
    else:
        new_columns.append(col)

spot_co_occurrence_df_single = spot_co_occurrence_df.copy()
spot_co_occurrence_df_single.columns = new_columns
spot_co_occurrence_df_single = spot_co_occurrence_df_single[[col for col in spot_co_occurrence_df_single.columns if 'noid' not in col]]

# Melt the DataFrame
spot_co_occurrence_melted = spot_co_occurrence_df_single.melt(id_vars=['Patch', 'Condition'], var_name='Cell Combination', value_name='Frequency')
spot_co_occurrence_melted

In [None]:
# Spot Cell Co-Occurrence
selected_cell_types = sorted(spot_co_occurrence_melted['Cell Combination'].unique())
selected_p_values = []

# Perform t-tests
print(f"p-value before correction: ")
for ct in selected_cell_types: # df_melted['CellType'].unique():
    subset = spot_co_occurrence_melted[spot_co_occurrence_melted['Cell Combination'] == ct]
    group1 = subset[subset['Condition'] == 'WT']['Frequency']
    group2 = subset[subset['Condition'] == 'KO']['Frequency']

    t_stat, p_value = stats.ttest_ind(group1, group2, equal_var=False)
    print(f"{ct} has p value = {p_value:.3f}")
    selected_p_values.append(p_value)

# Filter the dataframe based on selected CellTypes
df_filtered = spot_co_occurrence_melted[spot_co_occurrence_melted['Cell Combination'].isin(selected_cell_types)]

# Plot the filtered data
fig, ax = plt.subplots(figsize=(42,10))
sns.boxplot(data=df_filtered, x='Cell Combination', y='Frequency', hue='Condition', palette='muted', boxprops=dict(alpha=.3), ax=ax, dodge=True,order=selected_cell_types)
sns.swarmplot(data=df_filtered, x='Cell Combination', y='Frequency', hue='Condition', palette='dark:black', size=2.0, dodge=True,order=selected_cell_types, ax=ax, edgecolor='gray', linewidth=0.5)

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:2], labels[:2], title="Groups", handletextpad=1, columnspacing=1, bbox_to_anchor=(1, 1), ncol=3, frameon=True)

hot_pvals_corrected = stats.false_discovery_control(selected_p_values, method='bh')

print('-'*42)
print(f"p-values after correction:")

highlighted_comb = []
yrange = ax.get_ylim()[1] - ax.get_ylim()[0]
for i, ct in enumerate(selected_cell_types):
    ax.text(i, yrange, f"p = {pvals_corrected[i]:.3f}", ha='center', fontsize=8, rotation=0)
    print(f"{ct} in hot spots has p value = {hot_pvals_corrected[i]:.3f}", flush=True)
    if hot_pvals_corrected[i] < 0.05 and p_vals_corrected_dict[ct] >= 0.05:
        highlighted_comb.append(tuple(map(str.strip, ct.split('&'))))
        print(f"{ct} in whole tissue has p value = {p_vals_corrected_dict[ct]:.3f}", flush=True)
        print('*'*42)

for i in range(len(selected_cell_types) - 1):
    ax.axvline(i + 0.55, color='grey', linestyle='--', linewidth=0.5)

ax.set_ylabel("Frequency", fontsize=14)
ax.set_xlabel('')
plt.xticks(rotation=90)
plt.yticks(rotation=90)
plt.savefig("Spot_Cell_co_occurance_CB.pdf", dpi = 300)
plt.show()

In [None]:
circoplot_df1 = global_co_occurrence_df.sort_index(axis=1, level=[0,1]).drop(columns=['Patch'])
circoplot_df1 = circoplot_df1[[col for col in circoplot_df1.columns if 'noid' not in col]]
# Group by 'Condition' and calculate the mean of the other columns
circoplot_df1 = circoplot_df1.groupby('Condition').mean().reset_index()
circoplot_df1 = circoplot_df1.set_index('Condition')
circoplot_df1

In [None]:
global_cellfreq_df

In [None]:
# circoplot_df2 = global_cellfreq_df.drop(columns=['noid'])
circoplot_df2 = global_cellfreq_df.copy()
# Group by 'Condition' and calculate the mean of the other columns
circoplot_df2 = circoplot_df2.groupby('Condition').mean().reset_index()
circoplot_df2 = circoplot_df2.set_index('Condition')
circoplot_df2

In [None]:
patient_group = 'WT'
eco.create_circos_plot(circoplot_df1.loc[[patient_group]],
                       cell_type_colors_hex=None,
                       cell_abundance=circoplot_df2.loc[[patient_group]],
                       threshold=0.05,
                       edge_weights_scaler=10,
                       highlighted_edges=None,
                       node_weights_scaler=5000,
                       figure_size=(8,8),
                       save_path='WT_CB_circoplot.pdf')

In [None]:
circoplot_df1 = spot_co_occurrence_df.sort_index(axis=1, level=[0,1]).drop(columns=['Patch'])
circoplot_df1 = circoplot_df1[[col for col in circoplot_df1.columns if 'noid' not in col]]
# Group by 'Condition' and calculate the mean of the other columns
circoplot_df1 = circoplot_df1.groupby('Condition').mean().reset_index()
circoplot_df1 = circoplot_df1.set_index('Condition')
circoplot_df1

In [None]:
circoplot_df2 = spot_cellfreq_df.drop(columns=['Patch'])
# Group by 'Condition' and calculate the mean of the other columns
circoplot_df2 = circoplot_df2.groupby('Condition').mean().reset_index()
circoplot_df2 = circoplot_df2.set_index('Condition')
circoplot_df2

In [None]:
patient_group = 'KO'
eco.create_circos_plot(circoplot_df1.loc[[patient_group]],
                       cell_type_colors_hex=None,
                       cell_abundance=circoplot_df2.loc[[patient_group]],
                       threshold=0.05,
                       edge_weights_scaler=10,
                       highlighted_edges=highlighted_comb,
                       node_weights_scaler=5000,
                       figure_size=(8,8),
                       save_path='KO_CB_circoplot.pdf')

In [None]:
import scanpy as sc
import numpy as np
import pandas as pd
import anndata as ad
import seaborn as sns
from scipy import stats
import matplotlib.pyplot as plt

from mesa import ecospatial as eco

In [None]:
adata = sc.read_h5ad('RelnAll_Annotated.h5ad')
adata

In [None]:
from shapely import Polygon
import numpy as np

In [None]:
df = pd.read_csv("/media/duan/DuanLab_Data/openFISH/Reln/Reeler_reAnnotation_ROIs/KO1/CTX.csv", skiprows=2)
polygon1 = Polygon((df[df['Selection'] == 'Selection 1'][['X', 'Y']] / 0.325).to_numpy())

df = pd.read_csv("/media/duan/DuanLab_Data/openFISH/Reln/Reeler_reAnnotation_ROIs/KO2/CTX.csv", skiprows=2)
polygon2 = Polygon((df[df['Selection'] == 'Selection 1'][['X', 'Y']] / 0.325).to_numpy())

df = pd.read_csv("/media/duan/DuanLab_Data/openFISH/Reln/Reeler_reAnnotation_ROIs/WT1/CTX.csv", skiprows=2)
polygon3 = Polygon((df[df['Selection'] == 'Selection 1'][['X', 'Y']] / 0.325).to_numpy())

df = pd.read_csv("/media/duan/DuanLab_Data/openFISH/Reln/Reeler_reAnnotation_ROIs/WT2/CTX.csv", skiprows=2)
polygon4 = Polygon((df[df['Selection'] == 'Selection 1'][['X', 'Y']] / 0.325).to_numpy())

In [None]:
CTX = [False] * len(adata)

In [None]:
from shapely import Polygon,Point

In [None]:
i = 0
for _, row in adata.obs.iterrows():
    
    if row['Sample'] == 'KO1':
        if polygon1.contains(Point([row['x'],row['y']])):
            CTX[i] = True
            
    elif row['Sample'] == 'KO2':
        if polygon2.contains(Point([row['x'],row['y']])):
            CTX[i] = True
            
    elif row['Sample'] == 'WT1':
        if polygon3.contains(Point([row['x'],row['y']])):
            CTX[i] = True
            
    elif row['Sample'] == 'WT2':
        if polygon4.contains(Point([row['x'],row['y']])):
            CTX[i] = True
    
    i+=1

In [None]:
adata.obs['CTX'] = CTX

In [None]:
adata = adata[adata.obs['CTX'] == True].copy()

In [None]:
sc.pl.embedding(adata, basis = 'spatial', color = 'cell_type', size = 5)

In [None]:
KEEP_CELL_TYPE = ['Astro-TE', 'Endo', 'Ext L2/3', 'Ext L2/3 PIR-ENTl', 'Ext L4/5', 'Ext L6', 'Fibro', 'Inh Lamp5', 'Inh Sst', 'Inh Vip', 'Microglia',
                 'Mural', 'Oligo', 'OPC']

In [None]:
adata.obsm['spatial'] = adata.obsm['spatial'] * 0.325# Convert units to microns

In [None]:
adata.obs['Sample']

In [None]:
library_ids = ['KO1', 'KO2', 'WT1', 'WT2']

In [None]:
# Define the sequence of scales
scales = [1., 2., 4., 8., 16., 32., 64.]

mdi_results = eco.calculate_MDI(spatial_data=adata,
                                scales=scales,
                                library_key='Sample',
                                library_id=library_ids,
                                spatial_key='spatial',
                                cluster_key='cell_type',
                                random_patch=False,
                                plotfigs=False,
                                savefigs=False,
                                patch_kwargs={'random_seed': None, 'min_points':2},
                                other_kwargs={'metric': 'Shannon Diversity'})

In [None]:
# Add 'Condition' and 'Sample_id' to the columns
mdi_results['Condition'] = ' '
mdi_results['Sample_id'] = mdi_results.index
mdi_results.loc[mdi_results.index.str.contains('WT'), 'Condition'] = 'WT'
mdi_results.loc[mdi_results.index.str.contains('KO'), 'Condition'] = 'KO'
mdi_results.head()

In [None]:
df_melted = pd.melt(mdi_results, id_vars=['Sample_id', 'Condition'], value_vars=scales,
                    var_name='Scale', value_name='Diversity Value')
df_melted['sample'] = 'Tissue Sample'
df_melted

In [None]:
xrange = []
yrange = []
for region in adata.obs['Sample'].unique():
    spatial_value = adata[adata.obs['Sample']==region].obsm['spatial']
    xrange.append(spatial_value.max(axis=0)[0] - spatial_value.min(axis=0)[0])
    yrange.append(spatial_value.max(axis=0)[1] - spatial_value.min(axis=0)[1])
mean_xrange = np.mean(xrange)
std_xrange = np.std(xrange)
mean_yrange = np.mean(yrange)
std_yrange = np.std(yrange)

# Calculate mean and confidence interval
grouped = df_melted.groupby('Scale')
mean_values = grouped['Diversity Value'].mean()
conf_intervals = grouped['Diversity Value'].apply(lambda x: stats.sem(x) * stats.t.ppf((1 + 0.95) / 2., len(x)-1))

# Plotting using sns.lineplot
plt.figure(figsize=(6, 4))
ax = sns.lineplot(data=df_melted,
                  x='Scale',
                  y='Diversity Value',
                  style='sample',
                  markers=True,
                  estimator='mean',
                  err_style='bars',
                  errorbar=("ci", 95),
                  err_kws={"capsize":5.0}
                 )

# Annotating error bars with their value
for i, (scale, mean, ci) in enumerate(zip(mean_values.index, mean_values, conf_intervals)):
    ax.text(scale, mean + ci, f'{mean:.3f}±{ci:.3f}', color='black', ha='center', va='bottom')

# Drawing red dashed horizontal lines at half the maximum of x and y axes
mean_diversity_per_scale = df_melted.groupby('Scale')['Diversity Value'].mean()
y_sep = mean_diversity_per_scale.median()
x_sep = mean_diversity_per_scale.idxmax()

ax.axhline(y_sep, color='red', linestyle='--')
ax.axvline(x_sep, color='red', linestyle='--')
ax.get_legend().remove()

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xlabel('', fontsize=0)
plt.xticks(fontsize=12)
plt.ylabel(f"GDI", fontsize=16)
plt.yticks(fontsize=12)

# Add secondary x-axis
xtick_labels = [tick.get_text() for tick in ax.get_xticklabels()][2:-1]
scales = [int(label) for label in xtick_labels if label.strip() != '']
x_sizes = [mean_xrange / scale for scale in scales]
y_sizes = [mean_yrange / scale for scale in scales]
size_labels = [f"{int(x_size)}×{int(y_size)}" for x_size, y_size in zip(x_sizes, y_sizes)]
secax = ax.secondary_xaxis(location=-0.075)
secax.set_xticks(scales)
secax.set_xticklabels(size_labels)
secax.tick_params('x', length=0)
secax.spines['bottom'].set_linewidth(0)
secax.set_xlabel(f'Scale \n (Area μm²)', fontsize=12)

plt.title(f'GDI per Scale with 95% Confidence Intervals')
plt.grid(False)
fig = plt.gcf()
plt.show()

In [None]:
mdi_results

In [None]:
gdi_results = eco.calculate_GDI(spatial_data=adata,
                                scale=64,
                                library_key='Sample',
                                library_id=library_ids,
                                spatial_key='spatial',
                                cluster_key='cell_type',
                                hotspot=True,
                                restricted=False,
                                metric='Shannon Diversity')
gdi_results

In [None]:
gdi_results.to_csv("CTX_GDI.csv")

In [None]:
# Calculate DPI for hotspots
dpi_results = eco.calculate_DPI(spatial_data=adata,
                                scale=64.0,
                                library_key='Sample',
                                library_id=library_ids,
                                spatial_key='spatial',
                                cluster_key='cell_type',
                                hotspot=True,
                                metric='Shannon Diversity')
dpi_results

In [None]:
dpi_results.to_csv('CTX_DPI.csv')

In [None]:
global_cellfreq_df, global_co_occurrence_df = eco.spot_cellfreq(spatial_data=adata,
                                                                scale=64.0,
                                                                library_key='Sample',
                                                                library_id=library_ids,
                                                                spatial_key='spatial',
                                                                cluster_key='cell_type',
                                                                spots='global',
                                                                top=None,
                                                                selected_comb=None,
                                                                restricted=False,
                                                                metric='Shannon Diversity')

In [None]:
global_cellfreq_df

In [None]:
global_cellfreq_df['Condition'] = ' '
global_cellfreq_df.loc[global_cellfreq_df.index.str.contains('KO'), 'Condition'] = 'KO'
global_cellfreq_df.loc[global_cellfreq_df.index.str.contains('WT'), 'Condition'] = 'WT'

global_co_occurrence_subcols = global_co_occurrence_df.loc[:,global_co_occurrence_df.mean()>0.05].columns.tolist()
global_co_occurrence_df['Condition'] = ' '
global_co_occurrence_df['Patch'] = global_co_occurrence_df.index
global_co_occurrence_df.loc[global_co_occurrence_df.index.str.contains('KO'), 'Condition'] = 'KO'
global_co_occurrence_df.loc[global_co_occurrence_df.index.str.contains('WT'), 'Condition'] = 'WT'
global_co_occurrence_subcols.extend([('Condition',''),('Patch','')])

In [None]:
global_cellfreq_df

In [None]:
# Melt the dataframe for easier plotting and statistical analysis
global_cellfreq_df_melt = global_cellfreq_df.reset_index().melt(id_vars=['Sample', 'Condition'])
global_cellfreq_df_melt.columns = ['Sample', 'group', 'cell_type', 'Frequency']

global_cellfreq_df_melt

In [None]:
# Perform t-tests
selected_cell_types = sorted(adata.obs['cell_type'].unique())
selected_p_values = []
for ct in selected_cell_types:
    group1 = global_cellfreq_df_melt[(global_cellfreq_df_melt['cell_type'] == ct) & (global_cellfreq_df_melt['group'] == 'WT')]['Frequency']
    group2 = global_cellfreq_df_melt[(global_cellfreq_df_melt['cell_type'] == ct) & (global_cellfreq_df_melt['group'] == 'KO')]['Frequency']
    t_stat, p_value = stats.ttest_ind(group1, group2, equal_var=False)
    print(f"{ct} has p value of {p_value}")
    selected_p_values.append(p_value)

pvals_corrected = stats.false_discovery_control(selected_p_values, method='bh')
print('-'*42)
print(f"p-values after correction:")

# Plot
fig, ax = plt.subplots(figsize=(30,10))
sns.boxplot(data=global_cellfreq_df_melt, x='cell_type', y='Frequency', hue='group', palette='muted', boxprops=dict(alpha=.3), ax=ax, dodge=True, order=selected_cell_types)
sns.swarmplot(data=global_cellfreq_df_melt, x='cell_type', y='Frequency', hue='group', palette='dark:black', size=2.0, dodge=True, order=selected_cell_types, ax=ax, edgecolor='auto', linewidth=0.5)
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:2], labels[:2], title="Groups", handletextpad=1, columnspacing=1, bbox_to_anchor=(1, 1), ncol=3, frameon=True)
plt.xticks(rotation=90)

p_vals_corrected_dict = {}
yrange = ax.get_ylim()[1] - ax.get_ylim()[0]
for i, ct in enumerate(selected_cell_types):
    ax.text(i, yrange, f"p = {pvals_corrected[i]:.3f}", ha='center', fontsize=12, rotation=0)
    print(f"{ct} has p value = {pvals_corrected[i]:.3f}", flush=True)
    p_vals_corrected_dict[ct] = pvals_corrected[i]

for i in range(len(selected_cell_types) - 1):
    ax.axvline(i + 0.55, color='grey', linestyle='--', linewidth=0.5)

ax.set_ylabel("Frequency", fontsize=14)
ax.set_xlabel('')
plt.savefig("Global_Cell_Frequency_CTX.pdf", dpi = 300)
plt.show()

In [None]:
spot_cellfreq_df, spot_co_occurrence_df = eco.spot_cellfreq(spatial_data=adata,
                                                            scale=64.0,
                                                            library_key='Sample',
                                                            library_id=library_ids,
                                                            spatial_key='spatial',
                                                            cluster_key='cell_type',
                                                            spots='hot',
                                                            top=None,
                                                            selected_comb=None,
                                                            restricted=False,
                                                            metric='Shannon Diversity')

In [None]:
spot_cellfreq_df['Condition'] = ' '
spot_cellfreq_df.loc[spot_cellfreq_df.index.str.contains('KO'), 'Condition'] = 'KO'
spot_cellfreq_df.loc[spot_cellfreq_df.index.str.contains('WT'), 'Condition'] = 'WT'

spot_co_occurrence_subcols = spot_co_occurrence_df.loc[:,spot_co_occurrence_df.mean()>0.05].columns.tolist()
spot_co_occurrence_df['Condition'] = ' '
spot_co_occurrence_df['Patch'] = spot_co_occurrence_df.index
spot_co_occurrence_df.loc[spot_co_occurrence_df.index.str.contains('KO'), 'Condition'] = 'KO'
spot_co_occurrence_df.loc[spot_co_occurrence_df.index.str.contains('WT'), 'Condition'] = 'WT'
spot_co_occurrence_subcols.extend([('Condition',''),('Patch','')])

In [None]:
spot_cellfreq_df['Patch'] = spot_cellfreq_df.index

# Melt the DataFrame
spot_cellfreq_df_melt = spot_cellfreq_df.melt(id_vars=['Patch', 'Condition'], var_name='CellType', value_name='Frequency')

In [None]:
spot_cellfreq_df_melt

In [None]:
selected_cell_types = sorted(spot_cellfreq_df_melt['CellType'].unique())
selected_p_values = []

# Perform t-tests
print(f"p-value before correction:")
for ct in selected_cell_types: # df_melted['CellType'].unique():
    subset = spot_cellfreq_df_melt[spot_cellfreq_df_melt['CellType'] == ct]
    group1 = subset[subset['Condition'] == 'WT']['Frequency']
    group2 = subset[subset['Condition'] == 'KO']['Frequency']

    t_stat, p_value = stats.ttest_ind(group1, group2, equal_var=False)
    print(f"{ct} has p value = {p_value:.4f}")
    selected_p_values.append(p_value)

# Filter the dataframe based on selected CellTypes
df_filtered = spot_cellfreq_df_melt[spot_cellfreq_df_melt['CellType'].isin(selected_cell_types)]

# Plot the filtered data
fig, ax = plt.subplots(figsize=(30,10))
sns.boxplot(data=df_filtered, x='CellType', y='Frequency', hue='Condition', palette='muted', boxprops=dict(alpha=.3), ax=ax, dodge=True,order=selected_cell_types)
sns.swarmplot(data=df_filtered, x='CellType', y='Frequency', hue='Condition', palette='dark:black', size=3.0, dodge=True, order=selected_cell_types, ax=ax, edgecolor='auto', linewidth=0.5)

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:2], labels[:2], title="Groups", handletextpad=1, columnspacing=1, bbox_to_anchor=(1, 1), ncol=3, frameon=True)

spot_pvals_corrected = stats.false_discovery_control(selected_p_values, method='bh')
spot_pvals_corrected = dict(map(lambda i,j : (i,j) , selected_cell_types, spot_pvals_corrected))

print('-'*42)
print(f"p-values after correction: ")

yrange = ax.get_ylim()[1] - ax.get_ylim()[0]
for i, ct in enumerate(selected_cell_types):
    ax.text(i, yrange, f"p = {spot_pvals_corrected[ct]:.3f}", ha='center', fontsize=12, rotation=90)
    print(f"{ct} in hot spots has p value = {spot_pvals_corrected[ct]:.3f}", flush=True)
    if spot_pvals_corrected[ct] < 0.05 and p_vals_corrected_dict[ct] > 0.05:
        print(f"{ct} in whole tissue has p value = {p_vals_corrected_dict[ct]:.3f}", flush=True)
        print('*'*42)

for i in range(len(selected_cell_types) - 1):
    ax.axvline(i + 0.55, color='grey', linestyle='--', linewidth=0.5)

ax.set_ylabel("Frequency", fontsize=14)
ax.set_xlabel('')
plt.xticks(rotation=90)
plt.yticks(rotation=90)
plt.savefig("HotColdSpot_Cell_Frequency_CTX.pdf", dpi = 300)
plt.show()
plt.close(fig)

In [None]:
union_cols = set(global_co_occurrence_subcols).union(set(spot_co_occurrence_subcols))

In [None]:
# Make them have the same set of columns
global_co_occurrence_df = global_co_occurrence_df.reindex(columns=union_cols).fillna(0)
spot_co_occurrence_df = spot_co_occurrence_df.reindex(columns=union_cols).fillna(0)

In [None]:
# Global Cell Co-Occurrence
# Multi-index to single-index column
new_columns = []
for col in global_co_occurrence_df.columns:
    if isinstance(col, tuple):  # This checks if the column is a MultiIndex
        # Join only if the column name is not 'Mouse' or 'Condition'
        if "Patch" not in col and "Condition" not in col:
            new_columns.append('&'.join(map(str, col)).strip())
        else:
            # If 'Mouse' or 'Condition' is in the column, it is not joined with '&'
            new_columns.append(col[0])
    else:
        new_columns.append(col)

global_co_occurrence_df_single = global_co_occurrence_df.copy()
global_co_occurrence_df_single.columns = new_columns
global_co_occurrence_df_single = global_co_occurrence_df_single[[col for col in global_co_occurrence_df_single.columns if 'noid' not in col]]

# Melt the DataFrame
global_co_occurrence_melted = global_co_occurrence_df_single.melt(id_vars=['Patch', 'Condition'], var_name='Cell Combination', value_name='Frequency')
global_co_occurrence_melted

In [None]:
# Global Cell Co-Occurrence
selected_cell_types = sorted(global_co_occurrence_melted['Cell Combination'].unique())
selected_p_values = []

# Perform t-tests
print(f"p-value before correction:")
for ct in selected_cell_types:
    subset = global_co_occurrence_melted[global_co_occurrence_melted['Cell Combination'] == ct]
    group1 = subset[subset['Condition'] == 'WT']['Frequency']
    group2 = subset[subset['Condition'] == 'KO']['Frequency']

    t_stat, p_value = stats.ttest_ind(group1, group2, equal_var=False)
    print(f"{ct} has p value = {p_value:.3f}")
    selected_p_values.append(p_value)

# Filter the dataframe based on selected Cell Combinations
df_filtered = global_co_occurrence_melted[global_co_occurrence_melted['Cell Combination'].isin(selected_cell_types)]

# Plot the filtered data
fig, ax = plt.subplots(figsize=(45,10))
sns.boxplot(data=df_filtered, x='Cell Combination', y='Frequency', hue='Condition', palette='muted', boxprops=dict(alpha=.3), ax=ax, dodge=True,order=selected_cell_types)
sns.swarmplot(data=df_filtered, x='Cell Combination', y='Frequency', hue='Condition', palette='dark:black', size=1.0, dodge=True,order=selected_cell_types, ax=ax, edgecolor='gray', linewidth=0.5)

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:2], labels[:2], title="Groups", handletextpad=1, columnspacing=1, bbox_to_anchor=(1, 1), ncol=3, frameon=True)

pvals_corrected = stats.false_discovery_control(selected_p_values, method='bh')

print('-'*42)
print(f"p-values after correction:")

p_vals_corrected_dict = {}
yrange = ax.get_ylim()[1] - ax.get_ylim()[0]
for i, ct in enumerate(selected_cell_types):
    ax.text(i, yrange, f"p = {pvals_corrected[i]:.3f}", ha='center', fontsize=8, rotation=0)
    print(f"{ct} has p value = {pvals_corrected[i]:.3f}", flush=True)
    p_vals_corrected_dict[ct] = pvals_corrected[i]

for i in range(len(selected_cell_types) - 1):
    ax.axvline(i + 0.55, color='grey', linestyle='--', linewidth=0.5)

ax.set_ylabel("Frequency", fontsize=14)
ax.set_xlabel('')
plt.xticks(rotation=90)
plt.savefig("Global_Cell_co_occurance_CTX.pdf", dpi = 300)
plt.show()

In [None]:
# Spot Cell Co-Occurrence
# Multi-index to single-index column
new_columns = []
for col in spot_co_occurrence_df.columns:
    if isinstance(col, tuple):  # This checks if the column is a MultiIndex
        # Join only if the column name is not 'Mouse' or 'Condition'
        if "Patch" not in col and "Condition" not in col:
            new_columns.append('&'.join(map(str, col)).strip())
        else:
            # If 'Mouse' or 'Condition' is in the column, it is not joined with '&'
            new_columns.append(col[0])
    else:
        new_columns.append(col)

spot_co_occurrence_df_single = spot_co_occurrence_df.copy()
spot_co_occurrence_df_single.columns = new_columns
spot_co_occurrence_df_single = spot_co_occurrence_df_single[[col for col in spot_co_occurrence_df_single.columns if 'noid' not in col]]

# Melt the DataFrame
spot_co_occurrence_melted = spot_co_occurrence_df_single.melt(id_vars=['Patch', 'Condition'], var_name='Cell Combination', value_name='Frequency')
spot_co_occurrence_melted

In [None]:
# Spot Cell Co-Occurrence
selected_cell_types = sorted(spot_co_occurrence_melted['Cell Combination'].unique())
selected_p_values = []

# Perform t-tests
print(f"p-value before correction: ")
for ct in selected_cell_types: # df_melted['CellType'].unique():
    subset = spot_co_occurrence_melted[spot_co_occurrence_melted['Cell Combination'] == ct]
    group1 = subset[subset['Condition'] == 'WT']['Frequency']
    group2 = subset[subset['Condition'] == 'KO']['Frequency']

    t_stat, p_value = stats.ttest_ind(group1, group2, equal_var=False)
    print(f"{ct} has p value = {p_value:.3f}")
    selected_p_values.append(p_value)

# Filter the dataframe based on selected CellTypes
df_filtered = spot_co_occurrence_melted[spot_co_occurrence_melted['Cell Combination'].isin(selected_cell_types)]

# Plot the filtered data
fig, ax = plt.subplots(figsize=(42,10))
sns.boxplot(data=df_filtered, x='Cell Combination', y='Frequency', hue='Condition', palette='muted', boxprops=dict(alpha=.3), ax=ax, dodge=True,order=selected_cell_types)
sns.swarmplot(data=df_filtered, x='Cell Combination', y='Frequency', hue='Condition', palette='dark:black', size=2.0, dodge=True,order=selected_cell_types, ax=ax, edgecolor='gray', linewidth=0.5)

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[:2], labels[:2], title="Groups", handletextpad=1, columnspacing=1, bbox_to_anchor=(1, 1), ncol=3, frameon=True)

hot_pvals_corrected = stats.false_discovery_control(selected_p_values, method='bh')

print('-'*42)
print(f"p-values after correction:")

highlighted_comb = []
yrange = ax.get_ylim()[1] - ax.get_ylim()[0]
for i, ct in enumerate(selected_cell_types):
    ax.text(i, yrange, f"p = {pvals_corrected[i]:.3f}", ha='center', fontsize=8, rotation=0)
    print(f"{ct} in hot spots has p value = {hot_pvals_corrected[i]:.3f}", flush=True)
    if hot_pvals_corrected[i] < 0.05 and p_vals_corrected_dict[ct] >= 0.05:
        highlighted_comb.append(tuple(map(str.strip, ct.split('&'))))
        print(f"{ct} in whole tissue has p value = {p_vals_corrected_dict[ct]:.3f}", flush=True)
        print('*'*42)

for i in range(len(selected_cell_types) - 1):
    ax.axvline(i + 0.55, color='grey', linestyle='--', linewidth=0.5)

ax.set_ylabel("Frequency", fontsize=14)
ax.set_xlabel('')
plt.xticks(rotation=90)
plt.yticks(rotation=90)
plt.savefig("Spot_Cell_co_occurance_CTX.pdf", dpi = 300)
plt.show()

In [None]:
circoplot_df1 = global_co_occurrence_df.sort_index(axis=1, level=[0,1]).drop(columns=['Patch'])
circoplot_df1 = circoplot_df1[[col for col in circoplot_df1.columns if 'noid' not in col]]
# Group by 'Condition' and calculate the mean of the other columns
circoplot_df1 = circoplot_df1.groupby('Condition').mean().reset_index()
circoplot_df1 = circoplot_df1.set_index('Condition')
circoplot_df1

In [None]:
global_cellfreq_df

In [None]:
# circoplot_df2 = global_cellfreq_df.drop(columns=['noid'])
circoplot_df2 = global_cellfreq_df.copy()
# Group by 'Condition' and calculate the mean of the other columns
circoplot_df2 = circoplot_df2.groupby('Condition').mean().reset_index()
circoplot_df2 = circoplot_df2.set_index('Condition')
circoplot_df2

In [None]:
patient_group = 'WT'
eco.create_circos_plot(circoplot_df1.loc[[patient_group]],
                       cell_type_colors_hex=None,
                       cell_abundance=circoplot_df2.loc[[patient_group]],
                       threshold=0.05,
                       edge_weights_scaler=10,
                       highlighted_edges=None,
                       node_weights_scaler=5000,
                       figure_size=(8,8),
                       save_path='WT_CTX_circoplot.pdf')

In [None]:
circoplot_df1 = spot_co_occurrence_df.sort_index(axis=1, level=[0,1]).drop(columns=['Patch'])
circoplot_df1 = circoplot_df1[[col for col in circoplot_df1.columns if 'noid' not in col]]
# Group by 'Condition' and calculate the mean of the other columns
circoplot_df1 = circoplot_df1.groupby('Condition').mean().reset_index()
circoplot_df1 = circoplot_df1.set_index('Condition')
circoplot_df1

In [None]:
circoplot_df2 = spot_cellfreq_df.drop(columns=['Patch'])
# Group by 'Condition' and calculate the mean of the other columns
circoplot_df2 = circoplot_df2.groupby('Condition').mean().reset_index()
circoplot_df2 = circoplot_df2.set_index('Condition')
circoplot_df2

In [None]:
patient_group = 'KO'
eco.create_circos_plot(circoplot_df1.loc[[patient_group]],
                       cell_type_colors_hex=None,
                       cell_abundance=circoplot_df2.loc[[patient_group]],
                       threshold=0.05,
                       edge_weights_scaler=10,
                       highlighted_edges=highlighted_comb,
                       node_weights_scaler=5000,
                       figure_size=(8,8),
                       save_path='KO_CTX_circoplot.pdf')

In [None]:
import scanpy as sc

In [None]:
adata = sc.read_h5ad("RelnAll_Annotated.h5ad")

In [None]:
DG = adata[adata.obs['cell_type'].isin(['Ext DG'])].copy()

In [None]:
WT1_points = DG[(DG.obs['Sample'] == 'WT1')].obsm['spatial_raw'].copy()

In [None]:
WT1_points

In [None]:
import networkx as nx
import numpy as np
from itertools import combinations

def points_to_network(points, threshold):

    G = nx.Graph()
    n = len(points)
    
    for i in range(n):
        G.add_node(i, pos=points[i])

    for i, j in combinations(range(n), 2):
        dist = np.linalg.norm(np.array(points[i]) - np.array(points[j]))
        if dist <= threshold:
            G.add_edge(i, j)
    
    return G
    

In [None]:
threshold = 100

G = points_to_network(WT1_points, threshold)


clustering_coefficient = nx.average_clustering(G)

In [None]:
WT2_points = CA[(CA.obs['Sample'] == 'WT2')].obsm['spatial_raw'].copy()

In [None]:
WT2_points

In [None]:
threshold = 100

G = points_to_network(WT2_points, threshold)

# 计算聚类系数
clustering_coefficient = nx.average_clustering(G)

In [None]:
KO1_points = CA[(CA.obs['Sample'] == 'KO1')].obsm['spatial_raw'].copy()

In [None]:
KO1_points

In [None]:
threshold = 100

G = points_to_network(KO1_points, threshold)

# 计算聚类系数
clustering_coefficient = nx.average_clustering(G)

In [None]:
KO2_points = CA[(CA.obs['Sample'] == 'KO2')].obsm['spatial_raw'].copy()

In [None]:
KO2_points

In [None]:
threshold = 100

G = points_to_network(KO2_points, threshold)

# 计算聚类系数
clustering_coefficient = nx.average_clustering(G)

In [None]:
import scanpy as sc

In [None]:
adata = sc.read_h5ad("RelnAll_Annotated.h5ad")
adata

In [None]:
sc.pl.violin(adata, keys = ['n_genes_by_counts'], groupby = 'Sample', inner = 'box', stripplot=False, save = 'WTKO_n_genes_by_counts.pdf')

In [None]:
sc.pl.violin(adata, keys = ['total_counts'], groupby = 'Sample', inner = 'box', stripplot=False, save = 'WTKO_totalcounts.pdf')

In [None]:
adata.uns['Sample_colors'] = ['#FF0000', '#FF0000', '#0000FF', '#0000FF']

In [None]:
marker_dict = {
    'Astro-CB': ['Slc1a3'],
    'Astro-NT': ['Agt'],
    'Astro-TE': ['Aqp4'],
    'CHOR': ['Ecrg4'],
    'Endo': ['Cldn5'],
    'Epen': ['Tmem212'],
    'Ext CA1': ['Rasgrp1'],
    'Ext CA2': ['Adcy1'],
    'Ext CA3': ['Cpne4'],
    'Ext CB': ['Cbln1'],
    'Ext DG': ['Prox1'],
    'Ext IC': ['Tcf7l2'],
    'Ext L2/3': ['Satb2'],
    'Ext L2/3 PIR-ENTl': ['Hs3st2'],
    'Ext L4/5': ['Cck'],
    'Ext L6': ['Rprm'],
    'Ext MB-HY': ['Slc17a6'],
    'Ext NLOT': ['Synpr'],
    'Ext OB': ['Grm8'],
    'Ext TH': ['Synpo2'],
    'Ext UBC': ['Nfib'],
    'Fibro': ['Igf2'],
    'Inh CB': ['Pvalb'],
    'Inh CB Purkinje': ['Calb1'],
    'Inh Lamp5': ['Lamp5'],
    'Inh MB-HY': ['Gad2'],
    'Inh OB-STR-CTX': ['Meis2'],
    'Inh RT': ['Trh'],
    'Inh STR D1': ['Drd1'],
    'Inh STR D2': ['Drd2'],
    'Inh Sst': ['Sst'],
    'Inh Vip': ['Vip'],
    'Microglia': ['P2ry12'],
    'Mural': ['Pdgfrb'],
    'OPC': ['Pdgfra'],
    'Oligo': ['Mog']
}

In [None]:
import cosg

In [None]:
groupby='cell_type'
cosg.cosg(adata, key_added='cosg_cell_type',
          # use_raw=False, layer='log1p', ## e.g., if you want to use the log1p layer in adata
          mu=1,
          expressed_pct=0.1,
          remove_lowly_expressed=True,
          n_genes_user=10,
          groupby=groupby)

import pandas as pd
pd.DataFrame(adata.uns["cosg_cell_type"]["names"]).to_csv('cell_type_cosg.csv')

In [None]:
groupby='cell_type'

In [None]:
sc.tl.dendrogram(adata, groupby = 'cell_type')

In [None]:
sc.pl.dotplot(adata, {x:marker_dict[x] for x in adata.uns['dendrogram_cell_type']['categories_ordered']},
             groupby=groupby,
             dendrogram=True,
              swap_axes=False,
             standard_scale='var',
             cmap='RdYlBu_r',
             mean_only_expressed = True,
             linewidth = 0, marker = "o", figsize=(15,8),
             save = 'All_cell_type_marker.pdf'
             )