 <font size="10">**Pseudo-spatial visualisation: Matplotlib; outside functions**</font>
***

<div class="alert alert-info">
    
<b> <h1> ℹ️ <strong> <font size="6" color="black"> Important notebook information </font> </strong> </h1> </b>
    <hr>
    <font size="4" color="black">
        The purpose of this notebook is use the combined scRNAseq data and masks object generated to plot the data using masks set included <br> <br>
    This notebook everything is run out of functions for easier manipulation</font>

# Import packages

In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
import matplotlib.dates as mdates
import seaborn as sns
import cv2
import joypy
import math
import scipy
from scipy import stats
import re
import datetime
from dateutil.relativedelta import relativedelta

# Load in anndata object

In [None]:
adata = sc.read('anndata.h5ad')

In [None]:
# Check data
adata

In [None]:
# Check column dtypes
adata.obs.info()

In [None]:
# Check for gene presence if required
adata.var[adata.var_names.str.startswith('HOX')]

# User set parameters

In [None]:
# to compare gene expression of 1 celltype of interest / group of interest cut beforehand and then run this 

# Calculate tables to hold potential data to plot

cat1 = 'spatial_location'    # make this the column which the masks relate to i.e. 12 sections
cat2 = 'anno_col'  # change to annotations of interest

#################################################################################
# Mode arguments Elmer:

# manual = manual

# celltype_counts = the total number of cells for specified celltype within a section

# celltype_percentage_within_sections = the percentage of the specified celltype within a section compared against all other celltypes within the same section

# celltype_percentage_across_sections = the percentage of only the specified celltype compared across all sections i.e. the celltypes broad distribution across sections

# gene_expression = average gene expression of a specified gene for a section using all cells within the section. Note: if ran sc.pp.scale then this is average z-scores of the expression values otherwise it is just average expression values


# Mandatory arguments
mask_set = 'set1'

mode = 'gene_expression'   # manual, celltype_counts, celltype_percentage_within_sections, celltype_percentage_across_sections, gene_expression

# Input celltype of gene of interest
plot_value = ['HOXD-AS2']  #['celltype_1','celltype_2','celltype_3']
#plot_value = list(adata.obs['anno_col'].unique())

scale = 'auto'             # for the color bar: auto, manual
cmap = plt.cm.viridis      # using premade colormaps e.g. viridis, plasma, inferno, magma, cividis, Reds

# if chose scale manual
scale_lower_value = 0
scale_upper_value = 20


#################################################################################
# Mode arguments continuous and int variables:

# number of points between the upper and lower plot values - won't work for logged scale
tick_no = 8

# log the column of interest 
scale_log = False

#################################################################################
# Mode arguments for datetime columns

plot_covid = True
use_premade_info = True
cat3 = 'ID_col'

#################################################################################

# If chose manual mode - input manual values below
manual_values = [
42,   # Section 1
3,   # Section 2
31,   # Section 3
41,   # Section 4
15,   # Section 5
26,   # Section 6
67,   # Section 7
28,   # Section 8
91,   # Section 9
10,   # Section 10
13,   # Section 11
22,   # Section 12
]

save_plot = './plot_X.pdf' # False


# Generate plot

In [None]:
######################################################################################
# column checker - check if column can be read as datetime:
#if adata.obs[cat2].dtype == 'category':
if pd.api.types.is_categorical_dtype(adata.obs[cat2]):
    col_date_test = []
    for i in adata.obs[cat2].unique():
        if re.search("^(\d{4})-(0[1-9]|1[0-2]|[1-9])-([1-9]|0[1-9]|[1-2]\d|3[0-1])$", i):
            col_date_test.append(True)
            if len(col_date_test) == len(adata.obs[cat2].unique()):
                adata.obs[cat2] = adata.obs[cat2].astype('datetime64[ns]')
        else:
            break
else:
    pass
######################################################################################

