# Save all the unused functions from optimal_transport notebook

In [None]:
#custom functions for singe cell fate 
def customized_embedding(
    x,
    y,
    vector,
    title=None,
    ax=None,
    order_points=True,
    set_ticks=False,
    col_range=None,
    buffer_pct=0.03,
    point_size=1,
    color_map=None,
    smooth_operator=None,
    set_lim=True,
    vmax=None,
    vmin=None,
    color_bar=False,
    color_bar_label="",
    color_bar_title="",
):
    """
    Plot a vector on an embedding.

    Parameters
    ----------
    x: `np.array`
        x coordinate of the embedding
    y: `np.array`
        y coordinate of the embedding
    vector: `np.array`
        A vector to be plotted.
    color_map: {plt.cm.Reds,plt.cm.Blues,...}, (default: None)
    ax: `axis`, optional (default: None)
        An external ax object can be passed here.
    order_points: `bool`, optional (default: True)
        Order points to plot by the gene expression
    col_range: `tuple`, optional (default: None)
        The default setting is to plot the actual value of the vector.
        If col_range is set within [0,100], it will plot the percentile of the values,
        and the color_bar will show range [0,1]. This re-scaling is useful for
        visualizing gene expression.
    buffer_pct: `float`, optional (default: 0.03)
        Extra space for the plot box frame
    point_size: `int`, optional (default: 1)
        Size of the data point
    smooth_operator: `np.array`, optional (default: None)
        A smooth matrix to be applied to the subsect of gene expression matrix.
    set_lim: `bool`, optional (default: True)
        Set the plot range (x_limit, and y_limit) automatically.
    vmax: `float`, optional (default: np.nan)
        Maximum color range (saturation).
        All values above this will be set as vmax.
    vmin: `float`, optional (default: np.nan)
        The minimum color range, all values below this will be set to be vmin.
    color_bar: `bool`, optional (default, False)
        If True, plot the color bar.
    set_ticks: `bool`, optional (default, False)
        If False, remove figure ticks.

    Returns
    -------
    ax:
        The figure axis
    """

    from matplotlib.colors import Normalize as mpl_Normalize

    if color_map is None:
        color_map = darken_cmap(plt.cm.Reds, 0.9)
    if ax is None:
        fig, ax = plt.subplots()

    coldat = vector.astype(float)

    if smooth_operator is None:
        coldat = coldat.squeeze()
    else:
        coldat = np.dot(smooth_operator, coldat).squeeze()

    if order_points:
        o = np.argsort(coldat)
    else:
        o = np.arange(len(coldat))

    if vmin is None:
        if col_range is None:
            vmin = np.min(coldat)
        else:
            vmin = np.percentile(coldat, col_range[0])

    if vmax is None:
        if col_range is None:
            vmax = np.max(coldat)
        else:
            vmax = np.percentile(coldat, col_range[1])

    if vmax == vmin:
        vmax = coldat.max()

    ax.scatter(
        x[o], y[o], c=coldat[o], s=point_size, cmap=color_map, vmin=vmin, vmax=vmax
    )

    if not set_ticks:
        ax.set_xticks([])
        ax.set_yticks([])
        ax.axis("off")

    if set_lim == True:
        ax.set_xlim(x.min() - x.ptp() * buffer_pct, x.max() + x.ptp() * buffer_pct)
        ax.set_ylim(y.min() - y.ptp() * buffer_pct, y.max() + y.ptp() * buffer_pct)

    if title is not None:
        ax.set_title(title)

    if color_bar:

        norm = mpl_Normalize(vmin=vmin, vmax=vmax)
        Clb = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=color_map), ax=ax)
        Clb.set_label(
            color_bar_label,
            rotation=270,
            labelpad=20,
        )
        Clb.ax.set_title(color_bar_title)
    return ax

