In [None]:
import pickle
import seaborn as sns
import imageio as io
import pandas as pd
import os
import matplotlib.pyplot as plt
import numpy as np
import tifffile
from tqdm.notebook import tqdm
import pathlib
import json
import glob
import PIL
import scanpy as sc

In [None]:
output_folder = r'data\240719AnalysisDAM_TERM\fig_0925\Fig1A'

In [None]:
adata_del2=sc.read(r'data\240719AnalysisDAM_TERM\adata_del.h5ad')
adata=adata_del2.copy()

# Extended Fig1a

In [None]:
sc.set_figure_params(figsize=(5,5))
sc.pl.umap(
    adata, 
    color=['class1'],
    add_outline=False,
    legend_loc='on data', 
    legend_fontsize=10,
    legend_fontoutline=2,
    frameon=False,
    size=0.5,
    vmax=1,
    show=False,
    use_raw=False,
    cmap='coolwarm',
    #save='ExtFig1A_global_umap_class.pdf'
)

# output_path = os.path.join(output_folder, 'ExtFig1A_global_umap_class.pdf')
# plt.savefig(output_path, bbox_inches='tight', dpi=300)
plt.show()

# Extended Fig1d

In [None]:
from matplotlib.colors import ListedColormap
import random

cmap = plt.get_cmap('Accent')
colors = cmap(np.linspace(0, 1, cmap.N))
np.random.shuffle(colors)
random_cmap = ListedColormap(colors)

unique_categories = adata.obs['subclass_name1'].unique()

random_colors = {cat: "#"+''.join([random.choice('0123456789ABCDEF') for _ in range(6)]) for cat in unique_categories}

print("Colormap for each cluster:")
for cat, color in random_colors.items():
    print(f"'{cat}': '{color}'")
sc.set_figure_params(figsize=(5, 5))    
sc.pl.umap(
    adata, 
    color=['subclass_name1'],
    palette=random_colors,  
    add_outline=False,
    legend_fontsize=10,
    legend_fontoutline=2,
    frameon=False,
    size=0.5,
    vmax=1,
    show=False,
    use_raw=False,
    save='ExtFig1B_global_umap_subclass_name1_random.pdf'
)


# Extended Fig1b

In [None]:
adata.obs['ddnres5'] = adata.obs['ddnres5'].astype('category')
adata.layers['scaled'] = sc.pp.scale(adata, zero_center=True, copy=True).X
adata.uns['log1p']["base"] = None
sc.tl.rank_genes_groups(adata, 'class1', method='t-test')
result = adata.uns['rank_genes_groups']
groups = result['names'].dtype.names
df_scdata_class1 = pd.DataFrame({group + '_' + key[:1]: result[key][group] for group in groups for key in ['names','logfoldchanges','pvals','pvals_adj']})
#df_scdata.to_csv('f9_adata_del2_ddnres5_rgg.csv')
sc.pl.rank_genes_groups_dotplot(
    adata,
    n_genes=4,
    layer='scaled',
    vmax=1,   
    vmin=-1,  
    cmap='coolwarm',
    show=True,
    save='ExtFig1B_dotplot_class1_zscore.pdf'
)

# Extended Fig1f

In [None]:
sc.tl.dendrogram(adata, 'subclass_name1')
adata.uns['log1p']["base"] = None
sc.tl.rank_genes_groups(adata, 'subclass_name1', method='t-test')
result = adata.uns['rank_genes_groups']
groups = result['names'].dtype.names
df_scdata_subclass_name1 = pd.DataFrame({group + '_' + key[:1]: result[key][group] for group in groups for key in ['names','logfoldchanges','pvals','pvals_adj']})
#df_scdata.to_csv('f9_adata_del2_ddnres5_rgg.csv')
#sc.pl.rank_genes_groups_dotplot(adata, n_genes=4)