if adata.obs[cat2].dtype == 'datetime64[ns]':
    if use_premade_info == True:
        Dates = adata.uns[cat2]['dates']
        dates = []
        for d in Dates:
            dates.append(datetime.date(*list(d)))
        labels = ['{0:%d %b %Y}:\n{1}'.format(d, l) for l, d in zip (adata.uns[cat2]['labels'], dates)]
        #dates = adata.uns[cat2]['dates']
    else:
        #cat3 = 'Embryo_no' 
        df_new = ((adata.obs[[cat3,cat2]]).reset_index(drop=True).drop_duplicates().reset_index(drop=True)).set_index(cat3)
        date_dict = {}
        for i in df_new.index:
            date_dict[i] = df_new.loc[i][0]
        labels = ['{0:%d %b %Y}:\n{1}'.format(d, l) for l, d in zip (date_dict.keys(), date_dict.values())]
        dates = date_dict.values()


    if plot_covid == True:
        covid_start = datetime.date(2020, 3, 23)
        covid_end = datetime.date(2022, 2, 24)
    
        if covid_start < min(dates):
            start_date = (covid_start - relativedelta(months=1)).replace(day=1)
        else:
            start_date = (min(dates) - relativedelta(months=1)).replace(day=1)
            
        if covid_end > max(dates):
            end_date = (covid_end + relativedelta(months=1)).replace(day=1)
        else:
            end_date = (max(dates) + relativedelta(months=1)).replace(day=1)
    
    else:
        start_date = (min(dates) - relativedelta(months=1)).replace(day=1)
        end_date = (max(dates) + relativedelta(months=1)).replace(day=1)


    fig, ax = plt.subplots(figsize=(10, 4), constrained_layout=True)
    _ = ax.set_ylim(-2, 1.75)
    _ = ax.set_xlim(start_date,end_date)
    _ = ax.axhline(0, xmin=0, xmax=1, c='red', zorder=1)
    
    _ = ax.get_xaxis().set_major_locator(mdates.MonthLocator(interval=1))
    _ = ax.get_xaxis().set_major_formatter(mdates.DateFormatter("%b %Y"))
    
    _ = ax.scatter(dates, np.zeros(len(dates)), s=120, c='green', zorder=2)
    _ = ax.scatter(dates, np.zeros(len(dates)), s=30, c='darkgreen', zorder=3)
    
    label_offsets = np.zeros(len(dates))
    label_offsets[::2] = 1.4#0.35
    label_offsets[1::2] = -2.3#-0.7
    for i, (l, d) in enumerate(zip(labels, dates)):
        _ = ax.text(d, label_offsets[i], l, ha='center', fontfamily='serif', fontweight='bold', color='royalblue',fontsize=12)
    
    
    stems = np.zeros(len(dates))
    stems[::2] = 1#0.3
    stems[1::2] = -1# -0.3   
    markerline, stemline, baseline = ax.stem(dates, stems, use_line_collection=True)
    _ = plt.setp(markerline, marker=',', color='green', markersize=5)
    _ = plt.setp(stemline, color='green', linewidth=1.25)
    
    # hide lines around chart
    for spine in ["left", "top", "right", "bottom"]: 
        _ = ax.spines[spine].set_visible(False)
        
    
    _ = ax.set_title(f'Timeline for {cat2}', fontweight="bold", fontfamily='serif', fontsize=16, 
                     color='royalblue', y=1.4)
    
    
    
    # add month ticks and labels 
    for i in list((pd.date_range(start=start_date, end=end_date, freq='1MS')).map(lambda d: str(d.date()))):
        x_axis_pos = datetime.date(*list(map(int, i.replace('-', ' ').split(' '))))
        _ = ax.axvline(x_axis_pos, ymin=0.5, ymax=0.57, c='black', zorder=1)
        if (x_axis_pos.strftime("%B")[:3]) == 'Jan':
            _ = ax.text(x_axis_pos, -0.3, x_axis_pos.strftime("%B")[:3])
            _ = ax.text(x_axis_pos, 0.5, x_axis_pos.strftime("%Y"))
        else:
            _ = ax.text(x_axis_pos, -0.3, x_axis_pos.strftime("%B")[:3])
    
    if plot_covid == True:
        # add covid line segment
        x_min, x_max = ax.get_xlim()
        ticks = [(tick - x_min)/(x_max - x_min) for tick in ax.get_xticks()]
        tick_labels = ax.get_xticklabels()
        _ = ax.axvline(covid_start, ymin=0.6, ymax=0.7, c='purple', zorder=1)
        _ = ax.axvline(covid_end, ymin=0.6, ymax=0.7, c='purple', zorder=1)
        _ = ax.plot([covid_start,covid_end], [0.4,0.4], linestyle='-', color='purple')
        _ = ax.text(covid_start,0.7, "Covid restrictions",color='purple', fontsize=14)
                 
    # hide tick labels
    _ = ax.set_xticks([])
    _ = ax.set_yticks([])
    
    plt.show()

