In [None]:
import os

import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors

import scanpy as sc
sc.settings.n_jobs = 56
sc.settings.set_figure_params(dpi=180, dpi_save=300, frameon=False, figsize=(4, 4), fontsize=8, facecolor='white')

In [None]:
import re

def sort_slice_name(slice_names):
    slices = []
    
    for sn in slice_names:
        m = re.search('(.+)_slice([0-9]+)', sn)
        slices.append((m.group(1), int(m.group(2)), sn))    
    
    sorted_slices = sorted(slices, key=lambda x:(x[0], x[1]))
    return [x[2] for x in sorted_slices]

def merge_clusters(obs_df, cluster_col, clusters_to_merge):
    merged_cluster_id = sorted(clusters_to_merge)[0]
    obs_df[cluster_col][obs_df[cluster_col].isin(clusters_to_merge)] = merged_cluster_id
    obs_df[cluster_col] = obs_df[cluster_col].cat.remove_unused_categories()

In [None]:
ccf_registration_df = pd.read_csv('/home/xingjiepan/data/whole_brain/CCF_registration/20230614/wb_cell_metadata_230614.csv', 
                                  index_col=0, low_memory=False)

In [None]:
sm1_names = ['SM_MOB outer', 'SM_MOB inner', 'SM_CTX', 'SM_STR', 'SM_PAL/STR', 'SM_OLF/HPF',
    'SM_HY', 'SM_TH', 'SM_MB', 'SM_HB', 'SM_CB', ]

In [None]:
sm1n = 'SM_MOB inner'

adata = sc.read_h5ad(os.path.join('spatial_module_2_adatas', sm1n.replace('/', '-') + '.h5ad'))

adata.obs['ccfx'] = ccf_registration_df['ccfx_2']
adata.obs['ccfy'] = ccf_registration_df['ccfy_2']
adata.obs['ccfz'] = ccf_registration_df['ccfz_2']

sc.pl.umap(adata, color='leiden', legend_loc='on data', title=sm1n)
adata_display = adata[adata.obs['slice_id'].str.startswith('co2')]

# Color the cells
color_vec = np.array(['#000000'] * adata.shape[0])
for i in range(len(adata.obs['leiden'].cat.categories)):
    color_vec[adata.obs['leiden'] == adata.obs['leiden'].cat.categories[i]] =\
            adata.uns['leiden_colors'][i]
    adata.obs['color_leiden'] = color_vec

# Spatial plots
slices_to_show = sort_slice_name(np.unique(adata_display.obs['slice_id']))

n_cols = 5
n_rows = int(np.ceil(len(slices_to_show) / 5))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows), dpi=150)

for i in range(n_rows):
    
    for j in range(n_cols):
        slice_id = i * n_cols + j
        if slice_id >= len(slices_to_show):
            continue
        adata_slice = adata_display[adata_display.obs['slice_id'] == slices_to_show[slice_id]]
        
        #axes[2 * i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
        #                   c=adata_slice.obs['cl_transfer'].cat.codes, cmap='gist_ncar',
        #                   edgecolor='none', s=2)
        #axes[2 * i, j].set_aspect('equal')
        #axes[2 * i, j].set_axis_off()
        
        
        axes[i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
                           c=adata_slice.obs['color_leiden'],
                           edgecolor='none', s=4)
        axes[i, j].set_aspect('equal')
        axes[i, j].set_axis_off()
        axes[i, j].set_title(slices_to_show[slice_id])

#merge_clusters(adata.obs, 'leiden', ['0', '2', '3'])
adata.write_h5ad(os.path.join('spatial_module_2_adatas', 
                              sm1n.replace('/', '-') + '_manual_merged.h5ad'))

In [None]:
sm1n = 'SM_MOB outer'

adata = sc.read_h5ad(os.path.join('spatial_module_2_adatas', sm1n.replace('/', '-') + '.h5ad'))

adata.obs['ccfx'] = ccf_registration_df['ccfx_2']
adata.obs['ccfy'] = ccf_registration_df['ccfy_2']
adata.obs['ccfz'] = ccf_registration_df['ccfz_2']

sc.pl.umap(adata, color='leiden', legend_loc='on data', title=sm1n)
adata_display = adata[adata.obs['slice_id'].str.startswith('co2')]

