In [None]:
!pip install plotly dash dash-html-components dash-core-components dash-table

In [1]:
import numpy as np
import pandas as pd

from dash import Dash, dcc, html, Input, Output
import plotly.express as px

In [2]:
# Obtain data and format
data = pd.read_csv('../../data/AllData.tsv.xz', sep='\t')
data_counts = pd.read_csv('../../data/stats_per_accession.txt', sep='\t', usecols=['accession', 'num_datapoints', 'datapoints_mutant_percentage'])
data_counts['balance_ratio'] = (50 - data_counts.datapoints_mutant_percentage).abs()
# Add data counts and balance ratio
data_counts = data_counts.drop(columns='datapoints_mutant_percentage')
subplot1_data = (data[(data.crossval == 'random') & (data.folds == 'Mean') & (data.subset == 'complete')]
                 .drop(columns=[col for col in data.columns if  col not in ['Pearson r', 'modeltype', 'accession']]))
subplot1_data = (subplot1_data.merge(data_counts, on='accession')
                 .drop(columns=['accession']))

In [26]:
app = Dash(__name__)

app.layout = html.Div([
                       dcc.Graph(id="heatmap-graph"),
                       html.P("#datapoints:"),
                       dcc.RangeSlider(
                           id='heatmap-x-range-slider',
                           min=subplot1_data.num_datapoints.min(),
                           max=subplot1_data.num_datapoints.max(),
                           step=1000,
                           value=[3000, 5000]
                           ),
                       html.P("balance ratio:"),
                       dcc.RangeSlider(
                            id='heatmap-y-range-slider',
                            min=0, max=50, step=1,
                            value=[12, 40]
                            )
                       ])


@app.callback(
    Output("heatmap-graph", "figure"),
    [Input("heatmap-x-range-slider", "value"),
     Input("heatmap-y-range-slider", "value")])
def update_bar_chart(slider_x, slider_y):
    # Obvtain data
    global subplot1_data
    subplot_data = subplot1_data.copy()
    # Slider values
    low_x, high_x = slider_x
    low_y, high_y = slider_y
    # Obtain categories
    num_intervals = pd.IntervalIndex.from_tuples([(subplot_data.num_datapoints.min() - 1, low_x),
                                                  (low_x, high_x),
                                                  (high_x, subplot_data.num_datapoints.max() + 1)])
    ratio_intervals = pd.IntervalIndex.from_tuples([(subplot_data.balance_ratio.min() - 0.01, low_y),
                                                    (low_y, high_y),
                                                    (high_y, subplot_data.balance_ratio.max() + 0.01)])
    subplot_data['num_bins'] = pd.cut(subplot_data.num_datapoints.values, num_intervals)
    subplot_data['num_bins'] = subplot_data['num_bins'].astype(str)#.cat.rename_categories(dict(zip(num_intervals, ['low', 'medium', 'high'])))
    subplot_data['ratio_bins'] = pd.cut(subplot_data.balance_ratio.values, ratio_intervals)
    subplot_data['ratio_bins'] = subplot_data['ratio_bins'].astype(str)#.cat.rename_categories(dict(zip(ratio_intervals, ['low', 'medium', 'high'])))
    subplot_data = subplot_data.drop(columns=['num_datapoints', 'balance_ratio'])
    subplot_data = subplot_data.groupby(['modeltype', 'num_bins', 'ratio_bins']).agg('mean').reset_index()
    subplot_data = subplot_data.pivot(columns=['num_bins'], values='Pearson r',
                                      index=['modeltype', 'ratio_bins']).reset_index()
    subplot_data.iloc[np.where(subplot_data.modeltype == 'QSAR')[0], 2:] = -subplot_data.iloc[np.where(subplot_data.modeltype == 'QSAR')[0], 2:]
    subplot_data = subplot_data.drop(columns='modeltype').groupby('ratio_bins').agg('sum')
    
    fig = px.imshow(subplot_data, color_continuous_scale='RdYlGn', origin='lower', #zmin=0.0, zmax=0.1
                    labels=dict(x="Amount of data", y="Data balance", color="Difference between PCM and QSAR (random CV)"))
    # fig.update_xaxes(categoryorder='array', categoryarray= [f'({subplot_data.num_datapoints.min()}, {low_x}]',
    #                                                         f'({low_x}, {high_x}]',
    #                                                         f'({high_x}, {subplot_data.num_datapoints.max()}]'])

    return fig

if __name__ == '__main__':
    app.run_server(debug=False)