In [None]:
#Dash Imports
import dash
from dash import dcc
from dash import html
from dash.dependencies import Input,Output
import pandas as pd
import plotly.express as px
import pickle
import math 

def LinearPrediction(CtrlDf, Gene1Df, Gene2Df):

    df_to_return = pd.DataFrame()

    for cell in CtrlDf.columns:
        list_to_append = []

        for Ctrl_Cell in CtrlDf[cell]:
            for G1_Cell in Gene1Df[cell]:
                for G2_Cell in Gene2Df[cell]:
                    LinPre_V = Ctrl_Cell - ((Ctrl_Cell - G1_Cell) + (Ctrl_Cell - G2_Cell)) 
                    list_to_append.append(LinPre_V)
        
        df_to_return[cell] = list_to_append
    
    return df_to_return

def Log2(treated_cell_count, control_cell_count, Sudo):
    treated_cell_count = float(treated_cell_count)
    control_cell_count = float(control_cell_count)

    if Sudo == True:
        treated_cell_count = treated_cell_count + 1
        control_cell_count = control_cell_count + 1
    
    if treated_cell_count <= 0 or control_cell_count <= 0:
        log2 = "N/A"
    else:
        log2 = math.log(((treated_cell_count)/(control_cell_count)), 2)

    return log2

external_stylesheets= ['https://codepen.io/chriddyp/pen/bWLwgP.css']
app = dash.Dash(__name__, external_stylesheets = external_stylesheets)
server = app.server

Measured_Data = pd.read_csv('./Measured_Data.csv')
Measured_Data = Measured_Data.set_index(['Timepoint', 'Gene'])
Measured_Data = Measured_Data.sort_index()

with open('./scgen_pred.pickle', 'rb') as handle:
    predictions = pickle.load(handle)

for pre_key in predictions:
    for ge_key in predictions[pre_key]:
        predictions[pre_key][ge_key] = predictions[pre_key][ge_key].sort_index()

hour_string_dict = {'_18_h': '18h', '_24_h': '24h', '_36_h': '36h', '_48_h': '48h', '_72_h': '72h'}
_18_h_gene_list = ["tbx16", "epha4a", "cdx4", "smo", "mafba", "egr2b", "noto", "tbxta"]
_24_h_gene_list = ["cdx4", "noto", "tbx16", "zc4h2", "egr2b", "smo", "tbxta", "epha4a", "hand2", "hoxb1a", "tbx1", "mafba"]
_36_h_gene_list = ["tbx1", "met", "hgfa", "hoxb1a", "foxd3", "zc4h2", "smo", "phox2a", "tbx16", "cdx4", "tfap2a", "hand2", "epha4a", "noto", "tbxta", "egr2b", "mafba"]
_48_h_gene_list = ["hand2", "tfap2a", "foxi1", "zc4h2", "tbx1", "phox2a", "foxd3", "epha4a", "mafba", "hoxb1a", "egr2b" ]
_72_h_gene_list = ["foxd3", "tfap2a", "foxi1" ]

dict_gene_string_list = {'_18_h': _18_h_gene_list, '_24_h': _24_h_gene_list, '_36_h': _36_h_gene_list, '_48_h': _48_h_gene_list, '_72_h': _72_h_gene_list}

