In [None]:
'''
Goal:Pick out markers for vessel size experiment
'''

In [None]:
import scanpy as sc
import scanpy.external as sce
import os 
import pandas as pd 
import numpy as np


adata_name='venous_ec'
figures = "data/figures/figures"
data = "data/single_cell_files/scanpy_files"

os.makedirs(figures, exist_ok=True)
sc.set_figure_params(dpi_save=300, fontsize=10, figsize=(2,2))
sc.settings.figdir = figures
sns.set_style('white', rc={
    'xtick.bottom': True,
    'ytick.left': True,
})
plt.rcParams["font.family"] = "Arial"
size=15

In [None]:
adata_all = sc.read(f'{data}/{adata_name}_celltyped_no_cc.gz.h5ad')
adata_all

In [None]:
del adata_all.layers['log1p_cc_regress']
del adata_all.layers['cp10k']
adata_all.write(f'{data}/{adata_name}_share.gz.h5ad',compression='gzip')

In [None]:
# for ct in ['Arterial EC','Venous EC']:
#     ct_adata = adata[adata.obs['Cell Subtype_no_cc'] == ct]
#     compare_obs_values_within_groups_to_excel(ct_adata, 'Treatment', group_column='Vessel size category',output_prefix=f"{figures}/{ct}_hyperoxia_degs_by_size")
#     compare_obs_values_within_groups_to_excel(ct_adata, 'Vessel size category', output_prefix=f"{figures}/{ct}_vessel_size_degs")
#     degs = pd.read_excel(f"{figures}/{ct}_hyperoxia_degs_by_size.xlsx", sheet_name=None, index_col=0, header=0)
#     hyperoxia_score = pd.DataFrame(index=degs['small'].index)
#     hyperoxia_score['small'] = degs['small']['scores']
#     hyperoxia_score['medium'] = degs['medium']['scores']
#     hyperoxia_score['large'] = degs['large']['scores']
#     hyperoxia_score = normalize_dataframe(hyperoxia_score)
#     hyperoxia_score['large_small_difference'] = hyperoxia_score['large'] - hyperoxia_score['small']
#     hyperoxia_score = hyperoxia_score.sort_values('large_small_difference')
#     with pd.ExcelWriter(
#         f"{figures}/{ct}_hyperoxia_degs_by_size.xlsx",
#         mode="a",
#         engine="openpyxl",
#         if_sheet_exists="replace",) as writer:
#         hyperoxia_score.to_excel(writer, sheet_name='normalized_scores_together')