######################################################################################

#if (set(adata.obs[cat2].cat.categories) == set(['True','False'])) == True:

elif adata.obs[cat2].dtype == bool:
    
    sub_df = adata.obs[[cat1, cat2]]
    plot_df = pd.DataFrame(index=['True','False'])
    
    for s in sub_df[cat1].unique():
        df_sub = sub_df[sub_df[cat1].isin([s])]
        true_counts = df_sub[cat2].sum() #.apply(eval).astype('boolean')
        total = len(df_sub[cat2])
        t_percent = round((true_counts/total)*100,2)
        f_percent = 100 - t_percent
        plot_df[s] = [t_percent,f_percent]


    plot_df = plot_df.T
    plot_df = plot_df.iloc[::-1]
    
    ax = plot_df.plot(kind='barh', stacked=True, legend=False, figsize = (10,10), cmap= cmap)
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    
    for p in ax.patches:
        width, height = p.get_width(), p.get_height()
        x, y = p.get_xy() 
        ax.text(x+width/2, 
                y+height+0.2, 
                '{:.2f}%'.format(width), 
                horizontalalignment='center', 
                verticalalignment='center',
                color='black',
                fontsize=10,
                )
    
    for tick in ax.yaxis.get_major_ticks():
        tick.label.set_position((-0.03, 0))
        tick.label.set_fontsize(12)
    
    plt.tick_params(left = False)
        
    plt.xlabel('Percentage (%)', fontsize=14)
    yl = plt.ylabel(cat1, fontsize=16, color='black', alpha=1)    
    
    plt.title(f'Boolean proportion percentage of {cat2} per anatomical section', fontsize=16, y=1.03)   
    
    plt.legend(title=f'{cat2}', bbox_to_anchor=(1.4, 1), labelspacing = 2, fontsize=12, title_fontsize=15, prop={'size': 15})



######################################################################################

