In [None]:
import pickle
import seaborn as sns
import imageio as io
import pandas as pd
import os
import matplotlib.pyplot as plt
import numpy as np
import tifffile
from tqdm.notebook import tqdm
import pathlib
from cellpose import models, core
import json
import glob
import PIL
import scanpy as sc

In [None]:
meta_data_file=pd.read_csv(r'240719AnalysisMeta.txt', header=None, delimiter=' ',index_col=0)
meta_data_file.columns = ['sample_name', 'genotype','Sex','Age','adata','adata_conc','adata_name','adata_dir']
meta_data_reordered = meta_data_file[['adata_name', 'genotype', 'Sex', 'Age']]
meta_data_file_cleaned = meta_data_reordered.drop([7, 8, 11, 12, 17])
meta_data_file_cleaned

In [None]:
data_path = os.path.join('MERFISH','data', '240719AnalysisDAM_TERM', 'figs')


In [None]:
file_path = os.path.join(data_path, 'f0_adata_1.h5ad')
adata_1=sc.read(file_path)
adata_1.obs['batch']='E4_1'
adata_1

In [None]:
file_path = os.path.join(data_path, 'f0_adata_2.h5ad')
adata_2=sc.read(file_path)
adata_2.obs['batch']='E4_2'
adata_2

In [None]:
file_path = os.path.join(data_path, 'f0_adata_3.h5ad')
adata_3=sc.read(file_path)
adata_3.obs['batch']='TE4_1'
adata_3

In [None]:
file_path = os.path.join(data_path, 'f0_adata_4.h5ad')
adata_4=sc.read(file_path)
adata_4.obs['batch']='TE4_2'
adata_4

In [None]:
file_path = os.path.join(data_path, 'f0_adata_5.h5ad')
adata_5=sc.read(file_path)
adata_5.obs['batch']='APP_1'
adata_5

In [None]:
file_path = os.path.join(data_path, 'f0_adata_6.h5ad')
adata_6=sc.read(file_path)
adata_6.obs['batch']='APP_2'
adata_6

In [None]:
file_path = os.path.join(data_path, 'f0_adata_9.h5ad')
adata_9=sc.read(file_path)
adata_9.obs['batch']='TE4_3'
adata_9

In [None]:
file_path = os.path.join(data_path, 'f0_adata_10.h5ad')
adata_10=sc.read(file_path)
adata_10.obs['batch']='TE4_4'
adata_10

In [None]:
file_path = os.path.join(data_path, 'f0_adata_13.h5ad')
adata_13=sc.read(file_path)
adata_13.obs['batch']='APP_3'
adata_13

In [None]:
file_path = os.path.join(data_path, 'f0_adata_14.h5ad')
adata_14=sc.read(file_path)
adata_14.obs['batch']='APP_4'
adata_14

In [None]:
file_path = os.path.join(data_path, 'f0_adata_15.h5ad')
adata_15=sc.read(file_path)
adata_15.obs['batch']='WT_1'
adata_15

In [None]:
file_path = os.path.join(data_path, 'f0_adata_16.h5ad')
adata_16=sc.read(file_path)
adata_16.obs['batch']='WT_2'
adata_16

In [None]:

from anndata import AnnData
batch_names=['E4_1','E4_2','TE4_1','TE4_2','APP_1','APP_2','TE4_3','TE4_4','APP_3','APP_4','WT_1','WT_2']
adata=AnnData.concatenate(adata_1, adata_2,adata_3, adata_4,adata_5,adata_6,adata_9,adata_10,adata_13,adata_14,adata_15,adata_16, batch_key='batch',batch_categories=batch_names)


In [None]:
sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_Blank'],
             jitter=0.4, multi_panel=True)

In [None]:
sc.pl.scatter(adata, x='total_counts', y='pct_counts_Blank')
sc.pl.scatter(adata, x='total_counts', y='n_genes_by_counts')

In [None]:
exp = adata.to_df().sum(axis=1)
exp.hist(bins=200)

In [None]:
np.median(exp)

In [None]:
min_expression = 10
keep_cells = exp[exp > min_expression].index.tolist()
adata = adata[keep_cells]
adata
exp = adata.to_df().sum(axis=1)
exp.hist(bins=200)

In [None]:
sc.pl.highest_expr_genes(adata, n_top=20 )

In [None]:
adata.write(r'f0_adata_concat_raw.h5ad')