# Color the cells
color_vec = np.array(['#000000'] * adata.shape[0])
for i in range(len(adata.obs['leiden'].cat.categories)):
    color_vec[adata.obs['leiden'] == adata.obs['leiden'].cat.categories[i]] =\
            adata.uns['leiden_colors'][i]
    adata.obs['color_leiden'] = color_vec

# Spatial plots
slices_to_show = sort_slice_name(np.unique(adata_display.obs['slice_id']))

n_cols = 5
n_rows = int(np.ceil(len(slices_to_show) / 5))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows), dpi=150)

for i in range(n_rows):
    
    for j in range(n_cols):
        slice_id = i * n_cols + j
        if slice_id >= len(slices_to_show):
            continue
        adata_slice = adata_display[adata_display.obs['slice_id'] == slices_to_show[slice_id]]
        
        #axes[2 * i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
        #                   c=adata_slice.obs['cl_transfer'].cat.codes, cmap='gist_ncar',
        #                   edgecolor='none', s=2)
        #axes[2 * i, j].set_aspect('equal')
        #axes[2 * i, j].set_axis_off()
        
        
        axes[i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
                           c=adata_slice.obs['color_leiden'],
                           edgecolor='none', s=4)
        axes[i, j].set_aspect('equal')
        axes[i, j].set_axis_off()
        axes[i, j].set_title(slices_to_show[slice_id])

#merge_clusters(adata.obs, 'leiden', ['0', '3'])
adata.write_h5ad(os.path.join('spatial_module_2_adatas', 
                              sm1n.replace('/', '-') + '_manual_merged.h5ad'))

In [None]:
sm1n = 'SM_PAL/STR'

adata = sc.read_h5ad(os.path.join('spatial_module_2_adatas', sm1n.replace('/', '-') + '.h5ad'))

adata.obs['ccfx'] = ccf_registration_df['ccfx_2']
adata.obs['ccfy'] = ccf_registration_df['ccfy_2']
adata.obs['ccfz'] = ccf_registration_df['ccfz_2']

sc.pl.umap(adata, color='leiden', legend_loc='on data', title=sm1n)
adata_display = adata[adata.obs['slice_id'].str.startswith('co2')]

# Color the cells
color_vec = np.array(['#000000'] * adata.shape[0])
for i in range(len(adata.obs['leiden'].cat.categories)):
    color_vec[adata.obs['leiden'] == adata.obs['leiden'].cat.categories[i]] =\
            adata.uns['leiden_colors'][i]
    adata.obs['color_leiden'] = color_vec

# Spatial plots
slices_to_show = sort_slice_name(np.unique(adata_display.obs['slice_id']))

n_cols = 5
n_rows = int(np.ceil(len(slices_to_show) / 5))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows), dpi=150)

for i in range(n_rows):
    
    for j in range(n_cols):
        slice_id = i * n_cols + j
        if slice_id >= len(slices_to_show):
            continue
        adata_slice = adata_display[adata_display.obs['slice_id'] == slices_to_show[slice_id]]
        
        #axes[2 * i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
        #                   c=adata_slice.obs['cl_transfer'].cat.codes, cmap='gist_ncar',
        #                   edgecolor='none', s=2)
        #axes[2 * i, j].set_aspect('equal')
        #axes[2 * i, j].set_axis_off()
        
        
        axes[i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
                           c=adata_slice.obs['color_leiden'],
                           edgecolor='none', s=4)
        axes[i, j].set_aspect('equal')
        axes[i, j].set_axis_off()
        axes[i, j].set_title(slices_to_show[slice_id])


merge_clusters(adata.obs, 'leiden', ['2', '11'])
merge_clusters(adata.obs, 'leiden', ['0', '9'])
merge_clusters(adata.obs, 'leiden', ['4', '8'])

adata.write_h5ad(os.path.join('spatial_module_2_adatas', 
                              sm1n.replace('/', '-') + '_manual_merged.h5ad'))

In [None]:
sm1n = 'SM_STR'

adata = sc.read_h5ad(os.path.join('spatial_module_2_adatas', sm1n.replace('/', '-') + '.h5ad'))

adata.obs['ccfx'] = ccf_registration_df['ccfx_2']
adata.obs['ccfy'] = ccf_registration_df['ccfy_2']
adata.obs['ccfz'] = ccf_registration_df['ccfz_2']

sc.pl.umap(adata, color='leiden', legend_loc='on data', title=sm1n)
adata_display = adata[adata.obs['slice_id'].str.startswith('co2')]

