In [None]:
import scanpy as sc
import os
import numpy as np
import pandas as pd
import sys
import seaborn as sns
import matplotlib.pyplot as plt
import pylab as pl
import glob
import diopy

In [None]:
output_folder = r'./'
data_folder = '../../data'

spatial_datas = glob.glob(os.path.join(data_folder, 'integration', '*', 'spatial_data.h5ad'))
spatial_datas =[i for i in spatial_datas if 'dc3000' not in i]

seq_datas = glob.glob(os.path.join(data_folder, 'integration', '*', 'seq_data.h5ad'))
seq_datas =[i for i in seq_datas if 'dc3000' not in i]

spatial_adatas = []
for fl in spatial_datas:
    spatial_adatas.append(sc.read(fl))
    
ad_sp_only = sc.concat(spatial_adatas)

seq_adatas = []
for fl in seq_datas:
    seq_adatas.append(sc.read(fl))
ad_sc_only = sc.concat(seq_adatas)

In [None]:
figure_output_folder =  '6ac'

In [None]:
sc.set_figure_params(dpi=300, dpi_save=400)

In [None]:
try:
    os.mkdir(os.path.join(figure_output_folder))
except:
    print('Folder already exists')

In [None]:
ad_sp = ad_sp_only.copy()
ad_sc = ad_sc_only.copy()

In [None]:
ad_sc.obs['batch'] = 'rep1'

### Figure 6c

In [None]:
ad_sc_only.obs['modality'] = 'seq'
ad_sp_only.obs['modality'] = 'spatial'

In [None]:
ad_sp_only.obs['dpt_pseudotime'] = ad_sp_only.obs['pseudotime'].values

In [None]:
ad_sc_sp = sc.concat([ad_sc_only, ad_sp_only])

### Getting the rep1 rep2 joint embedding with the same colors. 

In [None]:
seq_folder = os.path.join(data_folder, "temp_objects", "AvrRpt2_alone2.h5")
scrna = diopy.input.read_h5(file = seq_folder)

In [None]:
ad_sc_sp.obs['Clusters'] = pd.Categorical(ad_sc_sp.obs['seurat_clusters'].values)

In [None]:
ad_sc_sp.uns['Clusters_colors'] = ['#023fa5',
                                    '#7d87b9',
                                    '#bec1d4',
                                    '#d6bcc0',
                                    '#bb7784',
                                    '#8e063b',
                                    '#4a6fe3',
                                    '#8595e1',
                                    '#b5bbe3',
                                    '#e6afb9',
                                    '#e07b91',
                                    '#d33f6a',
                                    '#11c638',
                                    '#8dd593',
                                    '#c6dec7',
                                    '#ead3c6',
                                    '#f0b98d',
                                    '#ef9708',
                                    '#0fcfc0',
                                    '#9cded6',
                                    '#d5eae7',
                                    '#f3e1eb',
                                    '#f6c4e1',
                                    '#f79cd4']

In [None]:
dic = {}
for categ in range(len(ad_sc_sp.obs['Clusters'].cat.categories)):
    dic[ad_sc_sp.obs['Clusters'].cat.categories[categ]] = ad_sc_sp.uns['Clusters_colors'][categ]

In [None]:
scrna.uns['SCT_snn_res.1_colors'] = dic

In [None]:
scrna.uns['SCT_snn_res.1_colors'] = pd.DataFrame(dic.values(), index=dic.keys()).loc[scrna.obs['SCT_snn_res.1'].cat.categories.values][0].values

In [None]:
sc.set_figure_params(dpi=400, dpi_save=400, figsize=(5, 4))
fig = sc.pl.umap(scrna, color='SCT_snn_res.1', return_fig=True)
fig.tight_layout()
fig.savefig(os.path.join(figure_output_folder, '6c_rep1_rep2.pdf'))
plt.show()
plt.close()

## Figure 5c

In [None]:
ad_sc = scrna.copy()

In [None]:
ad_sc.obs['celltype'] = ad_sc.obs['celltype'].replace('Epidermis', 'epidermis').replace('Mesophyll', 'mesophyll').replace('Vasculature', 'vasculature')

In [None]:
ad_sc = ad_sc[ad_sc.obs['celltype'] != 'Unknown']

In [None]:
ad_sc.uns['celltype_colors'] =  ['#FF007F',  # Bright Pink
                                '#008000',  # Green
                                '#FFD700']  # Gold

In [None]:
sc.set_figure_params(facecolor='black', figsize=(10, 10), dpi=300)
fig = sc.pl.embedding(ad_sc, basis='umap', color=['celltype'], vmax=0.1, frameon=False, size=5, return_fig=True)

l = plt.legend()
for text in l.get_texts():
    text.set_color("black")
plt.title('celltype', {'color' : 'white'})
fig.savefig(os.path.join(figure_output_folder, 'Figure6d_celltype_sc.pdf'))
plt.show()

In [None]:
ad_sp.obsm['X_spatial'] = ad_sp.obs[['x', 'y']].values

In [None]:
ad_sp.uns['celltype_colors'] =  ['#FF007F',  # Bright Pink
                                '#008000',  # Green
                                '#FFD700']  # Gold