In [None]:
sc.set_figure_params(dpi_save=300, fontsize=10, figsize=(2,2))
sc.settings.figdir = figures
sns.set_style('white', rc={
    'xtick.bottom': True,
    'ytick.left': True,
})
plt.rcParams["font.family"] = "Arial"
def custom_dotplot(
    adata, genes, x_obs, y_obs,
    x_order=None, y_order=None,
    min_expr=0.1, cmap='RdBu_r',
    dot_max_size=300, pad=0.5,
    scale_by_gene=False,
    figsize=None,
    save=None, dpi=300,
    show_gridlines=False
):
    """
    Custom dotplot that mimics scanpy's style but allows:
    - Custom x/y groupings
    - Split x-axis for condition + gene
    - Color = average expression, Size = percent expressing
    - Optional scaling and exporting
    """
    adata.obs[x_obs] = adata.obs[x_obs].astype('category')
    adata.obs[y_obs] = adata.obs[y_obs].astype('category')

    if x_order is None:
        x_order = adata.obs[x_obs].cat.categories.tolist()
    if y_order is None:
        y_order = adata.obs[y_obs].cat.categories.tolist()

    df = sc.get.obs_df(adata, keys=[x_obs, y_obs] + genes, layer=None)

    results = []
    for gene in genes:
        for x_val in x_order:
            for y_val in y_order:
                group = df[(df[x_obs] == x_val) & (df[y_obs] == y_val)]
                if group.shape[0] == 0:
                    continue
                expr = group[gene]
                avg_expr = expr.mean()
                prop_expr = (expr > min_expr).mean()
                results.append({
                    "gene": gene,
                    "x_group": x_val,
                    "y_group": y_val,
                    "avg_expr": avg_expr,
                    "prop_expr": prop_expr
                })

    plot_df = pd.DataFrame(results)
    plot_df["x_label"] = plot_df["gene"] + "\n" + plot_df["x_group"]

    # Scale avg_expr within gene
    if scale_by_gene:
        plot_df["scaled_expr"] = plot_df.groupby("gene")["avg_expr"].transform(
            lambda x: (x - x.min()) / (x.max() - x.min() + 1e-8)
        )
        color_col = "scaled_expr"
    else:
        color_col = "avg_expr"

    # X/Y axis label arrangement
    x_labels = []
    x_groups = []
    genes_list = []
    for gene in genes:
        for x_val in x_order:
            x_labels.append(f"{x_val}\n{gene}")
            x_groups.append(x_val)
            genes_list.append(gene)

    y_labels = y_order

    if figsize is None:
        figsize = (len(x_labels) * 0.6 + 2, len(y_labels) * 0.6 + 2)

    fig, ax = plt.subplots(figsize=figsize)

    # Scatter plot
    for _, row in plot_df.iterrows():
        x = x_labels.index(f"{row['x_group']}\n{row['gene']}")
        y = y_labels.index(row["y_group"])
        ax.scatter(
            x, y,
            s=row["prop_expr"] * dot_max_size,
            c=[row[color_col]],
            cmap=cmap,
            vmin=plot_df[color_col].min(),
            vmax=plot_df[color_col].max(),
            edgecolor='black',
            linewidth=0.5
        )

    # Optional gridlines
    if show_gridlines:
        for y in range(len(y_labels)):
            ax.axhline(y, color='lightgray', linestyle=':', linewidth=0.5)

    # Vertical dashed lines between gene groups
    for i in range(1, len(genes)):
        xpos = i * len(x_order) - 0.5
        ax.axvline(x=xpos, color='gray', linestyle='--', linewidth=1)

    ax.set_xticks(range(len(x_labels)))
    ax.set_xticklabels(x_groups, rotation=0, ha='center')
    ax.set_yticks(range(len(y_labels)))
    ax.set_yticklabels(y_labels)
    ax.set_xlim(-pad, len(x_labels) - 1 + pad)
    ax.set_ylim(-pad, len(y_labels) - 1 + pad)
    ax.invert_yaxis()
    ax.set_xlabel('')
    ax.set_ylabel(y_obs)

    # # Add gene names on second x-axis
    ax_gene = ax.secondary_xaxis('top')
    gene_locs = range(len(x_labels))
    gene_locs = [(gene_locs[i] + gene_locs[i + 1]) / 2 for i in range(len(gene_locs) - 1)][::2]
    ax_gene.set_xticks(gene_locs)
    # if len(set(genes_list)) == 1:
    #     ax_gene.set_xticklabels(genes_list[::2], rotation=0, ha='center')
    # else:
    ax_gene.set_xticklabels(genes_list[::2], rotation=0, ha='center')
    # # ax_gene.set_xlabel("Gene")

    # Colorbar
    norm = plt.Normalize(plot_df[color_col].min(), plot_df[color_col].max())
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, location='right', pad=0.02)
    cbar.set_label('Mean expression\nin group')

    # Dot size legend
    prop_vals = plot_df["prop_expr"]
    min_pct = int(np.floor(prop_vals.min() * 100 / 5) * 5)
    max_pct = int(np.ceil(prop_vals.max() * 100 / 5) * 5)
    possible_labels = [i for i in range(min_pct, max_pct + 1) if i % 5 == 0]
    num_labels = min(4, len(possible_labels))
    legend_labels = np.linspace(min_pct, max_pct, num_labels, dtype=int)

    handles = [
        plt.scatter([], [], s=(pct / 100) * dot_max_size, c='gray',
                    edgecolors='black', linewidth=0.5, label=f"{pct}%")
        for pct in legend_labels
    ]

    # Move axis to left to make room for both legends
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.72, box.height])

        # --- Dot size legend with clean rounding ---
    min_pct_raw = plot_df["prop_expr"].min() * 100
    max_pct_raw = plot_df["prop_expr"].max() * 100
    
    min_pct = int(np.floor(min_pct_raw / 5) * 5)
    max_pct = int(np.ceil(max_pct_raw / 5) * 5)
    
    # Generate 4 nicely rounded ticks between min and max
    if max_pct - min_pct < 15:
        ticks = np.linspace(min_pct, max_pct, 4)
    else:
        ticks = np.round(np.linspace(min_pct, max_pct, 4) / 5) * 5
    
    ticks = np.clip(ticks, 0, 100).astype(int)
    
    handles = [
        plt.scatter([], [], s=(p / 100) * dot_max_size, c='gray', edgecolors='black')
        for p in ticks
    ]
    labels = [f"{p}%" for p in ticks]
    
    fig.legend(
        handles,
        labels,
        title="Pct. Expressing",
        loc='center right',
        bbox_to_anchor=(1.40, 0.6),
        frameon=True
    )

    plt.tight_layout()

    if save:
        plt.savefig(save, dpi=dpi, bbox_inches='tight')
        print(f"Saved to {save}")
    else:
        plt.show()