def darken_cmap(cmap, scale_factor):
    """
    Generate a gradient color map for plotting.
    """

    cdat = np.zeros((cmap.N, 4))
    for ii in range(cdat.shape[0]):
        curcol = cmap(ii)
        cdat[ii, 0] = curcol[0] * scale_factor
        cdat[ii, 1] = curcol[1] * scale_factor
        cdat[ii, 2] = curcol[2] * scale_factor
        cdat[ii, 3] = 1
    cmap = cmap.from_list(cmap.N, cdat)
    return cmap

In [None]:
#Main function
def single_cell_transition(
    adata,
    selected_state_id_list,
    source="transition_map",
    map_backward=True,
    initial_point_size=3,
    color_bar=True,
    savefig=False,
    df=df):
    
    fig_width=5.5
    fig_height=5
    point_size=3
    state_info=ot_adata.obs["clusters"]
    x_emb = ot_adata.obsm["X_emb"][:,0]
    #x_emb = ot_adata[ot_adata.obs['time']==0].obsm["X_emb"][:,0]
    y_emb = ot_adata.obsm["X_emb"][:,1]
    #y_emb = ot_adata[ot_adata.obs['time']==0].obsm["X_emb"][:,1]
    if not map_backward:
        cell_id_t1 = np.array(list(df.index))
        cell_id_t1 = cell_id_t1.astype(int)
        cell_id_t2 = np.array(list(df.columns))
        cell_id_t2 = cell_id_t2.astype(int)
        Tmap = source
    else:
        cell_id_t2 = np.array(list(df.index))
        cell_id_t2 = cell_id_t2.astype(int)
        cell_id_t1 = np.array(list(df.columns))
        cell_id_t1 = cell_id_t1.astype(int)
        Tmap = source.T
    row = len(selected_state_id_list)
    col = 1
    selected_state_id_list = np.array(selected_state_id_list)
    full_id_list = np.arange(len(cell_id_t1))
    valid_idx = np.in1d(full_id_list, selected_state_id_list)
    if np.sum(valid_idx) < len(selected_state_id_list):
        selected_state_id_list = full_id_list[valid_idx]
    row = len(selected_state_id_list)
    col = 1
    fig = plt.figure(figsize=(fig_width* col, fig_height * row))
    for j, target_cell_ID in enumerate(selected_state_id_list):
        ax0 = plt.subplot(row, col, col * j +1)
        if target_cell_ID < Tmap.shape[0]:
            prob_vec = np.zeros(len(x_emb))
            prob_vec[cell_id_t2] = Tmap[target_cell_ID, :]
            prob_vec = prob_vec / np.max(prob_vec)
            customized_embedding(
                x_emb,
                y_emb,
                prob_vec,
                point_size=point_size,
                ax=ax0,
                color_bar=True,
                color_bar_label="Probability")        
            ax0.plot(
                x_emb[cell_id_t1][target_cell_ID],
                y_emb[cell_id_t1][target_cell_ID],
                "*b",
                markersize=initial_point_size * point_size)
            if map_backward:
                ax0.set_title(f"ID (t2): {target_cell_ID}")
            else:
                ax0.set_title(f"ID (t1): {target_cell_ID}")
    plt.tight_layout()
    if savefig:
        fig.savefig(
            os.path.join(
                 "/fast/AG_Haghverdi/Shashank/fig_cospar/",
                 f"_single_cell_transition_{source}_{map_backward}.png",
            )
        )
    

In [None]:
# Running without custom functions
initial_point_size =3
point_size=3
state_info=ot_adata.obs["clusters"]
x_emb = ot_adata.obsm["X_emb"][:,0]
#x_emb = ot_adata[ot_adata.obs['time']==0].obsm["X_emb"][:,0]
y_emb = ot_adata.obsm["X_emb"][:,1]
#y_emb = ot_adata[ot_adata.obs['time']==0].obsm["X_emb"][:,1]
cell_id_t1 = np.array(list(df.index))
cell_id_t1 = cell_id_t1.astype(int)
cell_id_t2 = np.array(list(df.columns))
cell_id_t2 = cell_id_t2.astype(int)
row = len(selected_state_id_list)
col = 1



