 <font size="10">**Pseudo-spatial visualisation: Plotly; outside functions**</font>
***

<div class="alert alert-info">
    
<b> <h1> ℹ️ <strong> <font size="6" color="black"> Important notebook information </font> </strong> </h1> </b>
    <hr>
    <font size="4" color="black">
        The purpose of this notebook is use the combined scRNAseq data and masks object generated to plot the data using masks set included <br> <br>
    This notebook everything is run out of functions for easier manipulation</font>

# Import packages

In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
import matplotlib.dates as mdates
import seaborn as sns
import cv2
import joypy
import math
import scipy
from scipy import stats
import re
import datetime
from dateutil.relativedelta import relativedelta
import plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import os
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from ipywidgets import Button, Dropdown, HBox, VBox
from skimage.draw import polygon
from PIL import Image
from IPython.display import display
from plotly.offline import iplot

# Load in anndata object

In [None]:
adata = sc.read('anndata.h5ad')

In [None]:
# check anndata object
adata

In [None]:
# check column dtypes
adata.obs.info()

In [None]:
# check mask sets and correlated column to use 
adata.uns['Mask_selector']

# User set parameters

In [None]:
# to compare gene expression of 1 celltype of interest / group of interest cut beforehand and then run this 

# Calculate tables to hold potential data to plot

cat1 = 'region_col'    # make this the column which the masks relate to i.e. 12 sections
cat2 = 'col_interest'  # change to annotations of interest

#################################################################################
# Mode arguments Elmer:

# manual = manual

# celltype_counts = the total number of cells for specified celltype within a section

# celltype_percentage_within_sections = the percentage of the specified celltype within a section compared against all other celltypes within the same section

# celltype_percentage_across_sections = the percentage of only the specified celltype compared across all sections i.e. the celltypes broad distribution across sections

# gene_expression = average gene expression of a specified gene for a section using all cells within the section. Note: if ran sc.pp.scale then this is average z-scores of the expression values otherwise it is just average expression values

# proportion = bool

# Mandatory arguments
mask_set = 'set_1'

mode = 'mode_interest'   # manual, celltype_counts, celltype_percentage_within_sections, celltype_percentage_across_sections, gene_expression

# Input celltype of gene of interest
plot_value = ['value']  #['value_1','value_2','value_3']
#plot_value = list(adata.obs['anno_col'].unique())
#plot_value = [True]

scale = 'auto'             # for the color bar: auto, manual
cmap = plt.cm.viridis      # using premade colormaps e.g. viridis, plasma, inferno, magma, cividis, Reds

# if chose scale manual
scale_lower_value = 0
scale_upper_value = 20


#################################################################################
# Mode arguments continuous and int variables:

# log the column of interest 
scale_log = False

#################################################################################
# Mode arguments for datetime columns

plot_covid = False
use_premade_info = True
cat3 = 'ID_col'

#################################################################################

# If chose manual mode - input manual values below
manual_values = [
42,   # Section 1
3,   # Section 2
31,   # Section 3
41,   # Section 4
15,   # Section 5
26,   # Section 6
67,   # Section 7
28,   # Section 8
91,   # Section 9
10,   # Section 10
13,   # Section 11
22,   # Section 12
]

save_plot = './plot_X.pdf' # False

# Generate plot

In [None]:
######################################################################################
if adata.obs[cat2].dtype == "bool":
        adata.obs[cat2] = adata.obs[cat2].astype("str").astype("category")



if pd.api.types.is_categorical_dtype(adata.obs[cat2]):
    col_date_test = []
    for i in adata.obs[cat2].unique():
        if re.search("^(\d{4})-(0[1-9]|1[0-2]|[1-9])-([1-9]|0[1-9]|[1-2]\d|3[0-1])$", i):
            col_date_test.append(True)
            if len(col_date_test) == len(adata.obs[cat2].unique()):
                adata.obs[cat2] = adata.obs[cat2].astype('datetime64[ns]')
        else:
            break
else:
    pass
######################################################################################

