In [1]:
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import modelling

In [2]:
saved_data_dir = '../../../simulation_data/sensitivity_analysis/'

In [7]:
param_interest = 'N'

# fig = go.Figure()
fig = make_subplots(rows=2, cols=1, row_heights=[0.7, 0.3])

drug_list = ['dofetilide', 'verapamil', 'terfenadine',
             'cisapride', 'quinidine', 'sotalol']
drug = drug_list[5]

filename = 'SA_' + drug + '_' + param_interest + '.csv'
# filename = 'SA_' + drug + '_check.csv'
df = pd.read_csv(saved_data_dir + filename,
                    header=[0, 1], index_col=[0],
                    skipinitialspace=True)
df = df.sort_values(by=[('param_values', param_interest)])
# data included: drug_conc_Hill, peak_current, Hill_curve, param_values,
# drug_conc_AP, APD_trapping, APD_conductance and MSE

param_lib = modelling.BindingParameters()
param_true = param_lib.binding_parameters[drug][param_interest]

param_range = df['param_values'][param_interest].values
RMSError = df['RMSE']['RMSE'].values
MAError = df['MAE']['MAE'].values
fig.add_trace(
    go.Scatter(
        visible=True,
        x=param_range,
        y=RMSError,
        mode='lines+markers',
        name='root mean square difference'
    ), row=2, col=1
)
fig.add_trace(
    go.Scatter(
        visible=True,
        x=param_range,
        y=MAError,
        mode='lines+markers',
        name='mean absolute difference',
        line=dict(color="#ff0000")
    ), row=2, col=1
)


cleaned_RMSError = [x for x in RMSError if not math.isnan(x)]
cleaned_MAError = [x for x in MAError if not math.isnan(x)]
max_RMSE = max(cleaned_RMSError)
max_MAE = max(cleaned_MAError)
max_error = max(max_RMSE, max_MAE)

min_RMSE = min(cleaned_RMSError)
min_MAE = min(cleaned_MAError)
min_error = min(min_RMSE, min_MAE)

fig.add_trace(
    go.Scatter(
        visible=True,
        x=[param_true, param_true],
        y=[min_error, max_error],
        mode='lines',
        name='true value',
        line=dict(color="#ff0000")
    ), row=2, col=1
)

min_drug_conc = []
max_drug_conc = []
for r in range(len(df.index)):

    param_values = df.iloc[[r]]['param_values']
    changing_param_value = param_values[param_interest].values[0]
    Hill_n = param_values['N'].values[0]
    half_effect_conc = 5.483e8

    # Plot APD90
    drug_conc_AP = df.iloc[[r]]['drug_conc_AP'].values[0]
    APD_trapping = df.iloc[[r]]['APD_trapping'].values[0]
    APD_conductance = df.iloc[[r]]['APD_conductance'].values[0]

    fig.add_trace(
        go.Scatter(
            visible=False,
            x=drug_conc_AP,
            y=APD_trapping,
            mode='lines+markers',
            name='APD_trapping'
        ), row=1, col=1,
    )
    fig.add_trace(
        go.Scatter(
            visible=False,
            x=drug_conc_AP,
            y=APD_conductance,
            mode='lines+markers',
            name='APD_conductance'
        ), row=1, col=1,
    )
    fig.add_trace(
        go.Scatter(
            visible=False,
            x=[changing_param_value, changing_param_value],
            y=[min_error, max_error],
            mode='lines',
            name='param_value'
        ), row=2, col=1,
    )

    min_drug_conc.append(min(drug_conc_AP))
    max_drug_conc.append(max(drug_conc_AP))

min_drug_conc = min(min_drug_conc)
max_drug_conc = max(max_drug_conc)

fig.data[3].visible=True
fig.data[4].visible=True
fig.data[5].visible=True

sets = []
for i in range(int((len(fig.data) - 2) / 3)):
    param_set = dict(
        method="update",
        args=[{"visible": [False] * len(fig.data)},
              {"title": "Drug " + drug + " at parameter " + param_interest + " = " +
               "%.3f" % param_range[i]}],
        label="%.3f" % param_range[i]
    )
    param_set["args"][0]["visible"][0] = True
    param_set["args"][0]["visible"][1] = True
    param_set["args"][0]["visible"][2] = True
    param_set["args"][0]["visible"][3 * i + 3] = True
    param_set["args"][0]["visible"][3 * i + 4] = True
    param_set["args"][0]["visible"][3 * i + 5] = True
    sets.append(param_set)

sliders = [dict(
    active=5,
    currentvalue={"prefix": param_interest + " = "},
    pad={"t": len(df.index)},
    steps=sets
)]

fig.update_layout(sliders=sliders, yaxis1=dict(range=[0, 1050]),
                  xaxis1=dict(range=[np.log10(min_drug_conc), np.log10(max_drug_conc)]))

fig.update_xaxes(title_text="Normalised drug concentration", type="log", row=1, col=1)
fig.update_xaxes(title_text="Parameter value", row=2, col=1)

fig.update_yaxes(title_text="APD90", row=1, col=1)
fig.update_yaxes(title_text="APD difference", row=2, col=1)
fig.show()