adata_ven = adata[adata.obs['Cell Subtype_no_cc']=='Venous EC']
adata_ven.obs['proliferation_score'] = normalize_dataframe(adata.obs[['proliferation_score']])
df = sc.get.obs_df(adata_ven,['proliferation_score','Vessel size category', 'Treatment'])
df.rename(columns={'proliferation_score':'Proliferation score'},inplace=True)
df = df.loc[df['Vessel size category']!= 'capillary']
fig, axs = plt.subplots(1, 2,figsize=(3,2),sharey=True)
axs = axs.ravel()
for i,treat in enumerate(['Normoxia', 'Hyperoxia']):
    ax = sns.histplot(data=df[df['Treatment']==treat], x="Proliferation score", hue='Vessel size category', hue_order=['small','medium','large']
                      ,palette=adata.uns['Vessel size category_colors'][1:],
                      stat='probability',
                      element='poly', fill=False, common_norm=False, bins=10, ax=axs[i])
    if i==1:
        ax.get_legend().remove()
    else:
        ax.legend(['small','medium','large'],frameon=False,fontsize="8", title='', loc ="upper right")
    ax.set_title(treat)
    ax.set_ylabel('Proportion')
    ax.set_xlabel('')
    ax.set_xticks([0,0.75])
    ax.set_xticklabels(['0','0.75'])
fig.supxlabel('Proliferation score',y=0.15,x=0.52)
fig.tight_layout()
fig.savefig(f'{figures}/histplot_venous_ec_treat_proliferation_score.png', dpi=300, bbox_inches='tight')
plt.close()
adata_ven.obs['Treatment'] = adata_ven.obs['Treatment'].str[0]
custom_dotplot(adata_ven[adata_ven.obs['Vessel size category']!='capillary'],
               ['Mki67', 'Top2a', 'Birc5', 'Hmgb2', 'Cenpf'],
               scale_by_gene=True, 
               x_obs='Treatment',
               y_obs='Vessel size category',
               x_order=['N','H'],
               cmap='Reds',
               dot_max_size=100,
               figsize=(4,2),
               save=f'{figures}/dotplot_proliferation_vec.png')
# custom_dotplot(adata[adata.obs['Cell Subtype_no_cc']=='Arterial EC'],['Cxcl12','Cxcr4'],scale_by_gene=True,  x_obs='Treatment', y_obs='Vessel size category',x_order=['Normoxia','Hyperoxia'],cmap='Reds',dot_max_size=300,save=f'{figures}/dotplot_cxcl12_aec.png',figsize=(4,3))