In [None]:
adata.X

# Normalization


In [None]:
sc.pp.normalize_total(adata, target_sum=np.median(adata.obs["total_counts"]))
sc.pp.log1p(adata)


In [None]:
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
sc.pl.highly_variable_genes(adata)

In [None]:
hvg = adata.var[adata.var['highly_variable']].index.tolist()
print(len(hvg),list(np.unique(hvg)))

In [None]:
adata.raw=adata

In [None]:
adata.write(r'Z:\GV1Backup\AM\data\240719AnalysisDAM_TERM\f0_adata_concat_norm.h5ad')


# PCA

In [None]:
sc.pp.pca(adata)
sc.pp.neighbors(adata,n_neighbors=10, n_pcs=40)

In [None]:
sc.tl.umap(adata,random_state=9)
sc.pl.umap(adata, color='Tmem119')

In [None]:
sc.pl.umap(adata, color=['Tmem119','batch'])

In [None]:
adata.write(r'f0_adata_concat_pca.h5ad')

# Harmony and Yao_subclass



In [None]:
import scanpy.external as sce
sce.pp.harmony_integrate(adata, 'batch', max_iter_harmony = 10)


In [None]:
sc.pp.pca(adata)
sc.pp.neighbors(adata, use_rep='X_pca_harmony')
sc.tl.umap(adata,random_state=9,min_dist=0.3,init_pos='X_pca_harmony')

In [None]:
sc.pl.umap(adata, color=['Tmem119','Spp1','batch'])


In [None]:
sc.tl.leiden(adata,resolution=1, key_added='res1')
sc.tl.leiden(adata,resolution=1.5, key_added='res1p5')
sc.tl.leiden(adata,resolution=2, key_added='res2')
sc.tl.leiden(adata,resolution=3, key_added='res3')

# sc.tl.paga(adata, groups='res3')

In [None]:
sc.tl.paga(adata, groups='res3')

In [None]:
sc.pl.paga(adata, threshold=1,fontoutline=1,color=['Tmem119','Aqp4','Olig2'])

In [None]:
sc.tl.umap(adata, init_pos='paga', min_dist=0.2)

In [None]:
sc.pl.umap(adata, color=['Tmem119','Aqp4','Olig2','Slc17a7','Sst','Vtn'])

In [None]:
sc.pl.umap(adata, color=['res3','res2','res1p5','res1'])

# Cell type Prediction

In [None]:
adata_yao=sc.read(r'X:\SC\2022NatureYaoCTXHPF\subsampled_to_500cellsPerCLuster.h5ad')

In [None]:
genes_common=np.intersect1d(adata.var.index.tolist(),adata_yao.var.index.tolist())

In [None]:
genes_common=np.intersect1d(adata.var.index.tolist(),adata_yao.var.index.tolist())
from sklearn.ensemble import RandomForestClassifier
X_train = adata_yao[:, genes_common].X  # Feature matrix for adata
y_train = adata_yao.obs['subclass_label']  # Labels for adata

X_test = adata[:, genes_common].X  # Feature matrix for adata2

clf = RandomForestClassifier(random_state=42)
clf.fit(X_train, y_train)

# Predict labels for adata
adata.obs['predicted_label_2'] = clf.predict(X_test)

In [None]:
sc.pl.umap(adata, color=['res1p5','res3','predicted_label'],legend_loc='on data')

In [None]:
adata.write(r'X:\GV1Backup\AM\data\240719AnalysisDAM_TERM\f0_adata_concat_harmpaga_predict_muspx.h5ad')

In [None]:
unique_batches = adata.obs.batch.cat.categories

In [None]:
new_spatial = np.zeros(np.shape(adata.obsm['X_spatial']))
#additions = np.array([[0, 0], [70000, 0], [140000,0], [210000, 0], [280000, 0], [350000, 0],[420000, 0],[490000, 0],[560000,0],[630000,0],[700000,0],[770000,0],[840000,0]])
additions = np.array([[0, 0], [70000, 0], [140000,0], [210000, 0], [0, 60000], [70000, 60000],[140000, 60000],[210000, 60000],[0, 120000], [70000, 120000],[140000, 120000],[210000, 120000]])

addition_ctr = 0
for batch in unique_batches:
    indices = np.where(adata.obs.batch.values == batch)[0]
    new_spatial[indices] = adata.obsm['X_spatial'][indices] + additions[addition_ctr]
    addition_ctr += 1