from collections import OrderedDict
sc.pp.scale(adata, zero_center=True, max_value=None)
#adata.layers['scaled'] = adata.X.copy()
genes_to_exclude = {
    'App', 'Syp', 'Zfp36l2', 'Psap', 'Prkar1b', 'Gaa', 'Nefl', 'Cd47', 'Ctsb',
    'Vip', 'Dvl1', 'Gfap', 'Sqstm1', 'Cd3e', 'Cd8a', 'Clu', 'Hp', 'Pten', 'Cxcl1', 'Clta',
    'Gas7', 'Grin2b', 'Ppp3cb', 'Tubb3', 'Ppfia2', 'Lyst'
}

key = 'rank_genes_groups'
n_genes = 4
groups = adata.uns[key]['names'].dtype.names

genes_dict = OrderedDict()
for group in groups:
    top_genes = adata.uns[key]['names'][group][:n_genes]
    adjusted_genes = [gene for gene in top_genes if gene not in genes_to_exclude]
    genes_dict[group] = adjusted_genes

cluster_order = list(genes_dict.keys())
adata.obs['subclass_name1'] = adata.obs['subclass_name1'].astype('category')
adata.obs['subclass_name1'] = adata.obs['subclass_name1'].cat.reorder_categories(cluster_order)

sc.pl.dotplot(
    adata,
    var_names=genes_dict,
    groupby='subclass_name1',
    layer='scaled',
    vmax=1,  
    vmin=-1, 
    cmap='coolwarm',
    show=True,
    save='ExtFig1D_dotplot_subclass_name1_zscore.pdf'
    
)


# Extended Fig1e

In [None]:
from scipy.optimize import linear_sum_assignment

class_means = adata.to_df().groupby(adata.obs['subclass_name']).mean()
class_name1_means = adata.to_df().groupby(adata.obs['subclass_name1']).mean()

# common_genes = class_means.columns.intersection(class_name1_means.columns)
# class_means = class_means[common_genes]
# class_name1_means = class_name1_means[common_genes]

correlation_matrix = pd.DataFrame(index=class_name1_means.index, columns=class_means.index)

for subclass_name1_cluster in class_name1_means.index:
    for subclass_name_cluster in class_means.index:
        profile_subclass_name1 = class_name1_means.loc[subclass_name1_cluster]
        profile_subclass_name = class_means.loc[subclass_name_cluster]
        correlation = profile_subclass_name1.corr(profile_subclass_name)
        correlation_matrix.loc[subclass_name1_cluster, subclass_name_cluster] = correlation

correlation_matrix = correlation_matrix.astype(float)

top3_subclass_clusters = {}
unique_subclass_clusters = set()

for subclass_name1_cluster in correlation_matrix.index:
    correlations = correlation_matrix.loc[subclass_name1_cluster]
    sorted_correlations = correlations.sort_values(ascending=False)

    top3 = sorted_correlations.head(3)
    top3_clusters = top3.index.tolist()
    top3_subclass_clusters[subclass_name1_cluster] = top3_clusters
    unique_subclass_clusters.update(top3_clusters)

unique_subclass_clusters = list(unique_subclass_clusters)

heatmap_data = pd.DataFrame(index=correlation_matrix.index, columns=unique_subclass_clusters)

for subclass_name1_cluster in correlation_matrix.index:
    for subclass_cluster in top3_subclass_clusters[subclass_name1_cluster]:
        corr_value = correlation_matrix.loc[subclass_name1_cluster, subclass_cluster]
        heatmap_data.loc[subclass_name1_cluster, subclass_cluster] = corr_value

heatmap_data = heatmap_data.astype(float)
heatmap_data = heatmap_data.fillna(-1.0)

avg_correlations = heatmap_data.mean(axis=0)
sorted_subclass_clusters = avg_correlations.sort_values(ascending=False).index
heatmap_data = heatmap_data[sorted_subclass_clusters]

cost_matrix = -heatmap_data.values 
row_ind, col_ind = linear_sum_assignment(cost_matrix)