app.layout = html.Div([
    html.H1("VAE Double Knockout Data"),

    #----------------------------------------
    html.Div([
        html.Div([
            html.Label(['Hour Selection'], style={'font-weight': 'bold'}),
            dcc.Dropdown(id='hour_dropdown', options=[
                {'label': '18h', 'value': '_18_h'},
                {'label': '24h', 'value': '_24_h'},
                {'label': '36h', 'value': '_36_h'},
                {'label': '48h', 'value': '_48_h'},
                {'label': '72h', 'value': '_72_h'}
            ])
            ], style={'width': '15%', 'display': 'inline-block'}),
        html.Div([
            html.Label(['Gene to Focus'], style={'font-weight': 'bold'}),
            dcc.Dropdown(id='gene1_dropdown')
        ], style={'width': '15%', 'display': 'inline-block'}),
        html.Div([
            html.Label(['Genes to Test'], style={'font-weight': 'bold'}),
            dcc.Dropdown(id='gene2_dropdown', multi = True)
        ], style={'width': '70%', 'display': 'inline-block'}),

        html.Br(),
        
        dcc.Checklist(id='checklist',
            options=[
                {'label': 'Only Report Values that come from more than 10 Cells', 'value': 'BCC'},
                {'label': 'Calculate Fold Change from an Average of the Measured Genes Rather than LinPred', 'value': 'DL2'},
                {'label': 'Sudo One added to Log2 calculations', 'value': 'SD1'},
                {'label': 'Treat Negative Predictions as Zero', 'value': 'NEG'},
                {'label': 'Only Report Fold Change Larger than:', 'value': 'BFC'}
            ]
        ),
        html.Div([
            dcc.Slider(
            id='Big_Fold_Change',
            min = 0, max = 2, value = 1, step = 0.1,
            marks = {
                0: '0',
                0.5: '0.5',
                1: '1',
                1.5: '1.5',
                2: '2'
            }
            )
        ], style={'width': '30%', 'display': 'inline-block'})
    ]),

    #----------------------------------------
    html.Div([
        dcc.Graph(id='heatmap', style={'width': '48%', 'display': 'inline-block'}),
        dcc.Graph(id='boxplot', style={'width': '48%', 'display': 'inline-block'})
    ])
])

@app.callback(
    Output('gene1_dropdown', 'options'),
    Input("hour_dropdown", "value")
)
def gene1_dropdown_avail(hour):
    return [{'label': i, 'value': i} for i in dict_gene_string_list[hour]]
#----------------------------------------
@app.callback(
    Output('gene2_dropdown', 'options'),
    Input("hour_dropdown", "value"),
    Input('gene1_dropdown', 'value')
)
def gene2_dropdown_avail(hour, gene1):

    string_to_show = list(dict_gene_string_list[hour])
    string_to_show.remove(gene1)

    return [{'label': i, 'value': i} for i in string_to_show]
#----------------------------------------
@app.callback(
    Output('heatmap', 'figure'),
    Input("hour_dropdown", "value"),
    Input("gene1_dropdown", "value"),
    Input("gene2_dropdown", "value"),
    Input('checklist', 'value'),
    Input('Big_Fold_Change', 'value')
)