elif adata.obs[cat2].dtype == 'float64' or adata.obs[cat2].dtype == 'int':
    
    if scale_log == True:
        adata.obs['log_scaled_col'] = np.log10(adata.obs[cat2])
        cat_use = 'log_scaled_col'
    else:
        cat_use = cat2
    
    x_lim_range = [adata.obs[[cat_use,cat1]][cat_use].min(),adata.obs[[cat_use,cat1]][cat_use].max()]
    
    array = []
    ticks = np.linspace(int(x_lim_range[0]), int(x_lim_range[1]), num=8, dtype=int)
    array.append(ticks[0])
        
    if scale_log == True:
        for i in ticks[1:]:
            array.append(math.ceil(i))
    else:
        for i in ticks[1:]:
            array.append(math.ceil(i / 100) *100)
    
    fig, axes = joypy.joyplot(data=adata.obs[[cat_use,cat1]], by=cat1, colormap=cmap, fade = True, range_style='group', x_range=x_lim_range, tails=0, xlim='max',  figsize = (10,10), overlap = 0, ylabelsize=12, xlabelsize=12) 
    
    plt.title(f'Ridgeplot of the continual variable {cat2} across {cat1}', fontsize=16, y=1.03)
    axes[-1].set_xticks(array)
    if scale_log == True:
        axes[-1].set_xlabel(f'log10 of {cat2}', fontsize=16, color='black', alpha=1)
    else:
        axes[-1].set_xlabel(cat2, fontsize=16, color='black', alpha=1)
    axes[-1].xaxis.set_label_coords(0.5,-0.07)
    
    axes[-1].yaxis.set_visible(True)
    axes[-1].set_yticks([])
    axes[-1].set_ylabel(cat1, fontsize=16, color='black', alpha=1)
    axes[-1].yaxis.set_label_coords(-0.15,0.5)
    
    patches = [[]] * (len(axes)-1) 
    counter=0
    for i in axes:
        if counter > (len(patches) -1):
            break
        current_handles, current_labels = axes[counter].get_legend_handles_labels()
        patches[counter] = ((current_handles[0].get_edgecolor()).tolist()[0])
        counter+=1
    
    
    counter = 0
    labels = list(adata.obs[cat1].unique())
    lengend_patches = []
    
    gmeans = []
    for i in patches:
        
        if scale_log:
            std_ = round(data[data[cat1].isin([labels[counter]])]['log_scaled_col'].std())
            kur_ = round(stats.kurtosis(data[data[cat1].isin([labels[counter]])]['log_scaled_col']), 2)
            gmean_ = round(stats.gmean(data[data[cat1].isin([labels[counter]])]['log_scaled_col']))
            gmeans.append(gmean_)
        else:
            std_ = round(data[data[cat1].isin([labels[counter]])][cat2].std())
            kur_ = round(stats.kurtosis(data[data[cat1].isin([labels[counter]])][cat2]), 2)
            gmean_ = round(stats.gmean(data[data[cat1].isin([labels[counter]])][cat2]))
            gmeans.append(gmean_)
        lengend_patches.append(mpatches.Patch(color=i, label=r'{a}:     {b},     {c},     {d}'.format(a=labels[counter], b=std_, c=kur_, d=gmean_))) #'
        counter+=1
    
        
    legend1 = plt.legend(handles=lengend_patches, title='Section:         std,        K,        GM', bbox_to_anchor=(1.6, 1), labelspacing = 2, fontsize=12, title_fontsize=15)
    #plt.setp(legend.get_title(), color='red')
    
    counter=0
    for i in patches:
        #if scale_log == True:
        #    X = np.log10(gmeans[counter])
        #else:
        #    X = gmeans[counter]
        X = gmeans[counter]
        axes[counter].axvline(X, color='red', lw=2, alpha=1, ymax=0.1)
        
        #if counter == (len(patches)-1):
        #    plt.text(x=gmeans[counter] - (gmeans[counter])/4, y=-0.05, s='Geometric mean', alpha=1, fontdict = {'color': 'r', 'fontsize' : '14'}) #axes[counter]
            
        counter+=1
    
    
    # 2nd legend
    
    line1 = Line2D([], [], color='red', marker='|', linestyle='None',
                              markersize=10, markeredgewidth=1.5, label='Geometric mean')
    
    legend2 = plt.legend(handles=[line1], title='Graphical overlays', bbox_to_anchor=(1.41, 0.25), labelspacing = 2, fontsize=12, title_fontsize=15)
    
    plt.gca().add_artist(legend1)
    

######################################################################################

else:

    if mode == 'manual':
        values = manual_values
    
    elif mode == 'gene_expression':
        if len(plot_value) > 1:
            raise Exception("Current implementation only supporting single gene expression")
        
        
        df_of_values = (adata.varm[adata.uns['masks'][mask_set]['varm']].T)[plot_value]
        df_of_values = df_of_values.T
        values = []
        for col in df_of_values:
            value = list(df_of_values[col].values)
            values.extend(value)
            
    elif mode in ['celltype_counts','celltype_percentage_within_sections','celltype_percentage_across_sections']:
    
        
        if len(plot_value) > 1:
            adata.obs['combined_annotation'] = adata.obs[cat2].copy().astype(str)
            for value in plot_value:
                adata.obs.loc[adata.obs[cat2].isin([value]), 'combined_annotation'] = 'combined_annotation'
            cat2 = 'combined_annotation'
            plot_value = ['combined_annotation']
        
        adata.obs[cat1] = adata.obs[cat1].astype('category')
        adata.obs[cat2] = adata.obs[cat2].astype('category')
        
        # generate counts table
        counts_table = pd.crosstab(adata.obs[cat1],adata.obs[cat2])
        
        if mode == 'celltype_counts':
            df_of_values = counts_table[plot_value]
            values = []
            for col in df_of_values:
                value = list(df_of_values[col].values)
                values.extend(value)
        
        elif mode == 'celltype_percentage_across_sections':
            percentage_table_column = (round((counts_table/counts_table.sum())*100,2))
            df_of_values = percentage_table_column[plot_value]
            values = []
            for col in df_of_values:
                value = list(df_of_values[col].values)
                values.extend(value)
                
                
        elif mode == 'celltype_percentage_within_sections':
            percentage_table_row = (round(((counts_table.T)/(counts_table.T).sum())*100,2)).T
            df_of_values = percentage_table_row[plot_value]
            values = []
            for col in df_of_values:
                value = list(df_of_values[col].values)
                values.extend(value)
        
    else:
        raise Exception('Mode option not correct. Please use one of the following: manual, celltype_counts, celltype_percentage or gene_expression')
        