ordered_heatmap_data = heatmap_data.iloc[row_ind, :].iloc[:, col_ind]
ordered_heatmap_data.index = heatmap_data.index[row_ind]
ordered_heatmap_data.columns = heatmap_data.columns[col_ind]

plt.figure(figsize=(15, 12))
vmin = ordered_heatmap_data.min().min()
vmax = ordered_heatmap_data.max().max()
sns.heatmap(
    ordered_heatmap_data,
    annot=False,
    fmt=".2f",
    cmap="coolwarm",
    linewidths=0.5,
    linecolor='gray',
    cbar_kws={'label': 'Pearson Correlation'},
    vmin=vmin,
    vmax=vmax
)

plt.xlabel('ABC Atlas Clusters')
plt.ylabel('MERFISH Clusters')
plt.title('Pearson Correlations between MERFISH Clusters and ABC Atlas Clusters')
plt.tight_layout()
# output_path = os.path.join(output_folder, 'ExtFig1E_pearson_corr_ABCA.pdf')
# plt.savefig(output_path, bbox_inches='tight', dpi=300)
plt.show()

# Cluster_mapping

In [None]:
#'class1','class_name1','subclass_name1'
cluster_mapping = adata.obs['class1'].to_dict()
adata_whole.obs['class1'] = adata_whole.obs_names.map(cluster_mapping)
adata_whole.obs['class1'].fillna('unknown', inplace=True)

cluster_mapping = adata.obs['class_name1'].to_dict()
adata_whole.obs['class_name1'] = adata_whole.obs_names.map(cluster_mapping)
adata_whole.obs['class_name1'].fillna('unknown', inplace=True)

cluster_mapping = adata.obs['subclass_name1'].to_dict()
adata_whole.obs['subclass_name1'] = adata_whole.obs_names.map(cluster_mapping)
adata_whole.obs['subclass_name1'].fillna('unknown', inplace=True)

In [None]:
adata_app_wh=adata_whole[adata_whole.obs['batch']=='APP_1']
adata_te4_wh=adata_whole[adata_whole.obs['batch']=='TE4_3']
adata_wt_wh=adata_whole[adata_whole.obs['batch']=='WT_1']
adata_e4_wh=adata_whole[adata_whole.obs['batch']=='E4_2']

adata_app_mg=adata[adata.obs['batch']=='APP_1']
adata_te4_mg=adata[adata.obs['batch']=='TE4_3']
adata_wt_mg=adata[adata.obs['batch']=='WT_1']
adata_e4_mg=adata[adata.obs['batch']=='E4_2']

# Ploting

In [None]:

custom_colors = {
    'OPC': '#f58231',  # Orange
    'Oligo': '#3cb44b',  # Green
    'Macro': '#ff1493',  # Yellow#ffe119
    'Astro': '#4363d8',  # Blue
    'Microglia': '#ffe119',  # Pink#ff1493
    'ExN': '#e6194B',  # Bright Red
    'InN': '#bfef45',  # Lime
    'VLMC': '#42d4f4',  # Sky Blue
    'Endo': '#fabed4',  # #ffd8b1 – Peach,#fabed4 – Light Pink
    'SMC_Peri': '#3cb44b'  # Olive
}

unique_classes = adata_del2.obs['class'].cat.categories
adata_del2.uns['class_colors'] = [custom_colors.get(cls, '#000000') for cls in unique_classes]  # Default black for missing classes

sc.set_figure_params(figsize=(5, 5))
sc.pl.umap(
    adata_del2, 
    color='class',
    add_outline=True,
    legend_loc='on data',  # Legend position
    legend_fontsize=10,
    legend_fontoutline=2,
    frameon=False,
    size=1,
    show=False,
    use_raw=False
)

output_folder=r'X:\GV1Backup\AM\data\240719AnalysisDAM_TERM\fig_0904'
output_path = f'{output_folder}/global_umap_class_legendondata.tif'
plt.savefig(output_path, bbox_inches='tight', dpi=300)