In [None]:
sc.set_figure_params(facecolor='black', figsize=(10, 10), dpi=300)
fig = sc.pl.embedding(ad_sp[ad_sp.obs.batch == '9hr_avr'], basis='spatial', color=['celltype'], vmax=0.1, frameon=False, size=10, return_fig=True)

l = plt.legend()
for text in l.get_texts():
    text.set_color("black")
plt.title('celltype', {'color' : 'white'})
fig.savefig(os.path.join(figure_output_folder, 'Figure6d_celltype_sp.pdf'))
plt.show()

## Figure 6e

In [None]:
for cluster_to_map in np.unique(ad_sp.obs['predicted.r_clusters']):
    cluster_map = [1 if i == cluster_to_map else 0 for i in ad_sp.obs['predicted.r_clusters']]
    ad_sp.obs['cluster_map'] = cluster_map
    sc.set_figure_params(facecolor='white', figsize=(10, 10), dpi=300)
    fig = sc.pl.embedding(ad_sp[ad_sp.obs.batch == '9hr_avr'], basis='spatial', color=['cluster_map'], vmax=1.3, vmin = -.4, frameon=False, size=10, return_fig=True, cmap='Purples')

    l = plt.legend()
    for text in l.get_texts():
        text.set_color("black")
    plt.title(f'Mapped cluster {cluster_to_map}', {'color' : 'black'})
    try:
        os.mkdir(os.path.join(figure_output_folder, 'cluster_projections'))
    except:
        None
    fig.savefig(os.path.join(figure_output_folder, 'cluster_projections', 'Figure6e_cluster_{cluster_to_map}.pdf'))
    plt.close()

### Figure 6G

In [None]:
ad_sc = ad_sc_sp[ad_sc_sp.obs.modality == 'seq']

In [None]:
ad_sc.obs['batch'] = ['rep2' if 'rep2' in i else 'rep1' for i in ad_sc.obs.index]

In [None]:
renames = []
for k in ad_sc[ad_sc.obs['batch'] == 'rep1'].obs.index:
    splitter = k.split('_')[0]
    renames.append(splitter + '_col_' + k.split('_')[1] + '_rep1_' + k.split('_')[2])

In [None]:
new_indices = []
for i in range(len(ad_sc.obs.index)):
    if ad_sc.obs.batch.values[i] == 'rep1':
        new_indices.append(renames[i])
    else:
        new_indices.append(ad_sc.obs.index.values[i])

In [None]:
ad_sc.obs.index = new_indices

In [None]:
scrna.obs['umap_x'] = scrna.obsm['X_umap'][:, 0]
scrna.obs['umap_y'] = scrna.obsm['X_umap'][:, 1]

In [None]:
scrna.obs['sample_prune'] = [''.join(i.split('_')[:-1]) for i in scrna.obs.index]

In [None]:
sc.pl.umap(scrna, color='sample_prune')

In [None]:
ad_sc.obs.index = [i.replace('col_', '') for i in ad_sc.obs.index]

In [None]:
scrna.obs.index = [i.replace('col_', '') for i in scrna.obs.index]

In [None]:
ad_sc.obs = ad_sc.obs.merge(scrna.obs, how='left', left_index=True, right_index=True)

In [None]:
ad_sc.obsm['X_umap'] = ad_sc.obs[['umap_x', 'umap_y']].values

In [None]:
new_pseudo = [0 if ad_sc.obs.celltype_x.values[i] not in ['Mesophyll', 'mesophyll'] else ad_sc.obs.dpt_pseudotime.values[i] for i in range(len(ad_sc.obs.dpt_pseudotime.values))]

In [None]:
ad_sc.obs['pseudotime'] = new_pseudo

In [None]:
fig = sc.pl.umap(ad_sc, color='pseudotime', cmap='jet', size=50, return_fig=True, show= False)
try:
    os.mkdir(os.path.join(figure_output_folder, 'figures', 'seq_pseudotime'))
except:
    None
fig.tight_layout()
fig.savefig(os.path.join(figure_output_folder, 'seq_pseudotime', 'all_batch.pdf'))
plt.close()

In [None]:
tag_list = []
for t in ad_sc.obs.index.values:
    prune_tag = t.split('_')[:-2]
    prune_tag = ''.join(prune_tag)
    prune_tag = prune_tag.replace('col', '')
    tag_list.append(prune_tag)


In [None]:
ad_sc.obs['sample_names'] = tag_list 

In [None]:
try:
    os.mkdir(os.path.join(figure_output_folder, 'seq_pseudotime'))
except:
    None

sc.set_figure_params(dpi=300, figsize=(5, 5))
for sn in np.unique(ad_sc.obs['sample_names']):
    fig = sc.pl.umap(ad_sc[(ad_sc.obs['sample_names'] == sn) & (ad_sc.obs.batch == 'rep1')], color='pseudotime', cmap='jet', size=50, return_fig=True, title=sn)
    fig.tight_layout()
    fig.savefig(os.path.join(figure_output_folder, 'seq_pseudotime', f'{sn}.pdf'))
    plt.show()
    plt.close()