######################################################################################
        
    # create a color scale on the range of values inputted
    
    values_dict = dict(zip(adata.obs[cat1].unique(),values))
    if scale == 'auto':
        norm = mpl.colors.Normalize( vmin=min(values) , vmax=max(values) )
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        
    elif scale == 'manual':
        norm = mpl.colors.Normalize( vmin=scale_lower_value , vmax=scale_upper_value )
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    
    else:
        raise Exception('Scale option not correct. Please use either auto or manual')    
        
######################################################################################
    
    base_img = np.full(adata.uns[mask_set+'_shape'], 255, dtype=np.uint8)
    
    for n in adata.uns[mask_set+'_polygons'].keys():
        cv2.fillPoly(base_img, pts=tuple([adata.uns[mask_set+'_polygons'][n][0]]), color=tuple(list(int((255*x)) for x in list(sm.to_rgba([val for key, val in values_dict.items() if key in n][0]))[0:3])))
    
    
    
    
######################################################################################
    
        
    # plot the embryo which can be used in all modes for categorical plotting
    plt.figure(figsize=(20,16))
    
    plt.imshow(base_img)
    
    plt.axis('off')
    
######################################################################################
    
    if len(plot_value) == len(adata.obs[cat2].unique()):
        plt.title(f'All options in category selected', fontsize=40, y=1.05)
    
    else:
        
        # add scale bar to plot
        cb = plt.colorbar(sm)
        cb.ax.tick_params(labelsize=20)
    
    
        if mode == 'manual':
            plt.title('Manual values for every section', fontsize=40, y=1.05)
            cb.set_label("Manual values", fontsize=30, rotation=270, labelpad=70)
            
        elif mode == 'gene_expression':
            plt.title(f'Mean gene expression of {plot_value[0]} for each section',fontsize=40, y=1.05)
            
            if scale == 'manual':
                cb.set_label("Expression" + " (scale values representitive to manual set upper and lower thresholds)", fontsize=20, rotation=270, labelpad=70)
            elif scale == 'auto':
                cb.set_label("Expression", fontsize=30, rotation=270, labelpad=70)
            
        elif mode in ['celltype_counts','celltype_percentage_within_sections','celltype_percentage_across_sections']:
              
            if mode == 'celltype_counts':
                plt.title(f'Number of counts for {plot_value[0]}', fontsize=40, y=1.05)
            
                if scale == 'manual':
                    cb.set_label("No. cells" + " (scale values representitive to manual set upper and lower thresholds)", fontsize=20, rotation=270, labelpad=70)
                elif scale == 'auto':
                    cb.set_label("No. cells", fontsize=30, rotation=270, labelpad=70)
                
            elif mode == 'celltype_percentage_within_sections':
                plt.title(f'Percentage of {plot_value[0]} compared within section', fontsize=40, y=1.05)
                
                if scale == 'manual':
                    cb.set_label("Percentage %" + " (scale values representitive to manual set upper and lower thresholds)", fontsize=20, rotation=270, labelpad=70)
                elif scale == 'auto':
                    cb.set_label("Percentage %", fontsize=30, rotation=270, labelpad=70)
            
            elif mode == 'celltype_percentage_across_sections':
                plt.title(f'Percentage of {plot_value[0]} across sections', fontsize=40, y=1.05)
                
                if scale == 'manual':
                    cb.set_label("Percentage %" + " (scale values representitive to manual set upper and lower thresholds)", fontsize=20, rotation=270, labelpad=70)
                elif scale == 'auto':
                    cb.set_label("Percentage %", fontsize=30, rotation=270, labelpad=70)
    
######################################################################################

if save_plot != False:
    plt.savefig(save_plot)