# Color the cells
color_vec = np.array(['#000000'] * adata.shape[0])
for i in range(len(adata.obs['leiden'].cat.categories)):
    color_vec[adata.obs['leiden'] == adata.obs['leiden'].cat.categories[i]] =\
            adata.uns['leiden_colors'][i]
    adata.obs['color_leiden'] = color_vec

# Spatial plots
slices_to_show = sort_slice_name(np.unique(adata_display.obs['slice_id']))

n_cols = 5
n_rows = int(np.ceil(len(slices_to_show) / 5))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows), dpi=150)

for i in range(n_rows):
    
    for j in range(n_cols):
        slice_id = i * n_cols + j
        if slice_id >= len(slices_to_show):
            continue
        adata_slice = adata_display[adata_display.obs['slice_id'] == slices_to_show[slice_id]]
        
        #axes[2 * i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
        #                   c=adata_slice.obs['cl_transfer'].cat.codes, cmap='gist_ncar',
        #                   edgecolor='none', s=2)
        #axes[2 * i, j].set_aspect('equal')
        #axes[2 * i, j].set_axis_off()
        
        
        axes[i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
                           c=adata_slice.obs['color_leiden'],
                           edgecolor='none', s=4)
        axes[i, j].set_aspect('equal')
        axes[i, j].set_axis_off()
        axes[i, j].set_title(slices_to_show[slice_id])

merge_clusters(adata.obs, 'leiden', ['2', '3'])
merge_clusters(adata.obs, 'leiden', ['5', '9'])
merge_clusters(adata.obs, 'leiden', ['7', '11'])

adata.write_h5ad(os.path.join('spatial_module_2_adatas', 
                              sm1n.replace('/', '-') + '_manual_merged.h5ad'))

In [None]:
sm1n = 'SM_CTX'

adata = sc.read_h5ad(os.path.join('spatial_module_2_adatas', sm1n.replace('/', '-') + '.h5ad'))

adata.obs['ccfx'] = ccf_registration_df['ccfx_2']
adata.obs['ccfy'] = ccf_registration_df['ccfy_2']
adata.obs['ccfz'] = ccf_registration_df['ccfz_2']

sc.pl.umap(adata, color='leiden', legend_loc='on data', title=sm1n)
adata_display = adata[adata.obs['slice_id'].str.startswith('co2')]

# Color the cells
color_vec = np.array(['#000000'] * adata.shape[0])
for i in range(len(adata.obs['leiden'].cat.categories)):
    color_vec[adata.obs['leiden'] == adata.obs['leiden'].cat.categories[i]] =\
            adata.uns['leiden_colors'][i]
    adata.obs['color_leiden'] = color_vec

# Spatial plots
slices_to_show = sort_slice_name(np.unique(adata_display.obs['slice_id']))

n_cols = 5
n_rows = int(np.ceil(len(slices_to_show) / 5))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows), dpi=150)

for i in range(n_rows):
    
    for j in range(n_cols):
        slice_id = i * n_cols + j
        if slice_id >= len(slices_to_show):
            continue
        adata_slice = adata_display[adata_display.obs['slice_id'] == slices_to_show[slice_id]]
        
        #axes[2 * i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
        #                   c=adata_slice.obs['cl_transfer'].cat.codes, cmap='gist_ncar',
        #                   edgecolor='none', s=2)
        #axes[2 * i, j].set_aspect('equal')
        #axes[2 * i, j].set_axis_off()
        
        
        axes[i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
                           c=adata_slice.obs['color_leiden'],
                           edgecolor='none', s=4)
        axes[i, j].set_aspect('equal')
        axes[i, j].set_axis_off()
        axes[i, j].set_title(slices_to_show[slice_id])

merge_clusters(adata.obs, 'leiden', ['1', '4', '6', '8'])
        
adata.write_h5ad(os.path.join('spatial_module_2_adatas', 
                              sm1n.replace('/', '-') + '_manual_merged.h5ad'))

In [None]:
sm1n = 'SM_OLF/HPF'

adata = sc.read_h5ad(os.path.join('spatial_module_2_adatas', sm1n.replace('/', '-') + '.h5ad'))

adata.obs['ccfx'] = ccf_registration_df['ccfx_2']
adata.obs['ccfy'] = ccf_registration_df['ccfy_2']
adata.obs['ccfz'] = ccf_registration_df['ccfz_2']

sc.pl.umap(adata, color='leiden', legend_loc='on data', title=sm1n)
adata_display = adata[adata.obs['slice_id'].str.startswith('co2')]