if adata.obs[cat2].dtype == 'datetime64[ns]':
    if use_premade_info == True:
        Dates = adata.uns[mask_set + '_' + cat2]['dates']
        dates = []
        for d in Dates:
            dates.append(datetime.date(*list(d)))
        labels = ['{0:%d %b %Y}:<br>{1}'.format(d, l) for l, d in zip (adata.uns[mask_set + '_' + cat2]['labels'], dates)]
        #dates = adata.uns[cat2]['dates']
    else:
        #cat3 = 'Embryo_no' 
        df_new = ((adata.obs[[cat3,cat2]]).reset_index(drop=True).drop_duplicates().reset_index(drop=True)).set_index(cat3)
        date_dict = {}
        for i in df_new.index:
            date_dict[i] = df_new.loc[i][0]
        labels = ['{0:%d %b %Y}:<br>{1}'.format(d, l) for l, d in zip (date_dict.keys(), date_dict.values())]
        dates = date_dict.values()
        dates = [i.date() for i in dates]


    if plot_covid == True:
        covid_start = datetime.date(2020, 3, 23)
        covid_end = datetime.date(2022, 2, 24)
    
        if covid_start < min(dates):
            start_date = (covid_start - relativedelta(months=1)).replace(day=1)
        else:
            start_date = (min(dates) - relativedelta(months=1)).replace(day=1)
            
        if covid_end > max(dates):
            end_date = (covid_end + relativedelta(months=1)).replace(day=1)
        else:
            end_date = (max(dates) + relativedelta(months=1)).replace(day=1)
    
    else:
        start_date = (min(dates) - relativedelta(months=1)).replace(day=1)
        end_date = (max(dates) + relativedelta(months=1)).replace(day=1)
    
    
    stems = np.zeros(len(dates))
    stems[::2] = 1#0.3
    stems[1::2] = -1# -0.3

    data = [
        go.Scatter(
            x = dates, 
            #y = np.zeros(len(dates)),
            y=stems,
            mode='markers',
            marker=dict(color='red'),
            text=labels,
            hoverinfo = 'text+x+y',
        )
    ]

    # Use the 'shapes' attribute from the layout to draw the vertical lines
    layout = go.Layout(
        shapes=[dict(
            type='line',
            xref='x',
            yref='y',
            x0=i,
            y0=0,
            x1=i,
            y1=stems[dates.index(i)],
            line=dict(
                color='black',
                width=1
            )
        ) for i in dates],
        title=f'Timeline for {cat2}'
    )

    # Plot the chart
    fig = go.Figure(data, layout)

    fig.add_hline(y=0)

    # add month ticks and labels 
    for i in list((pd.date_range(start=start_date, end=end_date, freq='1MS')).map(lambda d: str(d.date()))):
        x_axis_pos = datetime.date(*list(map(int, i.replace('-', ' ').split(' '))))
        fig.add_vline(x=x_axis_pos, y0=0.45,y1=0.55, line_width=1, line_color="black") 
        if (x_axis_pos.strftime("%B")[:3]) == 'Jan':
            fig.add_annotation(dict(font=dict(color='black',size=7),
                                            x=x_axis_pos,
                                            y=-0.3,
                                            showarrow=False,
                                            text=x_axis_pos.strftime("%B")[:3],
                                            textangle=0,
                                            xanchor='left',
                                            xref="x",
                                            yref="y"))
            fig.add_annotation(dict(font=dict(color='black',size=10),
                                            x=x_axis_pos,
                                            y=0.5,
                                            showarrow=False,
                                            text=x_axis_pos.strftime("%Y"),
                                            textangle=0,
                                            xanchor='left',
                                            xref="x",
                                            yref="y"))
        else:
            fig.add_annotation(dict(font=dict(color='black',size=7),
                                            x=x_axis_pos,
                                            y=-0.3,
                                            showarrow=False,
                                            text=x_axis_pos.strftime("%B")[:3],
                                            textangle=0,
                                            xanchor='left',
                                            xref="x",
                                            yref="y"))

    if plot_covid == True:
        fig.add_vline(x=covid_start, y0=0.55,y1=0.65, line_width=1, line_color="purple")
        fig.add_vline(x=covid_end, y0=0.55,y1=0.65, line_width=1, line_color="purple")
        fig.add_shape(type='line',
                        x0=covid_start,
                        y0=0.25,
                        x1=covid_end,
                        y1=0.25,
                        line=dict(color='purple',width=1),
                        xref='x',
                        yref='y'
        )
        fig.add_annotation(dict(font=dict(color='purple',size=10),
                                            x=covid_start,
                                            y=0.5,
                                            showarrow=False,
                                            text='Covid',
                                            textangle=0,
                                            xanchor='left',
                                            xref="x",
                                            yref="y"))

    
    
    fig.update_xaxes(visible=False, fixedrange=False, autorange=True,  rangeslider=dict(autorange=True,thickness=0.3, bgcolor="#e4f7fe")) #bgcolor="#ADD8E6" #range=[start_date, end_date],
    fig.update_yaxes(visible=False, fixedrange=True)
    fig.update_layout({'plot_bgcolor': 'rgba(0,0,0,0)','paper_bgcolor': 'rgba(0,0,0,0)'}, autosize=False, width=600, height=400, )

    fig.update_xaxes(type="date", dtick="M1")

    #fig.show()
    