plt.show()


# Fig1a

In [None]:
 custom_colors = {
        'L2/3 IT CTX Glut': '#e6194B',
        'CA1-ProS Glut': '#3cb44b',
        'L5 ET CTX Glut': '#ffe119',
        'L6 IT CTX Glut': '#4363d8',
        'CA3 Glut': '#f58231',
        'L4/5 IT CTX Glut': '#911eb4',
        'L6b CTX Glut': '#42d4f4',
        'DG Glut': '#f032e6',
        'CNU-HYa Glut': '#bfef45',
        'TH Prkcd Grin2c Glut': '#fabed4',
        'MB Glut': '#469990',
        'Sncg Gaba': '#dcbeff',
        'Vip Gaba': '#9A6324',
        'Pvalb Gaba': '#fffac8',
        'Sst Gaba': '#800000'
}  

def plot_cluster_scdata_str_diff(scdata, cmap, clusters=['DAM_1','DAM_2'], transpose=1, flipx=1, flipy=1, tag='cluster', key='X_spatial'):
    
    unique_clusters = [cluster for cluster in np.unique(scdata.obs[tag]) if cluster in clusters]
    x, y = (np.array(scdata.obsm[key]) * [flipx, flipy])[:, ::transpose].T
    plt.scatter(x, y, c='#E0E0E0', s=2, marker='.') 

    
    for cluster in unique_clusters:
        cluster_ = str(cluster)
        inds = scdata.obs[tag] == cluster_
        x_ = x[inds]
        y_ = y[inds]
        col = custom_colors.get(cluster_, 'grey')  

        if cluster_ == 'outline':
            s = 1  
        else:
            s = 40  
        plt.scatter(x_, y_, c=col, s=16, marker='.', label=cluster_)
    
    plt.grid(False)
    plt.axis("off")
    plt.axis("equal")
    plt.legend(loc='center right', bbox_to_anchor=(-0.1, 0.5), prop={'size': 12}) 
    plt.tight_layout()
    return plt.gcf()

    
fig = plt.figure(figsize=(10, 8), facecolor="white")
fig = plot_cluster_scdata_str_diff(adata_wt_wh, custom_colors, clusters=['L2/3 IT CTX Glut',
                                                           'CA1-ProS Glut',
                                                           #'L2/3 IT PIR-ENTl Glut',
                                                           'L5 ET CTX Glut',
                                                           #'L2/3 IT RSP Glut',
                                                           'L6 IT CTX Glut',
                                                           'CA3 Glut',
                                                           'L4/5 IT CTX Glut',
                                                           'L6b CTX Glut',
                                                           #'L6 CT CTX Glut',
                                                           'DG Glut',
                                                           #'MH Tac2 Glut',
                                                           'CNU-HYa Glut',
                                                           #'LA-BLA-BMA-PA Glut',
                                                           #'PF Fzd5 Glut',
                                                           'TH Prkcd Grin2c Glut',
                                                           'MB Glut',
                                                           #'CLA-EPd-CTX Car3 Glut',
                                                           #'STN-PSTN Pitx2 Glut',
                                                           #'TRS-BAC Sln Glut',
                                                           #'VMH Nr5a1 Glut',
                                                           'Sncg Gaba',
                                                           #'HY/MB/CNU-Hya GABA',
                                                           #'STR D2 Gaba',
                                                           'Vip Gaba',
                                                           'Pvalb Gaba',
                                                           'Sst Gaba',
                                                           #'MB Dopa'
                                                         ], transpose=1, flipx=1, flipy=1, tag='subclass_name1', key='X_spatial')
plt.savefig(f'{output_folder}\\global_spatial_neuros.tif', format='tif', dpi=300, bbox_inches='tight')
# plt.savefig(f'{output_folder}\\global_spatial_neuros.pdf', format='pdf', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
 