# Color the cells
color_vec = np.array(['#000000'] * adata.shape[0])
for i in range(len(adata.obs['leiden'].cat.categories)):
    color_vec[adata.obs['leiden'] == adata.obs['leiden'].cat.categories[i]] =\
            adata.uns['leiden_colors'][i]
    adata.obs['color_leiden'] = color_vec

# Spatial plots
slices_to_show = sort_slice_name(np.unique(adata_display.obs['slice_id']))

n_cols = 5
n_rows = int(np.ceil(len(slices_to_show) / 5))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows), dpi=150)

for i in range(n_rows):
    
    for j in range(n_cols):
        slice_id = i * n_cols + j
        if slice_id >= len(slices_to_show):
            continue
        adata_slice = adata_display[adata_display.obs['slice_id'] == slices_to_show[slice_id]]
        
        #axes[2 * i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
        #                   c=adata_slice.obs['cl_transfer'].cat.codes, cmap='gist_ncar',
        #                   edgecolor='none', s=2)
        #axes[2 * i, j].set_aspect('equal')
        #axes[2 * i, j].set_axis_off()
        
        
        axes[i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
                           c=adata_slice.obs['color_leiden'],
                           edgecolor='none', s=4)
        axes[i, j].set_aspect('equal')
        axes[i, j].set_axis_off()
        axes[i, j].set_title(slices_to_show[slice_id])


merge_clusters(adata.obs, 'leiden', ['5', '16'])
merge_clusters(adata.obs, 'leiden', ['3', '12', '13', '14', '24', '31', '33'])


adata.write_h5ad(os.path.join('spatial_module_2_adatas', 
                              sm1n.replace('/', '-') + '_manual_merged.h5ad'))

In [None]:
sm1n = 'SM_HY'

adata = sc.read_h5ad(os.path.join('spatial_module_2_adatas', sm1n.replace('/', '-') + '.h5ad'))

adata.obs['ccfx'] = ccf_registration_df['ccfx_2']
adata.obs['ccfy'] = ccf_registration_df['ccfy_2']
adata.obs['ccfz'] = ccf_registration_df['ccfz_2']

sc.pl.umap(adata, color='leiden', legend_loc='on data', title=sm1n)
adata_display = adata[adata.obs['slice_id'].str.startswith('co2')]

# Color the cells
color_vec = np.array(['#000000'] * adata.shape[0])
for i in range(len(adata.obs['leiden'].cat.categories)):
    color_vec[adata.obs['leiden'] == adata.obs['leiden'].cat.categories[i]] =\
            adata.uns['leiden_colors'][i]
    adata.obs['color_leiden'] = color_vec

# Spatial plots
slices_to_show = sort_slice_name(np.unique(adata_display.obs['slice_id']))

n_cols = 5
n_rows = int(np.ceil(len(slices_to_show) / 5))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows), dpi=150)

for i in range(n_rows):
    
    for j in range(n_cols):
        slice_id = i * n_cols + j
        if slice_id >= len(slices_to_show):
            continue
        adata_slice = adata_display[adata_display.obs['slice_id'] == slices_to_show[slice_id]]
        
        #axes[2 * i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
        #                   c=adata_slice.obs['cl_transfer'].cat.codes, cmap='gist_ncar',
        #                   edgecolor='none', s=2)
        #axes[2 * i, j].set_aspect('equal')
        #axes[2 * i, j].set_axis_off()
        
        
        axes[i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
                           c=adata_slice.obs['color_leiden'],
                           edgecolor='none', s=4)
        axes[i, j].set_aspect('equal')
        axes[i, j].set_axis_off()
        axes[i, j].set_title(slices_to_show[slice_id])

merge_clusters(adata.obs, 'leiden', ['1', '17'])
merge_clusters(adata.obs, 'leiden', ['3', '15'])
merge_clusters(adata.obs, 'leiden', ['4', '18'])


adata.write_h5ad(os.path.join('spatial_module_2_adatas', 
                              sm1n.replace('/', '-') + '_manual_merged.h5ad'))

In [None]:
sm1n = 'SM_TH'

adata = sc.read_h5ad(os.path.join('spatial_module_2_adatas', sm1n.replace('/', '-') + '.h5ad'))

adata.obs['ccfx'] = ccf_registration_df['ccfx_2']
adata.obs['ccfy'] = ccf_registration_df['ccfy_2']
adata.obs['ccfz'] = ccf_registration_df['ccfz_2']