In [None]:
adata.obsm['X_multi_spatial'] = new_spatial

In [None]:
def plot_gene_expression(gene,adata,
                         bad_cells=[],vmax=None,vmax_perc=None,vmin=None,vmin_perc=None,
                         name=None,bad_fovs=[],transpose=-1,flipx=-1,flipy=-1,key = 'X_spatial'):
    scores = np.squeeze(np.array(adata[:,gene].X))
    
    x,y = (np.array(adata.obsm[key])*[flipx,flipy])[:,::transpose].T
    
    #x,y = -np.array([dic_cells_final[cell]['X'][1:] for cell in cells]).T
    

    #fig = plt.figure(figsize=(10,8), facecolor='black')#,dpi=300)
    if vmax is None and vmax_perc is None:
        vmax_ = 30
        vmin_= 0
    if vmax_perc is not None:
        vmax_ = np.percentile(scores, vmax_perc)
    if vmin_perc is not None:
        vmin_ = np.percentile(scores, vmin_perc)
    if vmax is not None:
        vmax_ = vmax
    if vmin is not None:
        vmin_ = vmin
    #fig = plt.figure(dpi=300)
    plt.scatter(x, y, c=scores, s=1, cmap='coolwarm', vmax=vmax_,vmin=vmin_)
    cb = plt.colorbar(shrink=0.5)
    if name is None:
        name = gene
    cb.ax.set_title(f'{name}\ncounts', color='white', fontsize='x-large', fontweight='heavy')
    cb.ax.yaxis.set_tick_params(color='black')
    cb.outline.set_edgecolor('w')
    plt.setp(plt.getp(cb.ax.axes, 'yticklabels'), color='white', fontsize='x-large', fontweight='heavy')
    #plt.grid(b=None)
    plt.axis('off')
    plt.axis('equal')
    plt.tight_layout()
    #plt.show()
    return fig

cmap = ["#e6194B", "#3cb44b", "#ffe119", "#4363d8", "#f58231", "#911eb4", "#42d4f4", "#f032e6", "#bfef45",
        "#fabed4", "#469990", "#dcbeff", "#9A6324", "#fffac8", "#800000", "#aaffc3", "#808000", "#ffd8b1",
        "#000075", "#a9a9a9"]
def plot_cluster_scdata(scdata,cmap,clusters=[1,2],transpose=1,flipx=1,flipy=1,tag = 'cluster', key = 'X_spatial'):
    import matplotlib.pyplot as plt
    #fig=plt.figure(figsize=(15, 5), facecolor="black")
    

    from matplotlib import pylab as plt
    x,y = (np.array(scdata.obsm[key])*[flipx,flipy])[:,::transpose].T
    #np.unique(scdata.obs["leiden"].astype(np.int))[::-1]
    plt.scatter(x, y, c='grey', s=2, marker='.')
    for cluster in clusters:
        cluster_ = str(cluster)
        inds = scdata.obs[tag] == cluster_
        x_ = x[inds]
        y_ = y[inds]
        col = cmap[int(cluster) % len(cmap)]
        plt.scatter(x_, y_, c=col, s=4, marker='.',label = cluster_)
    
    plt.grid(False)
    plt.axis("off")
    plt.axis("equal")
    plt.legend()
    plt.tight_layout()
    return fig

# Dealing with Vizgen grids

In [None]:
sc.tl.pca(adata, svd_solver="arpack",random_state=9,use_highly_variable=True)
sc.pl.pca(adata, color=['Gfap','Hexb','Aqp4','Wipf3','Sst','Syp'])
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=50)
sc.tl.umap(adata,random_state=9,min_dist=0.2,init_pos='X_pca_harmony')
sc.pl.umap(adata, color=['Gfap','Hexb','Aqp4','Wipf3','Sst','Syp'])
sc.tl.leiden(adata,resolution=3, key_added='nres3')

In [None]:
sc.tl.paga(adata, groups='nres3')
sc.pl.paga(adata, color=['Gfap','Hexb','Aqp4','Wipf3','Sst','Syp'])
sc.tl.umap(adata, init_pos='paga', min_dist=0.2)
sc.pl.umap(adata, color=['Gfap','Hexb','Aqp4','Wipf3','Sst','Syp'])