custom_colors = {
        'VLMC': '#FFEA00',
        'Endo': '#f032e6',
        'Macro': '#4363d8',
        'SMC': '#469990',
        'Astroependymal': '#3cb44b',
        'Peri': '#bfef45',
        'Microglia': 'black',
        'OPC': '#911eb4',
        'Ependymal': '#FF4500',
        'Oligo': '#42d4f4',
        'Astro': '#fabed4'
    
}  

def plot_cluster_scdata_str_diff(scdata, cmap, clusters=['DAM_1','DAM_2'], transpose=1, flipx=1, flipy=1, tag='cluster', key='X_spatial'):
    
    unique_clusters = [cluster for cluster in np.unique(scdata.obs[tag]) if cluster in clusters]
    x, y = (np.array(scdata.obsm[key]) * [flipx, flipy])[:, ::transpose].T
    plt.scatter(x, y, c='#E0E0E0', s=2, marker='.') 

    
    for cluster in unique_clusters:
        cluster_ = str(cluster)
        inds = scdata.obs[tag] == cluster_
        x_ = x[inds]
        y_ = y[inds]
        col = custom_colors.get(cluster_, 'grey')  

        if cluster_ == 'outline':
            s = 1 
        else:
            s = 40  
        plt.scatter(x_, y_, c=col, s=16, marker='.', label=cluster_)
    
    plt.grid(False)
    plt.axis("off")
    plt.axis("equal")
    plt.legend(loc='center right', bbox_to_anchor=(-0.1, 0.5), prop={'size': 12})  
    plt.tight_layout()
    return plt.gcf()

    
fig = plt.figure(figsize=(10, 8), facecolor="white")
fig = plot_cluster_scdata_str_diff(adata_wt_wh, custom_colors, clusters=['VLMC',
                                                           'Endo',
                                                           'Macro',
                                                           'SMC',
                                                           'Astro',
                                                           'Astroependymal',
                                                           'Peri',
                                                           'Microglia',
                                                           'OPC',
                                                           'Ependymal',
                                                           'Oligo'
                                                         ], transpose=1, flipx=-1, flipy=1, tag='subclass_name1', key='X_spatial')
plt.savefig(f'{output_folder}\\global_spatial_imm.tif', format='tif', dpi=300, bbox_inches='tight')
plt.savefig(f'{output_folder}\\global_spatial_imm.pdf', format='pdf', dpi=300, bbox_inches='tight')
plt.show()


# Extended Fig1c

In [None]:

batch_to_genotype = {
    'APP_1': 'APP', 'APP_2': 'APP', 'APP_3': 'APP', 'APP_4': 'APP',
    'E4_1': 'E4', 'E4_2': 'E4',
    'TE4_1': 'TE4', 'TE4_2': 'TE4', 'TE4_3': 'TE4', 'TE4_4': 'TE4',
    'WT_1': 'WT', 'WT_2': 'WT'
}

adata.obs['geno_type'] = adata.obs['batch'].map(batch_to_genotype)

all_data = []

for cell_class in adata.obs['class1'].unique():
    adata_temp = adata[adata.obs['class1'] == cell_class]
    adata_temp.obs['geno_type'] = adata_temp.obs['batch'].map(batch_to_genotype)
    total_counts_batch = adata.obs['batch'].value_counts()
    temp_counts_batch = adata_temp.obs['batch'].value_counts()
    counts_df_batch = pd.DataFrame({
        'Total': total_counts_batch,
        'Temp': temp_counts_batch
    }).fillna(0)

    counts_df_batch['Percentage'] = (counts_df_batch['Temp'] / counts_df_batch['Total']) * 100
    counts_df_batch['geno_type'] = counts_df_batch.index.map(batch_to_genotype)
    genotype_stats = counts_df_batch.groupby('geno_type').agg(
        Mean_Percentage=('Percentage', 'mean'),
        Std_Percentage=('Percentage', 'std')
    )

    genotype_stats['class1'] = cell_class
    all_data.append(genotype_stats)