sc.pl.umap(adata, color='leiden', legend_loc='on data', title=sm1n)
adata_display = adata[adata.obs['slice_id'].str.startswith('co2')]

# Color the cells
color_vec = np.array(['#000000'] * adata.shape[0])
for i in range(len(adata.obs['leiden'].cat.categories)):
    color_vec[adata.obs['leiden'] == adata.obs['leiden'].cat.categories[i]] =\
            adata.uns['leiden_colors'][i]
    adata.obs['color_leiden'] = color_vec

# Spatial plots
slices_to_show = sort_slice_name(np.unique(adata_display.obs['slice_id']))

n_cols = 5
n_rows = int(np.ceil(len(slices_to_show) / 5))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows), dpi=150)

for i in range(n_rows):
    
    for j in range(n_cols):
        slice_id = i * n_cols + j
        if slice_id >= len(slices_to_show):
            continue
        adata_slice = adata_display[adata_display.obs['slice_id'] == slices_to_show[slice_id]]
        
        #axes[2 * i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
        #                   c=adata_slice.obs['cl_transfer'].cat.codes, cmap='gist_ncar',
        #                   edgecolor='none', s=2)
        #axes[2 * i, j].set_aspect('equal')
        #axes[2 * i, j].set_axis_off()
        
        
        axes[i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
                           c=adata_slice.obs['color_leiden'],
                           edgecolor='none', s=4)
        axes[i, j].set_aspect('equal')
        axes[i, j].set_axis_off()
        axes[i, j].set_title(slices_to_show[slice_id])

merge_clusters(adata.obs, 'leiden', ['0', '16'])
merge_clusters(adata.obs, 'leiden', ['1', '3', '18'])

adata.write_h5ad(os.path.join('spatial_module_2_adatas', 
                              sm1n.replace('/', '-') + '_manual_merged.h5ad'))

In [None]:
sm1n = 'SM_MB'

adata = sc.read_h5ad(os.path.join('spatial_module_2_adatas', sm1n.replace('/', '-') + '.h5ad'))

adata.obs['ccfx'] = ccf_registration_df['ccfx_2']
adata.obs['ccfy'] = ccf_registration_df['ccfy_2']
adata.obs['ccfz'] = ccf_registration_df['ccfz_2']

sc.pl.umap(adata, color='leiden', legend_loc='on data', title=sm1n)
adata_display = adata[adata.obs['slice_id'].str.startswith('co2')]

# Color the cells
color_vec = np.array(['#000000'] * adata.shape[0])
for i in range(len(adata.obs['leiden'].cat.categories)):
    color_vec[adata.obs['leiden'] == adata.obs['leiden'].cat.categories[i]] =\
            adata.uns['leiden_colors'][i]
    adata.obs['color_leiden'] = color_vec

# Spatial plots
slices_to_show = sort_slice_name(np.unique(adata_display.obs['slice_id']))

n_cols = 5
n_rows = int(np.ceil(len(slices_to_show) / 5))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows), dpi=150)

for i in range(n_rows):
    
    for j in range(n_cols):
        slice_id = i * n_cols + j
        if slice_id >= len(slices_to_show):
            continue
        adata_slice = adata_display[adata_display.obs['slice_id'] == slices_to_show[slice_id]]
        
        #axes[2 * i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
        #                   c=adata_slice.obs['cl_transfer'].cat.codes, cmap='gist_ncar',
        #                   edgecolor='none', s=2)
        #axes[2 * i, j].set_aspect('equal')
        #axes[2 * i, j].set_axis_off()
        
        
        axes[i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
                           c=adata_slice.obs['color_leiden'],
                           edgecolor='none', s=4)
        axes[i, j].set_aspect('equal')
        axes[i, j].set_axis_off()
        axes[i, j].set_title(slices_to_show[slice_id])

merge_clusters(adata.obs, 'leiden', ['1', '5'])

adata.write_h5ad(os.path.join('spatial_module_2_adatas', 
                              sm1n.replace('/', '-') + '_manual_merged.h5ad'))

In [None]:
sm1n = 'SM_HB'

adata = sc.read_h5ad(os.path.join('spatial_module_2_adatas', sm1n.replace('/', '-') + '.h5ad'))

adata.obs['ccfx'] = ccf_registration_df['ccfx_2']
adata.obs['ccfy'] = ccf_registration_df['ccfy_2']
adata.obs['ccfz'] = ccf_registration_df['ccfz_2']