In [None]:
adata_ven.obs['proliferation_score'] = normalize_dataframe(adata.obs[['proliferation_score']])
df = sc.get.obs_df(adata_ven,['proliferation_score','Vessel size category', 'Treatment'])
df.rename(columns={'proliferation_score':'Proliferation score'},inplace=True)
df = df.loc[df['Vessel size category']!= 'capillary']
fig, axs = plt.subplots(1, 2,figsize=(3,2),sharey=True)
axs = axs.ravel()
for i,treat in enumerate(['N', 'H']):
    ax = sns.histplot(data=df[df['Treatment']==treat], x="Proliferation score", hue='Vessel size category', hue_order=['small','medium','large']
                      ,palette=adata.uns['Vessel size category_colors'][1:],
                      stat='probability',
                      element='poly', fill=False, common_norm=False, bins=10, ax=axs[i])
    if i==1:
        ax.get_legend().remove()
    else:
        ax.legend(['small','medium','large'],frameon=False,fontsize="8", title='', loc ="upper right")
    ax.set_title(treat)
    ax.set_ylabel('Proportion')
    ax.set_xlabel('')
    ax.set_xticks([0,0.75])
    ax.set_xticklabels(['0','0.75'])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
fig.supxlabel('Proliferation score',y=0.15,x=0.52)
fig.tight_layout()
fig.savefig(f'{figures}/histplot_venous_ec_treat_proliferation_score.png', dpi=300, bbox_inches='tight')
plt.close()

In [None]:
sc.set_figure_params(dpi_save=300, fontsize=10, figsize=(1,1))
sc.settings.figdir = figures
sns.set_style('white', rc={
    'xtick.bottom': True,
    'ytick.left': True,
})
size=15
plt.rcParams["font.family"] = "Arial"
sc.pl.umap(adata,color=['Cxcl12','Cxcr4','Ackr3'],colorbar_loc=None,wspace=0.2,hspace=0.25,ncols=3,s=size/4,frameon=False,cmap='viridis',save='cxcl12_signaling.png')
sc.pl.DotPlot(adata[((adata.obs['Cell Subtype_no_cc']=='Arterial EC')
              &(adata.obs['Vessel size category']!='capillary')
                    )],['Cxcl12','Cxcr4'],standard_scale='var',groupby='Vessel size category').style(cmap='viridis').savefig(f'{figures}/dotplot_art_cxcl12_signal.png',dpi=300)
sc.pl.DotPlot(adata[((adata.obs['Cell Subtype_no_cc']=='Venous EC')
              &(adata.obs['Vessel size category']!='capillary')
                    )],['Ackr3'],standard_scale='var',groupby='Vessel size category').style(cmap='viridis').savefig(f'{figures}/dotplot_ven_cxcl12_signal.png',dpi=300)


In [None]:

fig = palantir.plot.plot_gene_trends(adata, ['Esr2'])
ax = plt.gca()
ax.set_xlim([0,1.01])
ax.get_lines()[0].remove()
ax.get_lines()[0].set_color(adata.uns['Cell Subtype_no_cc_colors'][0])

ax.get_legend().remove()
ax.set_xlabel('Pseudotime')
ax.set_ylabel('Gene expression')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

fig = ax.figure  # get the figure associated with those axes
fig.set_size_inches(2, 2)  # set width and height in inches
fig.savefig(f'{figures}/histplot_esr2.png',bbox_inches='tight',dpi=300)

In [None]:
sc.set_figure_params(dpi_save=300, fontsize=10, figsize=(2,2))
sc.pl.umap(adata,color='Esr2',frameon=False,s=size,cmap='viridis',save='_Esr2.png')
sc.pl.DotPlot(adata[((adata.obs['Cell Subtype_no_cc']=='Arterial EC')
              &(adata.obs['Vessel size category']!='capillary')
             )],
              ['Esr2'],
              groupby = ['Vessel size category'],
             ).style(cmap='viridis').savefig(f'{figures}/dotplot_arterial_ec_esr2.png',dpi=300)