def make_heatmap(hour, gene1, gene2_muti, checklist, BFC_Input):

    hour_readable = hour_string_dict[hour]
    hour_predictions_string = hour
    fold_change_df = pd.DataFrame(columns=['Gene', 'Cell Type', 'Log2'])

    Show_Big_Count_Cells = False
    Show_Big_Fold_Change = False
    Dif_Log2 = False
    Negtive_as_Zero = False
    Sudo_One = False
    title = f'Log2 Fold Change of ScGen Compared to Linear Prediction at {hour_readable}'


    if checklist:
        if 'BCC' in checklist:
            Show_Big_Count_Cells = True
            Big_Cell_Count = 10
        if 'BFC' in checklist:
            Show_Big_Fold_Change =  True
            Big_Fold_Change = BFC_Input
        if 'DL2' in checklist:
            title = f'Log2 Fold Change of ScGen Compared to Measured Counts {hour_readable}'
            Dif_Log2 = True
        if 'NEG' in checklist:
            Negtive_as_Zero = True
        if 'SD1' in checklist:
            Sudo_One = True

    for gene2 in gene2_muti:

        #Find the key needed to access the predicition data
        for key_test in predictions[hour_predictions_string].keys():
            if gene1 in key_test and gene2 in key_test:
                gene_key = key_test

        #Assign our prediciton and Ctrl Data
        gene_1x2_df = predictions[hour_predictions_string][gene_key]

        #Set our predictions to be zero
        if Negtive_as_Zero == True:
            gene_1x2_df[gene_1x2_df < 0] = 0

        CtrlData = Measured_Data.loc[hour_readable, 'ctrl-inj'].reset_index(drop=True)

        #Get Linear Predictions
        LinPredDf = LinearPrediction(CtrlData, Measured_Data.loc[hour_readable, gene1], Measured_Data.loc[hour_readable, gene2])

        #Set Predictions to be Zero
        if Negtive_as_Zero == True:
            LinPredDf[LinPredDf < 0 ] = 0

        #Compress Predictions into one average count for each cell.
        gene_1x2_df = gene_1x2_df.mean(axis=0)
        LinPredDf = LinPredDf.mean(axis=0)
        
        #Calucate Log2 Variable and Push 

        if Show_Big_Count_Cells == True:
            for cell in gene_1x2_df.index:
                if gene_1x2_df[cell] > Big_Cell_Count and LinPredDf[cell] > Big_Cell_Count:
                    if Dif_Log2 == True:
                        #For every cell, cell count is bigger than Big_Cell_Count and Dif log is ture
                        log_GC = Log2(gene_1x2_df[cell], Measured_Data.loc[hour_readable, gene1].reset_index(drop=True).mean(axis=0)[cell], Sudo_One)
                        log2_G = Log2(gene_1x2_df[cell], Measured_Data.loc[hour_readable, gene2].reset_index(drop=True).mean(axis=0)[cell], Sudo_One)

                        if log_GC != 'N/A' and log2_G != 'N/A':
                            log2 = (log_GC + log2_G)/2
                        else:
                            log2 = 'N/A'

                        fold_change_df.loc[len(fold_change_df.index)] = [f'{gene1}-{gene2}', cell, log2]  
                    else:
                        #For every cell, cell count is bigger than Big_Cell_Count and Dif log is false
                        log2 = Log2(gene_1x2_df[cell], LinPredDf[cell], Sudo_One)
                        fold_change_df.loc[len(fold_change_df.index)] = [f'{gene1}-{gene2}', cell, log2]
                else:
                    #For every cell, cell count is smaller than Big_Cell_Count, push NA
                    fold_change_df.loc[len(fold_change_df.index)] = [f'{gene1}-{gene2}', cell, 'N/A']  
        else:
            for cell in gene_1x2_df.index:
                if cell == 'liver':
                    pass
                if Dif_Log2 == True:
                    #For every cell, big cell is false, and Dif Log2 is ture
                    log_GC = Log2(gene_1x2_df[cell], Measured_Data.loc[hour_readable, gene1].reset_index(drop=True).mean(axis=0)[cell], Sudo_One)
                    log2_G = Log2(gene_1x2_df[cell], Measured_Data.loc[hour_readable, gene2].reset_index(drop=True).mean(axis=0)[cell], Sudo_One)

                    if log_GC != 'N/A' and log2_G != 'N/A':
                        log2 = (log_GC + log2_G)/2
                    else:
                        log2 = 'N/A'

                    fold_change_df.loc[len(fold_change_df.index)] = [f'{gene1}-{gene2}', cell, log2]  
                else:
                    #For every cell, big cell is false, and Dif Log2 is false
                    log2 = Log2(gene_1x2_df[cell], LinPredDf[cell], Sudo_One)
                    fold_change_df.loc[len(fold_change_df.index)] = [f'{gene1}-{gene2}', cell, log2] 

    #Pivot the Table for the Heatmap
    fold_change_df = fold_change_df.reset_index()
    fold_change_df = fold_change_df.pivot("Cell Type", "Gene", "Log2")

    #Knock out cell types that do not change, had zero or negtive numbers, and knockout 'NA' cell
    fold_change_df_T = fold_change_df

    cell_type_removed = []

    for cell in fold_change_df.index:
        
        gene_cell_list = fold_change_df.loc[cell].values.flatten().tolist()

        if 'N/A' in gene_cell_list:
            fold_change_df_T = fold_change_df_T.drop(cell)
            cell_type_removed.append(cell)
        else:
            if len(gene2_muti) > 1:
                if gene_cell_list[:-1] == gene_cell_list[1:]:
                    cell_type_removed.append(cell)
                    fold_change_df_T = fold_change_df_T.drop(cell)

    fold_change_df_T = fold_change_df_T.astype('float64')

    #Data Transformation Boolean
    if Show_Big_Fold_Change == True:
        for cell in fold_change_df_T.index:
            
            gene_cell_list = fold_change_df_T.loc[cell].values.flatten().tolist()

            if abs(max(gene_cell_list)) > abs(min(gene_cell_list)):
                big_dif = abs(max(gene_cell_list))
            else:
                big_dif = abs(min(gene_cell_list))

            if big_dif < Big_Fold_Change:
                fold_change_df_T = fold_change_df_T.drop(cell)

    Max_Val = max(fold_change_df_T.max(numeric_only=True).tolist())
    Min_Val = min(fold_change_df_T.min(numeric_only=True).tolist())

    if Max_Val > abs(Min_Val):
        Lg_Rn = Max_Val
        Sm_Rn = Lg_Rn * (-1)
    else:
        Lg_Rn = Min_Val * (-1)
        Sm_Rn = Min_Val



    HeatMap = px.imshow(fold_change_df_T, 
        title = title,
        range_color = [Sm_Rn, Lg_Rn],
        color_continuous_scale = 'RdBu_r'
    )

    HeatMap.write_html("./tfap2a_foxd3_heatmap.html")

    return HeatMap