sc.pl.umap(adata, color='leiden', legend_loc='on data', title=sm1n, palette='tab20')
adata_display = adata[adata.obs['slice_id'].str.startswith('co2')]

# Color the cells
color_vec = np.array(['#000000'] * adata.shape[0])
for i in range(len(adata.obs['leiden'].cat.categories)):
    color_vec[adata.obs['leiden'] == adata.obs['leiden'].cat.categories[i]] =\
            adata.uns['leiden_colors'][i]
    adata.obs['color_leiden'] = color_vec

# Spatial plots
slices_to_show = sort_slice_name(np.unique(adata_display.obs['slice_id']))

n_cols = 5
n_rows = int(np.ceil(len(slices_to_show) / 5))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows), dpi=150)

for i in range(n_rows):
    
    for j in range(n_cols):
        slice_id = i * n_cols + j
        if slice_id >= len(slices_to_show):
            continue
        adata_slice = adata_display[adata_display.obs['slice_id'] == slices_to_show[slice_id]]
        
        #axes[2 * i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
        #                   c=adata_slice.obs['cl_transfer'].cat.codes, cmap='gist_ncar',
        #                   edgecolor='none', s=2)
        #axes[2 * i, j].set_aspect('equal')
        #axes[2 * i, j].set_axis_off()
        
        
        axes[i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
                           c=adata_slice.obs['color_leiden'],
                           edgecolor='none', s=4)
        axes[i, j].set_aspect('equal')
        axes[i, j].set_axis_off()
        axes[i, j].set_title(slices_to_show[slice_id])

merge_clusters(adata.obs, 'leiden', ['2', '21'])
merge_clusters(adata.obs, 'leiden', ['8', '9', '19'])
merge_clusters(adata.obs, 'leiden', ['17', '18'])

adata.write_h5ad(os.path.join('spatial_module_2_adatas', 
                              sm1n.replace('/', '-') + '_manual_merged.h5ad'))

In [None]:
sm1n = 'SM_CB'

adata = sc.read_h5ad(os.path.join('spatial_module_2_adatas', sm1n.replace('/', '-') + '.h5ad'))

adata.obs['ccfx'] = ccf_registration_df['ccfx_2']
adata.obs['ccfy'] = ccf_registration_df['ccfy_2']
adata.obs['ccfz'] = ccf_registration_df['ccfz_2']

sc.pl.umap(adata, color='leiden', legend_loc='on data', title=sm1n)
adata_display = adata[adata.obs['slice_id'].str.startswith('co2')]

# Color the cells
color_vec = np.array(['#000000'] * adata.shape[0])
for i in range(len(adata.obs['leiden'].cat.categories)):
    color_vec[adata.obs['leiden'] == adata.obs['leiden'].cat.categories[i]] =\
            adata.uns['leiden_colors'][i]
    adata.obs['color_leiden'] = color_vec

# Spatial plots
slices_to_show = sort_slice_name(np.unique(adata_display.obs['slice_id']))

n_cols = 5
n_rows = int(np.ceil(len(slices_to_show) / 5))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows), dpi=150)

for i in range(n_rows):
    
    for j in range(n_cols):
        slice_id = i * n_cols + j
        if slice_id >= len(slices_to_show):
            continue
        adata_slice = adata_display[adata_display.obs['slice_id'] == slices_to_show[slice_id]]
        
        #axes[2 * i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
        #                   c=adata_slice.obs['cl_transfer'].cat.codes, cmap='gist_ncar',
        #                   edgecolor='none', s=2)
        #axes[2 * i, j].set_aspect('equal')
        #axes[2 * i, j].set_axis_off()
        
        
        axes[i, j].scatter(-adata_slice.obs['ccfz'], -adata_slice.obs['ccfy'], 
                           c=adata_slice.obs['color_leiden'],
                           edgecolor='none', s=4)
        axes[i, j].set_aspect('equal')
        axes[i, j].set_axis_off()
        axes[i, j].set_title(slices_to_show[slice_id])

merge_clusters(adata.obs, 'leiden', ['0', '1', '3', '5', '7', '8', '10'])
merge_clusters(adata.obs, 'leiden', [c for c in np.unique(adata.obs['leiden']) if c != '0'])
adata.write_h5ad(os.path.join('spatial_module_2_adatas', 
                              sm1n.replace('/', '-') + '_manual_merged.h5ad'))