all_data_df = pd.concat(all_data)
all_data_df = all_data_df.reset_index()

custom_order = ['WT', 'APP', 'E4', 'TE4']
all_data_df['geno_type'] = pd.Categorical(all_data_df['geno_type'], categories=custom_order, ordered=True)
all_data_df = all_data_df.sort_values(by=['class1', 'geno_type'])

color_map = {
    'WT': 'skyblue',
    'E4': '#bfef45',
    'APP': '#ec008c',
    'TE4': '#fabed4'
}
all_data_df['color'] = all_data_df['geno_type'].map(color_map)

class_names = all_data_df['class1'].unique()
num_genotypes = len(custom_order)
x_positions = []
spacing = 1.5 
bar_width = 0.8

for i, class_name in enumerate(class_names):
    class_data_len = len(all_data_df[all_data_df['class1'] == class_name])
    start_pos = i * (num_genotypes + spacing)
    x_positions.extend(np.arange(start_pos, start_pos + class_data_len))

plt.figure(figsize=(12, 6))
plt.bar(x_positions, all_data_df['Mean_Percentage'], 
        yerr=all_data_df['Std_Percentage'], color=all_data_df['color'], capsize=4, width=bar_width)

plt.title('Total Percentage of Each Class by Genotype')
plt.xlabel('Main Cell Types')
plt.ylabel('Percentage (%)')

class_ticks = [np.mean(x_positions[i * num_genotypes:(i + 1) * num_genotypes]) for i in range(len(class_names))]
plt.xticks(class_ticks, class_names, rotation=0)
legend_labels = [plt.Line2D([0], [0], color=color_map[geno], lw=4) for geno in custom_order]
plt.legend(legend_labels, custom_order, title="Genotype", loc='upper left')
plt.tight_layout()

output_path = os.path.join(output_folder, 'ExtFig1c_mainCellType.pdf')
plt.savefig(output_path, format='pdf', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
import pandas as pd
import numpy as np
from scipy.stats import chi2_contingency, fisher_exact

genotype_pairs = [('WT', 'APP'), ('E4', 'TE4')]

significance_results = []

for cell_class in adata.obs['class1'].unique():
    for genotype1, genotype2 in genotype_pairs:
        cells_genotype1 = adata[adata.obs['geno_type'] == genotype1]
        cells_genotype2 = adata[adata.obs['geno_type'] == genotype2]

        total_cells_genotype1 = cells_genotype1.shape[0]
        total_cells_genotype2 = cells_genotype2.shape[0]

        cells_in_class_genotype1 = cells_genotype1[cells_genotype1.obs['class1'] == cell_class].shape[0]
        cells_in_class_genotype2 = cells_genotype2[cells_genotype2.obs['class1'] == cell_class].shape[0]

        contingency_table = np.array([
            [cells_in_class_genotype1, total_cells_genotype1 - cells_in_class_genotype1],
            [cells_in_class_genotype2, total_cells_genotype2 - cells_in_class_genotype2]
        ])

        if np.any(contingency_table < 5):
            _, pvalue = fisher_exact(contingency_table)
            test_used = 'Fisher\'s Exact Test'
        else:
            _, pvalue, _, _ = chi2_contingency(contingency_table)
            test_used = 'Chi-square Test'

        significance_results.append({
            'class1': cell_class,
            'genotype1': genotype1,
            'genotype2': genotype2,
            'count_genotype1_in_class': cells_in_class_genotype1,
            'total_genotype1': total_cells_genotype1,
            'count_genotype2_in_class': cells_in_class_genotype2,
            'total_genotype2': total_cells_genotype2,
            'p-value': pvalue,
            'test_used': test_used
        })

significance_df = pd.DataFrame(significance_results)

output_csv_path = os.path.join(output_folder, 'ExtFig1c_mainCellType_significant.csv')
significance_df.to_csv(output_csv_path, index=False)
print(significance_df)