In [None]:
adata.uns['log1p']["base"] = None
sc.tl.rank_genes_groups(adata, 'nres3', method='t-test')
result = adata.uns['rank_genes_groups']
groups = result['names'].dtype.names
df_scdata = pd.DataFrame({group + '_' + key[:1]: result[key][group] for group in groups for key in ['names','logfoldchanges','pvals','pvals_adj']})
df_scdata.to_csv('f9_adata_nres3_rgg.csv')

In [None]:
del_cluster=['30','60','61']
adata_del1=adata[~adata.obs['nres3'].isin(del_cluster)]
adata_del1

In [None]:
sc.tl.pca(adata_del1, svd_solver="arpack",random_state=9,use_highly_variable=True)
sc.pp.neighbors(adata_del1, n_neighbors=10, n_pcs=50)
sc.tl.umap(adata_del1,random_state=9,min_dist=0.2)
sc.tl.leiden(adata_del1,resolution=4, key_added='dnres4')
sc.pl.umap(adata_del1, color=['Gfap','Hexb','Aqp4','Wipf3','Sst','Syp'])

In [None]:
clusters = adata_del1.obs['dnres4'].unique()
output_folder = r'data\240719AnalysisDAM_TERM\f9_dnres4_single_cluster_spxMap'
for cluster in clusters:
    fig = plt.figure(figsize=(12, 8), facecolor="black")
    plot_cluster_scdata(adata_del1, cmap, clusters=[cluster], transpose=1, flipx=1, flipy=-1, tag='dnres4', key='X_multi_spatial')
    
    plt.savefig(f'{output_folder}\\f9_cluster_{cluster}_dnres4.jpg', format='jpg', dpi=300, bbox_inches='tight')
    
    plt.close(fig)

In [None]:
del_cluster=['43','76']
adata_del2=adata_del1[~adata_del1.obs['dnres4'].isin(del_cluster)]
adata_del2

In [None]:
sc.tl.pca(adata_del2, svd_solver="arpack",random_state=9,use_highly_variable=True)
sc.pp.neighbors(adata_del2, n_neighbors=10, n_pcs=50)
sc.tl.umap(adata_del2,random_state=9,min_dist=0.2)
sc.tl.leiden(adata_del2,resolution=5, key_added='ddnres5')
sc.pl.umap(adata_del2, color=['Gfap','Hexb','Aqp4','Wipf3','Sst','Syp'])

In [None]:
clusters = adata_del2.obs['ddnres5'].unique()
output_folder = r'data\240719AnalysisDAM_TERM\f9_ddnres5_single_cluster_spxMap'
for cluster in clusters:
    fig = plt.figure(figsize=(12, 8), facecolor="black")
    plot_cluster_scdata(adata_del2, cmap, clusters=[cluster], transpose=1, flipx=1, flipy=-1, tag='ddnres5', key='X_multi_spatial')
    
    plt.savefig(f'{output_folder}\\f9_cluster_{cluster}_ddnres5.jpg', format='jpg', dpi=300, bbox_inches='tight')
    
    plt.close(fig)

In [None]:
cmap = ["#e6194B", "#3cb44b", "#ffe119", "#4363d8", "#f58231", "#911eb4", "#42d4f4", "#f032e6", "#bfef45",
        "#fabed4", "#469990", "#dcbeff", "#9A6324", "#fffac8", "#800000", "#aaffc3", "#808000", "#ffd8b1",
        "#000075", "#a9a9a9"]
def plot_cluster_scdata(scdata,cmap,clusters=[1,2],transpose=1,flipx=1,flipy=1,tag = 'cluster', key = 'X_spatial'):
    import matplotlib.pyplot as plt
    #fig=plt.figure(figsize=(15, 5), facecolor="black")
    

    from matplotlib import pylab as plt
    x,y = (np.array(scdata.obsm[key])*[flipx,flipy])[:,::transpose].T
    #np.unique(scdata.obs["leiden"].astype(np.int))[::-1]
    plt.scatter(x, y, c='#E0E0E0', s=4, marker='.')
    for cluster in clusters:
        cluster_ = str(cluster)
        inds = scdata.obs[tag] == cluster_
        x_ = x[inds]
        y_ = y[inds]
        col = cmap[int(cluster) % len(cmap)]
        plt.scatter(x_, y_, c=col, s=1, marker='.',label = cluster_)
    
    plt.grid(False)
    plt.axis("off")
    plt.axis("equal")
    plt.legend()
    plt.tight_layout()
    return fig