######################################################################################
    
elif adata.obs[cat2].dtype in ["float64", "int32", "int64"]:
    fig = go.Figure()
    for index, section in enumerate(adata.obs[cat1].unique()):
        values = (adata.obs[cat2][adata.obs[cat1].isin([section])]).values
        
        if scale_log == True:
            values = np.log(values)
        
        l = len(adata.obs[cat1].unique())
        c = cm.get_cmap(cmap, l)
        fig.add_trace(go.Violin(x=values, line_color=f'rgb{c(index/l)[:3]}', name=section))

    fig.update_traces(orientation='h', side='positive', width=2, points=False)
    fig.update_layout(xaxis_showgrid=False, xaxis_zeroline=False, 
                      title_text=f'Ridgeplot of the continual variable {cat2} across {cat1}', 
                      legend_title="Trace selector",
                     xaxis_title=f"{cat2}",
       yaxis_title=f"{cat1}",)
    fig.update_layout({'plot_bgcolor': 'rgba(0,0,0,0)','paper_bgcolor': 'rgba(0,0,0,0)'}, autosize=False, width=800, height=800, )

    #fig.show()
    
    
######################################################################################

    
else:
    if mode == 'manual':
        values = manual_values
    elif mode == 'gene_expression':
        df_of_values = (adata.varm[mask_set + '_Sectional_gene_expression'].T)[plot_value]
        df_of_values = df_of_values.T
        values = []
        for col in df_of_values:
            value = list(df_of_values[col].values)
            values.extend(value)

    elif mode in ['celltype_counts','celltype_percentage_within_sections','celltype_percentage_across_sections', 'proportion']:

        if len(plot_value) > 1:
            adata.obs['combined_annotation'] = adata.obs[cat2].copy().astype(str)
            for value in plot_value:
                adata.obs.loc[adata.obs[cat2].isin([value]), 'combined_annotation'] = 'combined_annotation'
            cat2 = 'combined_annotation'
            plot_value = ['combined_annotation']

        adata.obs[cat1] = adata.obs[cat1].astype('category')
        adata.obs[cat2] = adata.obs[cat2].astype('category')

        # generate counts table
        counts_table = pd.crosstab(adata.obs[cat1],adata.obs[cat2])

        if mode == 'celltype_counts':
            df_of_values = counts_table[plot_value]
            values = []
            for col in df_of_values:
                value = list(df_of_values[col].values)
                values.extend(value)

        elif mode == 'celltype_percentage_across_sections':
            percentage_table_column = (round((counts_table/counts_table.sum())*100,2))
            df_of_values = percentage_table_column[plot_value]
            values = []
            for col in df_of_values:
                value = list(df_of_values[col].values)
                values.extend(value)


        elif mode == 'celltype_percentage_within_sections':
            percentage_table_row = (round(((counts_table.T)/(counts_table.T).sum())*100,2)).T
            df_of_values = percentage_table_row[plot_value]
            values = []
            for col in df_of_values:
                value = list(df_of_values[col].values)
                values.extend(value)
        
        elif mode == 'proportion':
            values = (counts_table["True"]/counts_table.sum(axis=1)*100).values

    else:
        raise Exception('Mode option not correct. Please use one of the following: manual, celltype_counts, celltype_percentage, gene_expression or proportion')

    ##################################################################################

    # create a color scale on the range of values inputted
    
    values_dict = dict(zip(adata.obs[cat1].unique(),values))
    
    if scale == 'auto':
        norm = mpl.colors.Normalize( vmin=min(values) , vmax=max(values) )
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)

    elif scale == 'manual':
        norm = mpl.colors.Normalize( vmin=scale_lower_value , vmax=scale_upper_value )
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    else:
        raise Exception('Scale option not correct. Please use either auto or manual')    
    
    
    ##################################################################################
    
    fig = go.Figure()
                        
    count=0
    for key in adata.uns[mask_set+'_polygons'].keys():
        if len(plot_value) == len(adata.obs[cat2].unique()):
            title = 'All options in category selected'
        else:
            colorbar_trace=go.Scatter(x=[None],
                         y=[None],
                         mode='markers',
                         marker=dict(
                             colorscale='viridis', 
                             showscale=True,
                             cmin=min(values),
                             cmax=max(values),
                             colorbar=dict(thickness=10, tickmode='auto'), #tickvals=[min(values), max(values)],
                         ),
                         hoverinfo='none',
                        showlegend=False
                        )
            fig.add_trace(colorbar_trace)

            if mode == 'manual':
                title = 'Manual values for every section'
                text = "<br>".join([
                        f"<b>{key}</b>",
                        "",
                        f"Manual value inputted: {plot_value[0]}"])
            elif mode == 'gene_expression':
                title = f'Mean gene expression of {plot_value[0]} for each section'
                text = "<br>".join([
                        f"<b>{key}</b>",
                        "",
                        f"Section mean gene expression value: {plot_value[0]}"])
            elif mode == 'celltype_counts':
                title = f'Number of counts for {plot_value[0]}'                
                text = "<br>".join([
                        f"<b>{key}</b>",
                        "",
                        f"Number of {plot_value[0]} cells in {key}:    {[val for k, val in values_dict.items() if k in key][0]} cells",
                        f"Total number of {plot_value[0]} cells in data:    {sum(values)} cells"])
            elif mode == 'celltype_percentage_within_sections':
                title = f'Percentage of {plot_value[0]} compared within section'
                text = "<br>".join([
                        f"<b>{key}</b>",
                        "",
                        f"{plot_value[0]} represents {values[count]}% of the cells within {key}"])
            elif mode == 'celltype_percentage_across_sections':
                title = f'Percentage of {plot_value[0]} across sections'
                text = "<br>".join([
                        f"<b>{key}</b>",
                        "",
                        f"{key} contains {values[count]}% of the cells for {plot_value[0]}"])
            elif mode == 'proportion':
                title = f'Percentage of truthful {cat2} values within section'
                text = "<br>".join([
                        f"<b>{key}</b>",
                        "",
                        f"{round(values[count],2)}% (2dp) of the cells within {key} are true for {cat2}"])
                                       
            
        polygon0 = go.Scatter(
            x=list(*adata.uns[mask_set+'_polygons'][key][:, :, 0, 0]),
            y=list(*adata.uns[mask_set+'_polygons'][key][:, :, 0, 1]),
            showlegend=False,
            mode="lines",
            fill='toself',
            line=dict(color=("#%02x%02x%02x" % tuple(list(int((255 * x))for x in list(sm.to_rgba([val for k, val in values_dict.items() if k in key][0]))[0:3]))),width=1,),
            fillcolor=("#%02x%02x%02x" % tuple(list(int((255 * x))for x in list(sm.to_rgba([val for k, val in values_dict.items() if k in key][0]))[0:3]))),
            hoveron = 'fills',
            text=text,
            hoverinfo = 'text+x+y',
            #title = title
            )
        fig.add_trace(polygon0)
        fig.update_layout(title_text=title)
        count+=1     
    
    fig.update_xaxes(visible=False, fixedrange=True)
    fig.update_yaxes(visible=False, autorange="reversed", fixedrange=True)
    fig.update_layout({'plot_bgcolor': 'rgba(0,0,0,0)','paper_bgcolor': 'rgba(0,0,0,0)'}, autosize=False, width=600, height=800, )


fig.show()

if save_plot != False:
    fig.write_image(save_plot)