In [None]:
from matplotlib.colors import Normalize as mpl_Normalize
for j, target_cell_ID in enumerate(selected_state_id_list):
    ax0 = plt.subplot(2, 1, 1 * j + 1)
    if target_cell_ID < mat_day0_1.shape[0]:
        prob_vec = np.zeros(len(x_emb))
        prob_vec[cell_id_t2] = mat_day0_1[target_cell_ID, :]
        prob_vec = prob_vec / np.max(prob_vec)
        #add a new probability vector for a cell type
        '''prob_vec_2 = np.zeroes(len(x_emb))
        prob_vec_2[cell_id_mpg] = mat_day0_1[target_mpg, :]
        prob_vec_2 = prob_vec_2/ np.max(prob_vec_2)
        '''
        color_map = plt.cm.Reds
#        fig, ax0 = plt.subplots()
        coldat =prob_vec.astype(float)
        coldat = coldat.squeeze()
#        o = np.arange(len(coldat))
        o = np.argsort(coldat)
        vmin = np.min(coldat)
        vmax = np.max(coldat)
        if vmin == vmax:
            vmax = coldat.max()
        ax0.scatter(
            x_emb[o],y_emb[o],c=coldat[o],s=5,cmap=color_map,vmin=vmin,vmax=vmax,zorder=2
        )
#        ax0.scatter(
#            x_emb[o],y_emb[o],c='grey',s=4.5,alpha=0.5,zorder=1
#        )
        #ax0.scatter(
        #    y_emb[o],x_emb[o],c=coldat[o],s=5,cmap=color_map,vmin=vmin,vmax=vmax
        #)
        ax0.set_xticks([])
        ax0.set_yticks([])
        ax0.axis("off")
        ax0.set_xlim(x_emb.min() - x_emb.ptp() * 0.03, x_emb.max() + x_emb.ptp() * 0.03)
        ax0.set_ylim(y_emb.min() - y_emb.ptp() * 0.03, y_emb.max() + y_emb.ptp() * 0.03)
        norm = mpl_Normalize(vmin=vmin,vmax=vmax)
        Clb = plt.colorbar(plt.cm.ScalarMappable(norm=norm,cmap=color_map),ax=ax0)
        Clb.set_label(
            "Probability",
            rotation=270,
            labelpad=20,
        )
        Clb.ax.set_title("")
        ax0.plot(
            x_emb[cell_id_t1][target_cell_ID],
            y_emb[cell_id_t1][target_cell_ID],
            "*b",
            markersize=initial_point_size * point_size)
        ax0.set_title(f"ID (t2): {target_cell_ID}")
        ax0.figure.set_size_inches(15,16)
ot_adata.obs['cluster_dummy'] = ot_adata.obs['clusters'] == ot_adata.obs['clusters'].cat.categories[2]
ot_adata.obs['cluster_dummy'] = ot_adata.obs['cluster_dummy'].astype('category')
sc.pl.umap(ot_adata[ot_adata.obs['time']==1],color='cluster_dummy')
        
'''axs[0, 1].scatter(df.loc[~szt_hi, 'Pcp'], df.loc[~szt_hi, 'Pcp_3day'], c='b', label='SzT=0')
axs[0, 1].scatter(df.loc[szt_hi, 'Pcp'], df.loc[szt_hi, 'Pcp_3day'], c='r', label='SzT=1')
axs[0, 1].legend()

axs[0, 0].set_title('plot red before blue')
axs[0, 0].scatter(df.loc[szt_hi, 'Pcp'], df.loc[szt_hi, 'Pcp_3day'], c='r', label='SzT=1')
axs[0, 0].scatter(df.loc[~szt_hi, 'Pcp'], df.loc[~szt_hi, 'Pcp_3day'], c='b', label='SzT=0')
axs[0, 0].legend()
'''