In [None]:
# adata.obs['Cxcl12_ge'] = np.array(adata[:,['Cxcl12']].X.todense()).flatten()
# df = correlate_genes_with_pseudotime(adata,layer='log1p',method='spearman',pseudotime='Cxcl12_ge').dropna(how='all')
# df.head(100)

# sc.pl.umap(adata,color=df.head(50).index,hspace=0.3)

# degs = pd.read_excel(f"{figures}/Venous EC_vessel_size_degs.xlsx", sheet_name='medium v large', index_col=0, header=0)
# sc.pl.umap(adata,color=degs.head(50).index,hspace=0.3)
# degs = pd.read_excel(f"{figures}/Arterial EC_vessel_size_degs.xlsx", sheet_name='medium v large', index_col=0, header=0)
# sc.pl.umap(adata,color=degs.head(50).index,hspace=0.3)

In [None]:
adata_mural = sc.read(f"{data}/venous_ec_mural_velocity.gz.h5ad")

print(adata_mural.uns['Cell Subtype_no_cc_colors'])

sc.pl.umap(adata_mural,color='palantir_pseudotime',title='Pseudotime',cmap='viridis',frameon=False,size=size,save='mural_pseudotime.png')
sc.pl.umap(adata_mural,color='Cell Subtype_no_cc',title='',legend_loc='on data',frameon=False,size=size,save='mural_celltype.png')


df = correlate_genes_with_pseudotime(adata_mural[adata_mural.obs['Cell Subtype_no_cc']=='Vascular smooth muscle'],layer='log1p',method='pearson',pseudotime='palantir_pseudotime')
df = df.dropna(how='all')
mural_large_genes = df.head(top_n_genes).index.tolist()
mural_small_genes = df.tail(top_n_genes).index.tolist()[::-1]

# Create the Venn diagram
venn = venn3([set(arterial_large_genes), set(venous_large_genes),set(mural_large_genes)], 
             set_labels=('Arterial', 'Venous','VSM'), 
             set_colors=('#4A90E2', '#E35D6A','#48C774'), 
             alpha=0.7)

# Optional: Customize font size
for text in venn.set_labels:
    text.set_fontsize(12)
for text in venn.subset_labels:
    if text:
        text.set_fontsize(12)

# Show the plot
plt.title("Top 50 genes positively correlated with pseudotime")
plt.savefig(f'{figures}/venn_diagram_large_mural.png',dpi=300,bbox_inches='tight')
plt.close()

# Create the Venn diagram
venn = venn3([set(arterial_small_genes), set(venous_small_genes),set(mural_small_genes)], 
             set_labels=('Arterial', 'Venous','VSM'), 
             set_colors=('#4A90E2', '#E35D6A','#48C774'), 
             alpha=0.7)

# Optional: Customize font size
for text in venn.set_labels:
    text.set_fontsize(12)
for text in venn.subset_labels:
    if text:
        text.set_fontsize(12)

# Show the plot
plt.title("Top 50 genes positively correlated with pseudotime")
plt.savefig(f'{figures}/venn_diagram_small_mural.png',dpi=300,bbox_inches='tight')
plt.close()

overlap_large_all = set(arterial_large_genes) & set(venous_large_genes) & set(mural_large_genes)
overlap_small_all = set(arterial_small_genes) & set(venous_small_genes) & set(mural_small_genes)
for gene_ls in [overlap_large_all, overlap_small_all]:
    for gene in gene_ls:
        fig, axs = plt.subplots(1, 2, figsize=(3, 1.5))
        axs = axs.ravel()
        sc.pl.umap(adata,
                   color=gene,
                   size=size/2,
                        ax=axs[0],
                        frameon=False,
                   cmap='viridis',
                        show=False
                        )
        sc.pl.umap(adata_mural,
                   color=gene,
                   ax=axs[1],
                   frameon=False,
                   size=size/2,
                   cmap='viridis',
                   show=False
                   )
        fig.savefig(f'{figures}/umap_size_mural_endo_overlap_{gene}.png', dpi=300, bbox_inches='tight')
        plt.close()