To plot the APD90 difference in the domain of the 3 interested parameters

In [1]:
import numpy as np
import os
import pandas as pd
import plotly.graph_objects as go

import modelling

In [2]:
category = False

In [3]:
# 3D plot in parameter space
# Plot for known drugs
param_lib = modelling.BindingParameters()
drug_list = param_lib.drug_compounds

SA_model = modelling.SensitivityAnalysis()
param_names = SA_model.param_names

starting_param_df = pd.DataFrame([1] * 5, index=param_names).T
ComparisonController = modelling.ModelComparison(starting_param_df)

discrete_colors = ['red', 'blue', 'black']
APD_diff_label = ['similar', 'SD higher', 'CS higher']

In [30]:
fig = go.Figure()

# Read data for drugs
saved_data_dir = '../../../simulation_data/'
filename = 'SA_alldrugs.csv'
df = pd.read_csv(saved_data_dir + filename,
                 header=[0, 1], index_col=[0],
                 skipinitialspace=True)

Vhalf_list = df['param_values']['Vhalf'].values
Kmax_list = df['param_values']['Kmax'].values
Ku_list = df['param_values']['Ku'].values
drug_list = df['drug']['drug'].values

RMSError_drug = df['RMSE']['RMSE'].values
MAError_drug = df['MAE']['MAE'].values

# Read data for space
saved_data_dir = '../../../simulation_data/sensitivity_analysis/'
file_prefix = 'copy_SA_allparam_'
result_files = [f for f in os.listdir(saved_data_dir) if f.startswith(file_prefix)]

Vhalf_range = np.array([])
Kmax_range = np.array([])
Ku_range = np.array([])

RMSError = np.array([])
MAError = np.array([])

param_id = np.array([])

for file in result_files:
    df = pd.read_csv(saved_data_dir + file,
                     header=[0, 1], index_col=[0],
                     skipinitialspace=True)

    Vhalf_range = np.concatenate((Vhalf_range, df['param_values']['Vhalf'].values))
    Kmax_range = np.concatenate((Kmax_range, df['param_values']['Kmax'].values))
    Ku_range = np.concatenate((Ku_range, df['param_values']['Ku'].values))

    RMSError = np.concatenate((RMSError, df['RMSE']['RMSE'].values))
    MAError = np.concatenate((MAError, df['MAE']['MAE'].values))

    param_id = np.concatenate((param_id, df['param_id']['param_id'].values))

if category:
    color_code_list = []
    for i in range(len(drug_list)):
        if RMSError_drug[i] < 100:
            color_code = 0
        elif MAError_drug[i] > 0:
            color_code = 1
        else:
            color_code = 2

        color_code_list.append(color_code)

    fig.add_trace(
        go.Scatter3d(
            x=Vhalf_list,
            y=Kmax_list,
            z=Ku_list,
            mode='markers',
            marker_symbol='diamond',
            name='drugs',
            marker=dict(
                color=color_code_list,
                colorscale=discrete_colors
            )
        )
    )

    color_code_list = []
    for i in range(len(param_id)):
        if RMSError[i] < 100:
            color_code = 0
        elif MAError[i] > 0:
            color_code = 1  # SD higher
        else:
            color_code = 2  # CS higher
        color_code_list.append(color_code)

    fig.add_trace(
        go.Scatter3d(
            x=Vhalf_range,
            y=Kmax_range,
            z=Ku_range,
            mode='markers',
            name='space',
            marker=dict(
                color=color_code_list,
                colorscale=discrete_colors,
                opacity=0.5,
                size=5
            )
        )
    )
else:
    RMSError_drug = np.array(RMSError_drug) * np.array(MAError_drug) / np.abs(np.array(MAError_drug))
    RMSError_space = RMSError * MAError / np.abs(MAError)
    
    cmin = min(min(RMSError_drug), min(RMSError_space))
    cmax = max(max(RMSError_drug), max(RMSError_space))

    hovertext = np.empty(shape=(12,3,1), dtype='object')
    hovertext[:,0] = np.array(drug_list).reshape(-1,1)
    hovertext[:,1] = np.array(RMSError_drug).reshape(-1,1)
    hovertext[:,2] = np.array(MAError_drug).reshape(-1,1)
    fig.add_trace(
        go.Scatter3d(
            x=Vhalf_list,
            y=Kmax_list,
            z=Ku_list,
            mode='markers',
            marker_symbol='diamond',
            name='',
            customdata=hovertext,
            hovertemplate='<b>%{customdata[0]}</b> <br>RMSD = %{customdata[1]:.2e} <br>MD = %{customdata[2]:.2e}',
            marker=dict(
                color=RMSError_drug,
                colorscale='Portland',
                colorbar=dict(thickness=20),
                cmin=cmin,
                cmax=cmax
            )
        )
    )
    
    hovertext = np.empty(shape=(len(param_id),3,1), dtype='object')
    hovertext[:,0] = np.array(param_id).reshape(-1,1)
    hovertext[:,1] = np.array(RMSError_space).reshape(-1,1)
    hovertext[:,2] = np.array(MAError).reshape(-1,1)

    fig.add_trace(
        go.Scatter3d(
            x=Vhalf_range,
            y=Kmax_range,
            z=Ku_range,
            mode='markers',
            name='',
            customdata=hovertext,
            hovertemplate='<b>id: %{customdata[0]}</b> <br>RMSD = %{customdata[1]} <br>MD = %{customdata[2]}',
            marker=dict(
                color=RMSError_space,
                colorscale='Portland',
                opacity=0.5,
                size=5,
                colorbar=dict(thickness=20),
                cmin=cmin,
                cmax=cmax
            )
        )
    )

fig.update_layout(scene = dict(
                    xaxis_title='Vhalf',
                    yaxis_title='Kmax',
                    zaxis_title='Ku',
                    yaxis = dict(dtick=1,
                                 type='log'),
                    zaxis = dict(dtick=1,
                                 type='log')))

fig.show()