#-----------------------------------------------------------------------------------------------------------------------
@app.callback(
    Output('boxplot', 'figure'),
    Input('hour_dropdown', "value"),
    Input('heatmap', 'hoverData'),
    Input('checklist', 'value')
)
def HoverToBoxPlot(hour, hoverdata, checklist): 

    genes = hoverdata['points'][0]['x'].split('-')
    gene1 = genes[0]
    gene2 = genes[1]

    for key_test in predictions[hour].keys():
        if gene1 in key_test and gene2 in key_test:
            gene_key = key_test

    cell = hoverdata['points'][0]['y']

    hour_readable = hour_string_dict[hour]
    
    CtrlData = Measured_Data.loc[hour_readable, 'ctrl-inj'].reset_index(drop=True)
    gene1_df = Measured_Data.loc[hour_readable, gene1].reset_index(drop=True)
    gene2_df = Measured_Data.loc[hour_readable, gene2].reset_index(drop=True)
    gene_1x2_df = predictions[hour][gene_key]
    LinPredDf = LinearPrediction(CtrlData, gene1_df, gene2_df)

    if checklist:
        if 'NEG' in checklist:
            gene_1x2_df[gene_1x2_df < 0 ] = 0
            LinPredDf[LinPredDf < 0] = 0

    DataFrame_Dict = {
        "Control": CtrlData,
        gene1: gene1_df,
        gene2: gene2_df,
        'Scgen': gene_1x2_df,
        'LinPred': LinPredDf
    }
    
    #Add Data to one Dataframe to graph
    df_to_graph = pd.DataFrame()
    for df in DataFrame_Dict:
        df_to_append = DataFrame_Dict[df]
        df_to_append = df_to_append[cell].to_frame()
        df_to_append.insert(0, 'Gene_Knock', df)
        df_to_append = df_to_append.reset_index(drop=True)
        df_to_graph = df_to_graph.append(df_to_append)

    df_to_graph = df_to_graph.reset_index(drop=True)

    print(df_to_graph)
    boxplot = px.box(df_to_graph, 
        x = 'Gene_Knock', y = cell, labels= {'Gene_Knock': 'Condition', cell : 'Number of Cells'},
        title = f'Number of {cell} cells at {hour_readable}'
    )

    
    return boxplot

#----------------------------------------
if __name__ == '__main__':
    app.run_server()
