# INIT

In [None]:
########################
# load libraries
library(tidyverse)
library(lubridate)
library(RColorBrewer)
library(repr)
options(repr.plot.width=6, repr.plot.height=6)

########################
### MULTIPLOT
# SOURCE: http://www.cookbook-r.com/Graphs/Multiple_graphs_on_one_page_(ggplot2)/
multiplot <- function(..., plotlist=NULL, file, cols=1, layout=NULL) {
  library(grid)

  # Make a list from the ... arguments and plotlist
  plots <- c(list(...), plotlist)

  numPlots = length(plots)

  # If layout is NULL, then use 'cols' to determine layout
  if (is.null(layout)) {
    # Make the panel
    # ncol: Number of columns of plots
    # nrow: Number of rows needed, calculated from # of cols
    layout <- matrix(seq(1, cols * ceiling(numPlots/cols)),
                    ncol = cols, nrow = ceiling(numPlots/cols))
  }

 if (numPlots==1) {
    print(plots[[1]])

  } else {
    # Set up the page
    grid.newpage()
    pushViewport(viewport(layout = grid.layout(nrow(layout), ncol(layout))))

    # Make each plot, in the correct location
    for (i in 1:numPlots) {
      # Get the i,j matrix positions of the regions that contain this subplot
      matchidx <- as.data.frame(which(layout == i, arr.ind = TRUE))

      print(plots[[i]], vp = viewport(layout.pos.row = matchidx$row,
                                      layout.pos.col = matchidx$col))
    }
  }
}
print("done")

# SETTINGS

In [None]:
#######################
# Select Experiment
exp_name = 'FINAL'

# need the grid number for the config file, could also be an interim model from this exp_grid
exp_grid = 1

# OPTIONAL:
continued = FALSE
continued_grid = 0

########################
### SAVE PLOTS TO DISK
SAVE_PDF = FALSE       # save to TIFF
SAVE_PNG = TRUE        # save to PNG

###################
# load experiment
orig_exp_name = exp_name
if(continued == TRUE){
    exp_name =  paste0(exp_name,"_",continued_grid,"_continued")
    print(paste0('CONTINUED experiment name: ', exp_name))
    print(paste0('exp_grid: ', continued_grid))
} else {
    print(paste0('Experiment name: ', exp_name))
    print(paste0('exp_grid: ', exp_grid))
}

########################
### set working directory to "SEPSIS"
#location = "D:/ResearchData/rl_sepsis/SEPSIS"
location = "/Users/luca/Projects/rl_sepsis/SEPSIS"
setwd(location)
print(getwd())

########################
### SET PATHs
data_path <- paste(location, '/experiments/', orig_exp_name, '/data/',sep="")
exp_fig_path = paste(location, '/experiments/', orig_exp_name, '/figures/',sep="")
exp_data_path = paste(location, '/experiments/', orig_exp_name, '/models/', exp_name, '_', exp_grid, '/',sep="")
print(data_path)
print(exp_fig_path)
print(exp_data_path)

# Import data and preprocessing

In [None]:
########################
### Get DATA
train_data <- read_csv(paste0(data_path, 'train_data.csv'), col_types = cols(.default = col_double(),
                                                                              PatientID = col_integer(),
                                                                              interval_start_time = col_datetime(format = ""),
                                                                              interval_end_time = col_datetime(format = ""),
                                                                              Discharge = col_integer(),
                                                                              discrete_action = col_integer(),
                                                                              Reward = col_integer(),
                                                                              discrete_action_original = col_integer(),
                                                                              row_id = col_integer(),
                                                                              row_id_next = col_integer()
                                                                            ))
val_data <- read_csv(paste0(data_path, 'val_data.csv'), col_types = cols(.default = col_double(),
                                                                              PatientID = col_integer(),
                                                                              interval_start_time = col_datetime(format = ""),
                                                                              interval_end_time = col_datetime(format = ""),
                                                                              Discharge = col_integer(),
                                                                              discrete_action = col_integer(),
                                                                              Reward = col_integer(),
                                                                              discrete_action_original = col_integer(),
                                                                              row_id = col_integer(),
                                                                              row_id_next = col_integer()
                                                                            ))
test_data <- read_csv(paste0(data_path, 'test_data.csv'), col_types = cols(.default = col_double(),
                                                                              PatientID = col_integer(),
                                                                              interval_start_time = col_datetime(format = ""),
                                                                              interval_end_time = col_datetime(format = ""),
                                                                              Discharge = col_integer(),
                                                                              discrete_action = col_integer(),
                                                                              Reward = col_integer(),
                                                                              discrete_action_original = col_integer(),
                                                                              row_id = col_integer(),
                                                                              row_id_next = col_integer()
                                                                            ))
train_opt <- read_csv(paste0(exp_data_path, 'DQN_Qvalues_traindata.csv'), col_types = cols(.default = col_double(),
                                                                                            best_action = col_integer(),
                                                                                            state_id = col_integer()
                                                                                            ))
val_opt <- read_csv(paste0(exp_data_path, 'DQN_Qvalues_valdata.csv'), col_types = cols(.default = col_double(),
                                                                                            best_action = col_integer(),
                                                                                            state_id = col_integer()
                                                                                            ))
test_opt <- read_csv(paste0(exp_data_path, 'DQN_Qvalues_testdata.csv'), col_types = cols(.default = col_double(),
                                                                                            best_action = col_integer(),
                                                                                            state_id = col_integer()
                                                                                            ))

########################
# Action distribution
action_mappings <- expand.grid(discrete_IV = c(0, 1, 2, 3, 4), discrete_VP = c(0, 1, 2, 3, 4)) %>% mutate(real_discrete_action = 0:24)
model_actions     = c(0, 1, 2, 3, 4, NA, 5, 6, 7, 8, NA,  9, 10, 11, 12, NA, 13, 14, 15, 16, NA, 17, 18, 19, 20)
real_actions_ind =  c(0, 1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 21, 22, 23, 24)
action_mappings = action_mappings %>% mutate(model_action = model_actions[real_discrete_action+1])
DQN_action_mappings = action_mappings %>% rename(DQN_discrete_action = real_discrete_action)

# Apply mapping
train_data = train_data %>% mutate(real_discrete_action = real_actions_ind[discrete_action+1])
val_data = val_data %>% mutate(real_discrete_action = real_actions_ind[discrete_action+1])
test_data = test_data %>% mutate(real_discrete_action = real_actions_ind[discrete_action+1])
train_opt = train_opt %>% mutate(DQN_discrete_action = real_actions_ind[best_action+1])
val_opt = val_opt %>% mutate(DQN_discrete_action = real_actions_ind[best_action+1]) 
test_opt = test_opt %>% mutate(DQN_discrete_action = real_actions_ind[best_action+1]) 

########################
train = train_data %>% rename(state_id = row_id) %>% left_join(train_opt, by='state_id')  %>% #select (-row_id_next, -real_discrete_action.y) %>% rename(real_discrete_action = real_discrete_action.x) %>%
        group_by(PatientID) %>%
        mutate(sum_reward = sum(Reward)) %>%
        mutate(relative_time = difftime(interval_start_time, min(interval_start_time), units = 'hours')) %>% 
        mutate(discharge = case_when(sum(sum_reward) > 0 ~ 1, sum(sum_reward) < 0 ~ 0, TRUE ~ 9)) %>%
        ungroup() 

val =   val_data %>% rename(state_id = row_id) %>% left_join(val_opt, by='state_id') %>% #select (-row_id_next, -real_discrete_action.y) %>% rename(real_discrete_action = real_discrete_action.x) %>%
        group_by(PatientID) %>%
        mutate(sum_reward = sum(Reward)) %>%
        mutate(relative_time = difftime(interval_start_time, min(interval_start_time), units = 'hours')) %>% 
        mutate(discharge = case_when(sum(sum_reward) > 0 ~ 1, sum(sum_reward) < 0 ~ 0, TRUE ~ 9)) %>%
        ungroup() 

test =  test_data %>% rename(state_id = row_id) %>% left_join(test_opt, by='state_id')  %>% #select (-row_id_next, -real_discrete_action.y) %>% rename(real_discrete_action = real_discrete_action.x) %>%
        group_by(PatientID) %>%
        mutate(sum_reward = sum(Reward)) %>%
        mutate(relative_time = difftime(interval_start_time, min(interval_start_time), units = 'hours')) %>% 
        mutate(discharge = case_when(sum(sum_reward) > 0 ~ 1, sum(sum_reward) < 0 ~ 0, TRUE ~ 9)) %>%
        ungroup() 

########################
# Action distribution
action_mappings <- expand.grid(discrete_IV = c(0, 1, 2, 3, 4), discrete_VP = c(0, 1, 2, 3, 4)) %>% mutate(real_discrete_action = 0:24)
model_actions     = c(0, 1, 2, 3, 4, NA, 5, 6, 7, 8, NA,  9, 10, 11, 12, NA, 13, 14, 15, 16, NA, 17, 18, 19, 20)
real_actions_ind =  c(0, 1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 21, 22, 23, 24)
action_mappings = action_mappings %>% mutate(model_action = model_actions[real_discrete_action+1])

PHY_mapping = action_mappings %>% select(-model_action)
names(PHY_mapping) <- c("PHY_IV","PHY_VP","PHY_action")

DQN_mapping = action_mappings %>% select(-model_action)
names(DQN_mapping) <- c("DQN_IV","DQN_VP","DQN_action")

# Apply mapping again with additional mutations
dose_train = train %>%  mutate(PHY_action = real_actions_ind[discrete_action+1]) %>% 
                        mutate(DQN_action = real_actions_ind[best_action+1]) %>% 
                        full_join(PHY_mapping, by = "PHY_action") %>% group_by(PHY_IV, PHY_VP, PHY_action) %>% 
                        full_join(DQN_mapping, by = "DQN_action") %>% group_by(DQN_IV, DQN_VP, DQN_action) %>%
                        mutate(VP_diff = DQN_VP-PHY_VP) %>% mutate(IV_diff = DQN_IV-PHY_IV)

dose_val = val %>%      mutate(PHY_action = real_actions_ind[discrete_action+1]) %>% 
                        mutate(DQN_action = real_actions_ind[best_action+1]) %>% 
                        full_join(PHY_mapping, by = "PHY_action") %>% group_by(PHY_IV, PHY_VP, PHY_action) %>% 
                        full_join(DQN_mapping, by = "DQN_action") %>% group_by(DQN_IV, DQN_VP, DQN_action) %>%
                        mutate(VP_diff = DQN_VP-PHY_VP) %>% mutate(IV_diff = DQN_IV-PHY_IV)

dose_test = test %>%    mutate(PHY_action = real_actions_ind[discrete_action+1]) %>% 
                        mutate(DQN_action = real_actions_ind[best_action+1]) %>% 
                        full_join(PHY_mapping, by = "PHY_action") %>% group_by(PHY_IV, PHY_VP, PHY_action) %>% 
                        full_join(DQN_mapping, by = "DQN_action") %>% group_by(DQN_IV, DQN_VP, DQN_action) %>%
                        mutate(VP_diff = DQN_VP-PHY_VP) %>% mutate(IV_diff = DQN_IV-PHY_IV)

########################
### preprocess data
dose_time_train = dose_train %>% mutate(PHY_VP_OnOFF_0 = case_when(PHY_VP == 0 ~ 1, PHY_VP != 0  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_0 = case_when(PHY_IV == 0 ~ 1, PHY_IV != 0  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_0 = case_when(DQN_VP == 0 ~ 1, DQN_VP != 0  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_0 = case_when(DQN_IV == 0 ~ 1, DQN_IV != 0  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_1 = case_when(PHY_VP == 1 ~ 1, PHY_VP != 1  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_1 = case_when(PHY_IV == 1 ~ 1, PHY_IV != 1  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_1 = case_when(DQN_VP == 1 ~ 1, DQN_VP != 1  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_1 = case_when(DQN_IV == 1 ~ 1, DQN_IV != 1  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_2 = case_when(PHY_VP == 2 ~ 1, PHY_VP != 2  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_2 = case_when(PHY_IV == 2 ~ 1, PHY_IV != 2  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_2 = case_when(DQN_VP == 2 ~ 1, DQN_VP != 2  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_2 = case_when(DQN_IV == 2 ~ 1, DQN_IV != 2  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_3 = case_when(PHY_VP == 3 ~ 1, PHY_VP != 3  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_3 = case_when(PHY_IV == 3 ~ 1, PHY_IV != 3  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_3 = case_when(DQN_VP == 3 ~ 1, DQN_VP != 3  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_3 = case_when(DQN_IV == 3 ~ 1, DQN_IV != 3  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_4 = case_when(PHY_VP == 4 ~ 1, PHY_VP != 4  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_4 = case_when(PHY_IV == 4 ~ 1, PHY_IV != 4  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_4 = case_when(DQN_VP == 4 ~ 1, DQN_VP != 4  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_4 = case_when(DQN_IV == 4 ~ 1, DQN_IV != 4  ~ 0)) 

dose_time_val = dose_val     %>% mutate(PHY_VP_OnOFF_0 = case_when(PHY_VP == 0 ~ 1, PHY_VP != 0  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_0 = case_when(PHY_IV == 0 ~ 1, PHY_IV != 0  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_0 = case_when(DQN_VP == 0 ~ 1, DQN_VP != 0  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_0 = case_when(DQN_IV == 0 ~ 1, DQN_IV != 0  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_1 = case_when(PHY_VP == 1 ~ 1, PHY_VP != 1  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_1 = case_when(PHY_IV == 1 ~ 1, PHY_IV != 1  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_1 = case_when(DQN_VP == 1 ~ 1, DQN_VP != 1  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_1 = case_when(DQN_IV == 1 ~ 1, DQN_IV != 1  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_2 = case_when(PHY_VP == 2 ~ 1, PHY_VP != 2  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_2 = case_when(PHY_IV == 2 ~ 1, PHY_IV != 2  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_2 = case_when(DQN_VP == 2 ~ 1, DQN_VP != 2  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_2 = case_when(DQN_IV == 2 ~ 1, DQN_IV != 2  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_3 = case_when(PHY_VP == 3 ~ 1, PHY_VP != 3  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_3 = case_when(PHY_IV == 3 ~ 1, PHY_IV != 3  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_3 = case_when(DQN_VP == 3 ~ 1, DQN_VP != 3  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_3 = case_when(DQN_IV == 3 ~ 1, DQN_IV != 3  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_4 = case_when(PHY_VP == 4 ~ 1, PHY_VP != 4  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_4 = case_when(PHY_IV == 4 ~ 1, PHY_IV != 4  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_4 = case_when(DQN_VP == 4 ~ 1, DQN_VP != 4  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_4 = case_when(DQN_IV == 4 ~ 1, DQN_IV != 4  ~ 0)) 

dose_time_test = dose_test   %>% mutate(PHY_VP_OnOFF_0 = case_when(PHY_VP == 0 ~ 1, PHY_VP != 0  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_0 = case_when(PHY_IV == 0 ~ 1, PHY_IV != 0  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_0 = case_when(DQN_VP == 0 ~ 1, DQN_VP != 0  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_0 = case_when(DQN_IV == 0 ~ 1, DQN_IV != 0  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_1 = case_when(PHY_VP == 1 ~ 1, PHY_VP != 1  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_1 = case_when(PHY_IV == 1 ~ 1, PHY_IV != 1  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_1 = case_when(DQN_VP == 1 ~ 1, DQN_VP != 1  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_1 = case_when(DQN_IV == 1 ~ 1, DQN_IV != 1  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_2 = case_when(PHY_VP == 2 ~ 1, PHY_VP != 2  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_2 = case_when(PHY_IV == 2 ~ 1, PHY_IV != 2  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_2 = case_when(DQN_VP == 2 ~ 1, DQN_VP != 2  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_2 = case_when(DQN_IV == 2 ~ 1, DQN_IV != 2  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_3 = case_when(PHY_VP == 3 ~ 1, PHY_VP != 3  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_3 = case_when(PHY_IV == 3 ~ 1, PHY_IV != 3  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_3 = case_when(DQN_VP == 3 ~ 1, DQN_VP != 3  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_3 = case_when(DQN_IV == 3 ~ 1, DQN_IV != 3  ~ 0)) %>% 
                                 mutate(PHY_VP_OnOFF_4 = case_when(PHY_VP == 4 ~ 1, PHY_VP != 4  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF_4 = case_when(PHY_IV == 4 ~ 1, PHY_IV != 4  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF_4 = case_when(DQN_VP == 4 ~ 1, DQN_VP != 4  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF_4 = case_when(DQN_IV == 4 ~ 1, DQN_IV != 4  ~ 0)) 

########################
### SUMMARISE DATA for Physician
PHY_IV_time_train_melt = dose_time_train %>% group_by(relative_time) %>% 
                                           summarise(PHY_prop_IV_0 = mean(PHY_IV_OnOFF_0)*100,
                                                     PHY_prop_IV_1 = mean(PHY_IV_OnOFF_1)*100, 
                                                     PHY_prop_IV_2 = mean(PHY_IV_OnOFF_2)*100, 
                                                     PHY_prop_IV_3 = mean(PHY_IV_OnOFF_3)*100,
                                                     PHY_prop_IV_4 = mean(PHY_IV_OnOFF_4)*100,                                                                    
                                                  ) %>% ungroup() 

PHY_IV_time_val_melt = dose_time_val %>% group_by(relative_time) %>% 
                                           summarise(PHY_prop_IV_0 = mean(PHY_IV_OnOFF_0)*100,
                                                     PHY_prop_IV_1 = mean(PHY_IV_OnOFF_1)*100, 
                                                     PHY_prop_IV_2 = mean(PHY_IV_OnOFF_2)*100, 
                                                     PHY_prop_IV_3 = mean(PHY_IV_OnOFF_3)*100,
                                                     PHY_prop_IV_4 = mean(PHY_IV_OnOFF_4)*100                                                              
                                                  ) %>% ungroup() 

PHY_IV_time_test_melt = dose_time_test %>% group_by(relative_time) %>% 
                                           summarise(PHY_prop_IV_0 = mean(PHY_IV_OnOFF_0)*100,
                                                     PHY_prop_IV_1 = mean(PHY_IV_OnOFF_1)*100, 
                                                     PHY_prop_IV_2 = mean(PHY_IV_OnOFF_2)*100, 
                                                     PHY_prop_IV_3 = mean(PHY_IV_OnOFF_3)*100,
                                                     PHY_prop_IV_4 = mean(PHY_IV_OnOFF_4)*100                                                                  
                                                  ) %>% ungroup() 

PHY_VP_time_train_melt = dose_time_train %>% group_by(relative_time) %>% 
                                           summarise(PHY_prop_VP_0 = mean(PHY_VP_OnOFF_0)*100,
                                                     PHY_prop_VP_1 = mean(PHY_VP_OnOFF_1)*100, 
                                                     PHY_prop_VP_2 = mean(PHY_VP_OnOFF_2)*100, 
                                                     PHY_prop_VP_3 = mean(PHY_VP_OnOFF_3)*100,
                                                     PHY_prop_VP_4 = mean(PHY_VP_OnOFF_4)*100                                                                    
                                                  ) %>% ungroup() 

PHY_VP_time_val_melt = dose_time_val %>% group_by(relative_time) %>% 
                                           summarise(PHY_prop_VP_0 = mean(PHY_VP_OnOFF_0)*100,
                                                     PHY_prop_VP_1 = mean(PHY_VP_OnOFF_1)*100, 
                                                     PHY_prop_VP_2 = mean(PHY_VP_OnOFF_2)*100, 
                                                     PHY_prop_VP_3 = mean(PHY_VP_OnOFF_3)*100,
                                                     PHY_prop_VP_4 = mean(PHY_VP_OnOFF_4)*100                                                                     
                                                  ) %>% ungroup() 

PHY_VP_time_test_melt = dose_time_test %>% group_by(relative_time) %>% 
                                           summarise(PHY_prop_VP_0 = mean(PHY_VP_OnOFF_0)*100,
                                                     PHY_prop_VP_1 = mean(PHY_VP_OnOFF_1)*100, 
                                                     PHY_prop_VP_2 = mean(PHY_VP_OnOFF_2)*100, 
                                                     PHY_prop_VP_3 = mean(PHY_VP_OnOFF_3)*100,
                                                     PHY_prop_VP_4 = mean(PHY_VP_OnOFF_4)*100                                                                    
                                                  ) %>% ungroup()  

########################
### SUMMARISE DATA for DQN
DQN_IV_time_train_melt = dose_time_train %>% group_by(relative_time) %>% 
                                           summarise(DQN_prop_IV_0 = mean(DQN_IV_OnOFF_0)*100,
                                                     DQN_prop_IV_1 = mean(DQN_IV_OnOFF_1)*100,
                                                     DQN_prop_IV_2 = mean(DQN_IV_OnOFF_2)*100,
                                                     DQN_prop_IV_3 = mean(DQN_IV_OnOFF_3)*100, 
                                                     DQN_prop_IV_4 = mean(DQN_IV_OnOFF_4)*100                                                                      
                                                  ) %>% ungroup() 

DQN_IV_time_val_melt = dose_time_val %>% group_by(relative_time) %>% 
                                           summarise(DQN_prop_IV_0 = mean(DQN_IV_OnOFF_0)*100,
                                                     DQN_prop_IV_1 = mean(DQN_IV_OnOFF_1)*100,
                                                     DQN_prop_IV_2 = mean(DQN_IV_OnOFF_2)*100,
                                                     DQN_prop_IV_3 = mean(DQN_IV_OnOFF_3)*100, 
                                                     DQN_prop_IV_4 = mean(DQN_IV_OnOFF_4)*100                                                                       
                                                  ) %>% ungroup() 

DQN_IV_time_test_melt = dose_time_test %>% group_by(relative_time) %>% 
                                           summarise(DQN_prop_IV_0 = mean(DQN_IV_OnOFF_0)*100,
                                                     DQN_prop_IV_1 = mean(DQN_IV_OnOFF_1)*100,
                                                     DQN_prop_IV_2 = mean(DQN_IV_OnOFF_2)*100,
                                                     DQN_prop_IV_3 = mean(DQN_IV_OnOFF_3)*100, 
                                                     DQN_prop_IV_4 = mean(DQN_IV_OnOFF_4)*100                                                                       
                                                  ) %>% ungroup() 

DQN_VP_time_train_melt = dose_time_train %>% group_by(relative_time) %>% 
                                           summarise(DQN_prop_VP_0 = mean(DQN_VP_OnOFF_0)*100,
                                                     DQN_prop_VP_1 = mean(DQN_VP_OnOFF_1)*100,
                                                     DQN_prop_VP_2 = mean(DQN_VP_OnOFF_2)*100,
                                                     DQN_prop_VP_3 = mean(DQN_VP_OnOFF_3)*100, 
                                                     DQN_prop_VP_4 = mean(DQN_VP_OnOFF_4)*100                                                                       
                                                  ) %>% ungroup() 

DQN_VP_time_val_melt = dose_time_val %>% group_by(relative_time) %>% 
                                           summarise(DQN_prop_VP_0 = mean(DQN_VP_OnOFF_0)*100,
                                                     DQN_prop_VP_1 = mean(DQN_VP_OnOFF_1)*100,
                                                     DQN_prop_VP_2 = mean(DQN_VP_OnOFF_2)*100,
                                                     DQN_prop_VP_3 = mean(DQN_VP_OnOFF_3)*100, 
                                                     DQN_prop_VP_4 = mean(DQN_VP_OnOFF_4)*100                                                                      
                                                  ) %>% ungroup() 
DQN_VP_time_test_melt = dose_time_test %>% group_by(relative_time) %>% 
                                           summarise(DQN_prop_VP_0 = mean(DQN_VP_OnOFF_0)*100,
                                                     DQN_prop_VP_1 = mean(DQN_VP_OnOFF_1)*100,
                                                     DQN_prop_VP_2 = mean(DQN_VP_OnOFF_2)*100,
                                                     DQN_prop_VP_3 = mean(DQN_VP_OnOFF_3)*100, 
                                                     DQN_prop_VP_4 = mean(DQN_VP_OnOFF_4)*100                                                                       
                                                  ) %>% ungroup() 

print("Data preprocessing done")

# Action matrix Plots

In [None]:
########################
options(repr.plot.width=10, repr.plot.height=4)
action_max = 44000

# action matrix Physicians
train_PHY_actionmatrix <- train %>% full_join(action_mappings, by = "real_discrete_action") %>% group_by(discrete_IV, discrete_VP, real_discrete_action) %>% summarise(action_count = n())
AM_TRAIN_PHY = ggplot(train_PHY_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + #, limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='MIMIC training dataset - Physician action matrix')

# action matrix DQN
train_DQN_actionmatrix <- train %>% full_join(DQN_action_mappings, by = "DQN_discrete_action") %>% group_by(discrete_IV, discrete_VP, DQN_discrete_action) %>% summarise(action_count = n())
AM_TRAIN_DQN = ggplot(train_DQN_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + #, limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='MIMIC training dataset - Optimal policy action matrix')

# action matrix Physicians
val_PHY_actionmatrix <- val %>% full_join(action_mappings, by = "real_discrete_action") %>% group_by(discrete_IV, discrete_VP, real_discrete_action) %>% summarise(action_count = n())
AM_VAL_PHY = ggplot(val_PHY_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + #, limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='MIMIC validation dataset - Physician action matrix')

# action matrix DQN
val_DQN_actionmatrix <- val %>% full_join(DQN_action_mappings, by = "DQN_discrete_action") %>% group_by(discrete_IV, discrete_VP, DQN_discrete_action) %>% summarise(action_count = n())
AM_VAL_DQN = ggplot(val_DQN_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + #, limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='MIMIC validation dataset - Optimal policy action matrix')

# action matrix Physicians
test_PHY_actionmatrix <- test %>% full_join(action_mappings, by = "real_discrete_action") %>% group_by(discrete_IV, discrete_VP, real_discrete_action) %>% summarise(action_count = n())
AM_TEST_PHY = ggplot(test_PHY_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + #, limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='AmsterdamUMCdb dataset - Physician action matrix')


# action matrix DQN
test_DQN_actionmatrix <- test %>% full_join(DQN_action_mappings, by = "DQN_discrete_action") %>% group_by(discrete_IV, discrete_VP, DQN_discrete_action) %>% summarise(action_count = n())
AM_TEST_DQN = ggplot(test_DQN_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + # , limit=c(0,action_max)) + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='AmsterdamUMCdb dataset - Optimal policy action matrix')


#################################################
### FILTERED

# action matrix Physicians
filt_train_PHY_actionmatrix <- train %>% filter(relative_time >= 24) %>% full_join(action_mappings, by = "real_discrete_action") %>% group_by(discrete_IV, discrete_VP, real_discrete_action) %>% summarise(action_count = n())
filt_AM_TRAIN_PHY = ggplot(filt_train_PHY_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='MIMIC training dataset - Physician action matrix', subtitle = 'First 48H after sepsis onset')

# action matrix DQN
filt_train_DQN_actionmatrix <- train %>% filter(relative_time >= 24) %>% full_join(DQN_action_mappings, by = "DQN_discrete_action") %>% group_by(discrete_IV, discrete_VP, DQN_discrete_action) %>% summarise(action_count = n())
filt_AM_TRAIN_DQN = ggplot(filt_train_DQN_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='MIMIC training dataset - Optimal policy action matrix', subtitle = 'First 48H after sepsis onset')

# action matrix Physicians
filt_val_PHY_actionmatrix <- val %>% filter(relative_time >= 24) %>% full_join(action_mappings, by = "real_discrete_action") %>% group_by(discrete_IV, discrete_VP, real_discrete_action) %>% summarise(action_count = n())
filt_AM_VAL_PHY = ggplot(filt_val_PHY_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='MIMIC validation dataset - Physician action matrix', subtitle = 'First 48H after sepsis onset')

# action matrix DQN
filt_val_DQN_actionmatrix <- val %>% filter(relative_time >= 24) %>% full_join(DQN_action_mappings, by = "DQN_discrete_action") %>% group_by(discrete_IV, discrete_VP, DQN_discrete_action) %>% summarise(action_count = n())
filt_AM_VAL_DQN = ggplot(filt_val_DQN_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label='MIMIC validation dataset - Optimal policy action matrix', subtitle = 'First 48H after sepsis onset')

# action matrix Physicians
filt_test_PHY_actionmatrix <- test %>% filter(relative_time <= 60) %>% full_join(action_mappings, by = "real_discrete_action") %>% group_by(discrete_IV, discrete_VP, real_discrete_action) %>% summarise(action_count = n())
filt_AM_TEST_PHY = ggplot(filt_test_PHY_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label = 'AmsterdamUMCdb - Physician action matrix', subtitle = 'First 48H of admission')

# action matrix DQN
filt_test_DQN_actionmatrix <- test %>% filter(relative_time <= 60) %>% full_join(DQN_action_mappings, by = "DQN_discrete_action") %>% group_by(discrete_IV, discrete_VP, DQN_discrete_action) %>% summarise(action_count = n())
filt_AM_TEST_DQN = ggplot(filt_test_DQN_actionmatrix, aes(discrete_VP, discrete_IV)) +
    geom_raster(aes(fill = action_count)) + 
    geom_text(aes(label = round(action_count - 1, 3))) + 
    theme_minimal() +
    scale_fill_gradient(low = "white", high = "blue", name = 'Action Count') + 
    xlab('Maximum Vasopressor Dosage') + 
    ylab('Mean IV Fluids') + 
    ggtitle(label = 'AmsterdamUMCdb - Optimal policy action matrix', subtitle = 'First 48H of admission')

if(SAVE_PNG==TRUE){
    ########################
    ### SAFE ACTION MATRIX PNG
    options(repr.plot.width=18, repr.plot.height=8)
    png(filename = paste0(exp_fig_path,"Multiplot_ActionMatrix.png"), width = 18, height = 8, units = "in", res = 400, pointsize=6)
    suppressWarnings(multiplot(AM_TRAIN_PHY, AM_TRAIN_DQN, AM_VAL_PHY, AM_VAL_DQN, AM_TEST_PHY, AM_TEST_DQN, cols=3))
    dev.off()

    ########################
    ### SAFE FILTERED ACTION MATRIX PNG
    options(repr.plot.width=18, repr.plot.height=8)
    png(filename = paste(exp_fig_path,"Multiplot_ActionMatrix_FILT.png"), width = 18, height = 8, units = "in", res = 400, pointsize=6)
    suppressWarnings(multiplot(filt_AM_TRAIN_PHY, filt_AM_TRAIN_DQN, filt_AM_VAL_PHY, filt_AM_VAL_DQN, filt_AM_TEST_PHY, filt_AM_TEST_DQN, cols=3))
    dev.off()
}
if(SAVE_PDF==TRUE){
    ########################
    ### SAFE ACTION MATRIX
    options(repr.plot.width=18, repr.plot.height=8)
    tiff(filename = paste(exp_fig_path,"Multiplot_ActionMatrix.tiff"), width = 18, height = 8, units = "in", res = 200, pointsize=6)
    suppressWarnings(multiplot(AM_TRAIN_PHY, AM_TRAIN_DQN, AM_VAL_PHY, AM_VAL_DQN, AM_TEST_PHY, AM_TEST_DQN, cols=3))
    dev.off()

    ########################
    ### SAFE FILTERED ACTION MATRIX
    options(repr.plot.width=18, repr.plot.height=8)
    tiff(filename = paste(exp_fig_path,"Multiplot_ActionMatrix_FILT.tiff"), width = 18, height = 8, units = "in", res = 200, pointsize=6)
    suppressWarnings(multiplot(filt_AM_TRAIN_PHY, filt_AM_TRAIN_DQN, filt_AM_VAL_PHY, filt_AM_VAL_DQN, filt_AM_TEST_PHY, filt_AM_TEST_DQN, cols=3))
    dev.off()
} 
########################
### SHOW
options(repr.plot.width=18, repr.plot.height=8)
suppressWarnings(multiplot(AM_TRAIN_PHY, AM_TRAIN_DQN, AM_VAL_PHY, AM_VAL_DQN, AM_TEST_PHY, AM_TEST_DQN, cols=3))
suppressWarnings(multiplot(filt_AM_TRAIN_PHY, filt_AM_TRAIN_DQN, filt_AM_VAL_PHY, filt_AM_VAL_DQN, filt_AM_TEST_PHY, filt_AM_TEST_DQN, cols=3))

# Q Value - % Survival Calibration plots

In [None]:
### TO DO: looks like cumulative lineplot (change aesthetics)

# set X min and max for plot axis
xlim_min = -15
xlim_max = 15

# set the granularity of the binning
floor_dec <- function(x, level=1) round(x - 5*10^(-level-1), level)

PHYQ_train_plot_df <- train %>% 
                mutate(bin_Q = floor_dec(phy_action_Qvalue)) %>%
                group_by(bin_Q) %>%
                summarise(prop_dead = mean(((Discharge*-1)+1)),sd_dead = sd(Discharge)/sqrt(n()))
PHYQ_train_plot_df$count = data.frame(table(floor_dec(train$phy_action_Qvalue)))$Freq


PHYQ_val_plot_df <- val %>% 
                mutate(bin_Q = floor_dec(phy_action_Qvalue)) %>%
                group_by(bin_Q) %>%
                summarise(prop_dead = mean(((Discharge*-1)+1)),sd_dead = sd(Discharge)/sqrt(n()))
PHYQ_val_plot_df$count = data.frame(table(floor_dec(val$phy_action_Qvalue)))$Freq


PHYQ_test_plot_df <- test %>% 
                mutate(bin_Q = floor_dec(phy_action_Qvalue)) %>%
                group_by(bin_Q) %>%
                summarise(prop_dead = mean(((Discharge*-1)+1)),sd_dead = sd(Discharge)/sqrt(n()))
PHYQ_test_plot_df$count = data.frame(table(floor_dec(test$phy_action_Qvalue)))$Freq

PHYQ_train_survival = ggplot(data=PHYQ_train_plot_df, aes(x = bin_Q, y = prop_dead*100)) + 
                geom_line(aes(y=prop_dead*100), alpha=0.3) +
                geom_smooth(span = 0.4,method = 'loess', formula = y ~ x, se = FALSE, level=0.5) +
                geom_bar(aes(y=count/60), stat="identity") + 
                scale_y_continuous(sec.axis = sec_axis(~.*60, name = "")) +
                xlim(-5,10) + theme_bw() +
                ggtitle(subtitle='MIMIC training dataset', label = 'Physician Q value - mortality calibration') + 
                ylab('% Patient survival') + xlab('Q value')

PHYQ_val_survival = ggplot(PHYQ_val_plot_df, aes(x = bin_Q, y = prop_dead*100)) + 
                geom_line(aes(y=prop_dead*100), alpha=0.3) +
                geom_smooth(span = 0.4,method = 'loess', formula = y ~ x, se = FALSE, level=0.5) +
                geom_bar(aes(y=count/30), stat="identity") + 
                scale_y_continuous(sec.axis = sec_axis(~.*30, name = "")) +
                xlim(-5,10) + theme_bw() +
                ggtitle(subtitle='MIMIC validation dataset', label = 'Physician Q value - mortality calibration') + 
                ylab('') + xlab('Q value')

PHYQ_test_survival = ggplot(PHYQ_test_plot_df, aes(x = bin_Q, y = prop_dead*100)) + 
                geom_line(aes(y=prop_dead*100), alpha=0.3) +
                geom_smooth(span = 0.4,method = 'loess', formula = y ~ x, se = FALSE, level=0.5) +
                geom_bar(aes(y=count/30), stat="identity") + 
                scale_y_continuous(sec.axis = sec_axis(~.*30, name = "Q Value count")) +
                xlim(-5,10) + theme_bw() +
                ggtitle(subtitle='AmsterdamUMCdb dataset', label = 'Physician Q value - mortality calibration') + 
                ylab('') + xlab('Q value')

########################
if(SAVE_PNG==TRUE){
    ### PNG
    options(repr.plot.width=15, repr.plot.height=5)
    png(filename = paste(exp_fig_path,"Multiplot_Calibration_PHY_Qvalue.png"), width = 12, height = 5, units = "in", res = 400, pointsize=6)
    suppressWarnings(multiplot(PHYQ_train_survival, PHYQ_val_survival, PHYQ_test_survival, cols=3))
    dev.off()
}
if(SAVE_PDF==TRUE){   
    ### SAFE TIFF
    options(repr.plot.width=15, repr.plot.height=5)
    tiff(filename = paste(exp_fig_path,"Multiplot_Calibration_PHY_Qvalue.tiff"), width = 12, height = 5, units = "in", res = 200, pointsize=6)
    suppressWarnings(multiplot(PHYQ_train_survival, PHYQ_val_survival, PHYQ_test_survival, cols=3))
    dev.off()
    print("done")
} 

########################
### SHOW
options(repr.plot.width=15, repr.plot.height=5)
suppressWarnings(multiplot(PHYQ_train_survival, PHYQ_val_survival, PHYQ_test_survival, cols=3))

# reset
options(repr.plot.width=5, repr.plot.height=5)
par(mfrow=c(1,1))

# Dose initiation

In [None]:
########################
### create dataframes with % of actions above 0 (VP>0 and IV>0), except for test set, use IV>mode(IV)
NOdose_time_train = dose_train %>% mutate(PHY_VP_OnOFF = case_when(PHY_VP > 0 ~ 1, PHY_VP == 0  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF = case_when(PHY_IV > 0  ~ 1, PHY_IV == 0  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF = case_when(DQN_VP > 0 ~ 1, DQN_VP == 0  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF = case_when(DQN_IV > 0 ~ 1, DQN_IV == 0  ~ 0)) %>%
                                 group_by(relative_time) %>% summarise(  PHY_prop_VP = mean(PHY_VP_OnOFF)*100, 
                                                                         PHY_prop_IV = mean(PHY_IV_OnOFF)*100,
                                                                         DQN_prop_VP = mean(DQN_VP_OnOFF)*100, 
                                                                         DQN_prop_IV = mean(DQN_IV_OnOFF)*100,
                                                                      ) %>% ungroup() 

NOdose_time_val   = dose_val   %>% mutate(PHY_VP_OnOFF = case_when(PHY_VP > 0 ~ 1, PHY_VP == 0  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF = case_when(PHY_IV > 0  ~ 1, PHY_IV == 0  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF = case_when(DQN_VP > 0 ~ 1, DQN_VP == 0  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF = case_when(DQN_IV > 0 ~ 1, DQN_IV == 0  ~ 0)) %>%
                                 group_by(relative_time) %>% summarise(  PHY_prop_VP = mean(PHY_VP_OnOFF)*100, 
                                                                         PHY_prop_IV = mean(PHY_IV_OnOFF)*100,
                                                                         DQN_prop_VP = mean(DQN_VP_OnOFF)*100, 
                                                                         DQN_prop_IV = mean(DQN_IV_OnOFF)*100,
                                                                      ) %>% ungroup() 

NOdose_time_test  = dose_test  %>% mutate(PHY_VP_OnOFF = case_when(PHY_VP > 0 ~ 1, PHY_VP == 0  ~ 0)) %>% 
                                 mutate(PHY_IV_OnOFF = case_when(PHY_IV >= 3  ~ 1, PHY_IV < 3  ~ 0)) %>% 
                                 mutate(DQN_VP_OnOFF = case_when(DQN_VP > 0 ~ 1, DQN_VP == 0  ~ 0)) %>% 
                                 mutate(DQN_IV_OnOFF = case_when(DQN_IV >= 3 ~ 1, DQN_IV < 3  ~ 0)) %>% 
                                 group_by(relative_time) %>% summarise(  PHY_prop_VP = mean(PHY_VP_OnOFF)*100, 
                                                                         PHY_prop_IV = mean(PHY_IV_OnOFF)*100,
                                                                         DQN_prop_VP = mean(DQN_VP_OnOFF)*100, 
                                                                         DQN_prop_IV = mean(DQN_IV_OnOFF)*100,
                                                                      ) %>% ungroup() 

### quick and dirty fix
NOdose_time_train = NOdose_time_train[complete.cases(NOdose_time_train), ] %>% gather(key,value, 2:5)
NOdose_time_val = NOdose_time_val[complete.cases(NOdose_time_val), ] %>% gather(key,value, 2:5)
NOdose_time_test = NOdose_time_test[complete.cases(NOdose_time_test), ] %>% gather(key,value, 2:5)

########################
### CREATE PLOTS
myColors <- brewer.pal(9,"Greens")
myColors2 <- brewer.pal(9,"Reds")
myColors3 = rbind(myColors[8],myColors2[8],myColors[3],myColors2[3])
names(myColors3) <- levels(as.factor(NOdose_time_train$key))
colScale <- scale_colour_manual(name = "Actions:",values = myColors3, labels = c('Optimal policy Fluids', 'Optimal policy Vasopressors', 'Physician Fluids ', 'Physician Vasopressors'))
dose_time_train_plot = NOdose_time_train %>% ggplot(aes(relative_time-24, value, color = as.factor(key))) +
                                            geom_smooth(method="loess", se=TRUE, span=0.4) + colScale + scale_x_continuous() +
                                            xlab('Relative Hours') +
                                            ylab('% patients on treatment') + ylim(0,100) +
                                            ggtitle(subtitle='MIMIC Training dataset', label='Treatment progression') + theme_bw() +
                                            scale_x_continuous(breaks = c(-24,-12, 0, 12, 24, 36, 48, 60, 72), limits = c(-24, 72))  

dose_time_val_plot = NOdose_time_val %>%     ggplot(aes(relative_time-24, value, color = as.factor(key))) + 
                                            geom_smooth(method="loess", se=TRUE, span=0.4) + colScale + scale_x_continuous() +
                                            xlab('Relative Hours') +
                                            ylab('') + ylim(0,100) +
                                            ggtitle(subtitle='MIMIC Validation dataset', label='Treatment progression') + theme_bw() +
                                            scale_x_continuous(breaks = c(-24,-12, 0, 12, 24, 36, 48, 60, 72), limits = c(-24, 72))  

colScale <- scale_colour_manual(name = "Actions:",values = myColors3, labels = c('Optimal policy Fluids', 'Optimal policy Vasopressors', 'Physician Fluids ', 'Physician Vasopressors'))
dose_time_test_plot = NOdose_time_test %>%   ggplot(aes(relative_time, value, color = as.factor(key))) +
                                            geom_smooth(method="loess", se=TRUE, span=0.4) + colScale + scale_x_continuous() +
                                            xlab('Relative Hours') +
                                            ylab('') + ylim(0,105) +
                                            ggtitle(subtitle='AmsterdamUMCdb dataset', label='Treatment progression') + theme_bw() +
                                            scale_x_continuous(breaks = c(-24,-12, 0, 12, 24, 36, 48, 60, 72), limits = c(-24, 72))  

library(ggpubr)
require(grid)
require(gridExtra)


# ### Don't put this in a if{}else{} statement: saving plot to TIFF/PNG won't work.
# ########################
# ### TIFF
# options(repr.plot.width=18, repr.plot.height=5)
# tiff(filename = paste(exp_fig_path,"Multiplot_IV_VP_ANY_DoseDiff.tiff"), width = 18, height = 5, units = "in", res = 200, pointsize=6)
# p1 <- dose_time_train_plot
# p2 <- dose_time_val_plot
# p3 <- dose_time_test_plot
# ggarrange(p1, p2, p3, ncol=3, common.legend = TRUE, legend="bottom")
# dev.off()

########################
### PNG
options(repr.plot.width=18, repr.plot.height=5)
png(filename = paste(exp_fig_path,"Multiplot_IV_VP_ANY_DoseDiff.png"), width = 18, height = 5, units = "in", res = 400, pointsize=6)
p1 <- dose_time_train_plot
p2 <- dose_time_val_plot
p3 <- dose_time_test_plot
ggarrange(p1, p2, p3, ncol=3, common.legend = TRUE, legend="bottom")
dev.off()

########################
### SHOW PLOTS
options(repr.plot.width=18, repr.plot.height=5)
p1 <- dose_time_train_plot
p2 <- dose_time_val_plot
p3 <- dose_time_test_plot
ggarrange(p1, p2, p3, ncol=3, common.legend = TRUE, legend="bottom")


# Area over time plots

In [None]:
### IV
myColorsF <- brewer.pal(5,"Greens")
names(myColorsF) <- (levels(as.factor(as.factor((PHY_IV_time_test_melt[complete.cases(PHY_IV_time_test_melt), ] %>%  gather(key,value, 2:ncol(PHY_IV_time_test_melt)))$key))))

PHY_IV_trainarea = PHY_IV_time_train_melt[complete.cases(PHY_IV_time_train_melt), ] %>%  gather(key,value, 2:6) %>%
     ggplot(aes(x=relative_time-24,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsF,labels = c('PHY IV 1', 'PHY IV 2', 'PHY IV 3 ', 'PHY IV 4', 'PHY IV 5')) + 
        xlab('Relative Hours') +
        ylab('% patients') + 
        ggtitle(subtitle='MIMIC training dataset', label='IV fluids physician action distribution') + 
        theme(panel.border = element_blank(), panel.grid.major = element_blank(), legend.position = "none",
              panel.grid.minor = element_blank(), panel.background = element_blank()) +
        scale_x_continuous(breaks = c(-24, -12, 0, 12, 24, 36), limits = c(-24, 48)) 

PHY_IV_valarea = PHY_IV_time_val_melt[complete.cases(PHY_IV_time_val_melt), ] %>% gather(key,value,2:6) %>%
     ggplot(aes(x=relative_time-24,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsF,labels = c('PHY IV 1', 'PHY IV 2', 'PHY IV 3 ', 'PHY IV 4', 'PHY IV 5')) + 
        xlab('Relative Hours') +
        ylab('% patients') + 
        ggtitle(subtitle='MIMIC validation dataset', label='IV fluids physician action distribution') + 
        theme(panel.border = element_blank(), panel.grid.major = element_blank(), legend.position = "none",
              panel.grid.minor = element_blank(), panel.background = element_blank()) +
        scale_x_continuous(breaks = c(-24, -12, 0, 12, 24, 36), limits = c(-24, 48)) 

PHY_IV_testarea = PHY_IV_time_test_melt[complete.cases(PHY_IV_time_test_melt), ] %>% gather(key,value,2:6) %>%
     ggplot(aes(x=relative_time,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsF,labels = c('PHY IV 1', 'PHY IV 2', 'PHY IV 3 ', 'PHY IV 4', 'PHY IV 5')) + 
        xlab('Relative Hours') +
        ylab('% patients') + 
        ggtitle(subtitle='AmsterdamUMCdb dataset', label='IV fluids physician action distribution') + 
        theme(panel.border = element_blank(), panel.grid.major = element_blank(), #legend.position = "none",
              panel.grid.minor = element_blank(), panel.background = element_blank()) +
        scale_x_continuous(breaks = c(0, 12, 24, 36, 48, 60), limits = c(0, 74)) 

myColorsG <- brewer.pal(5,"Greens")
names(myColorsG) <- (levels(as.factor(as.factor((DQN_IV_time_test_melt[complete.cases(DQN_IV_time_test_melt), ] %>%  gather(key,value, 2:ncol(DQN_IV_time_test_melt)))$key))))

DQN_IV_trainarea = DQN_IV_time_train_melt[complete.cases(DQN_IV_time_train_melt), ] %>%  gather(key,value, 2:6) %>% 
     ggplot(aes(x=relative_time-24,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsG,labels = c('OPT IV 1','OPT IV 2', 'OPT IV 3', 'OPT IV 4 ', 'OPT IV 5')) + 
        xlab('Relative Hours') +
        ylab('% patients') + 
        ggtitle(subtitle='MIMIC training dataset', label='IV fluids optimal policy action distribution') + 
        theme(panel.border = element_blank(), panel.grid.major = element_blank(), legend.position = "none",
              panel.grid.minor = element_blank(), panel.background = element_blank()) + 
        scale_x_continuous(breaks = c(-24, -12, 0, 12, 24, 36), limits = c(-24, 48)) 

DQN_IV_valarea = DQN_IV_time_val_melt[complete.cases(DQN_IV_time_val_melt), ] %>% gather(key,value,2:6) %>%
    ggplot(aes(x=relative_time-24,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsG,labels = c('OPT IV 1','OPT IV 2', 'OPT IV 3', 'OPT IV 4 ', 'OPT IV 5')) + 
        xlab('Relative Hours') +
        ylab('% patients') + 
        ggtitle(subtitle='MIMIC validation dataset', label='IV fluids optimal policy action distribution') + 
        theme(panel.border = element_blank(), panel.grid.major = element_blank(), legend.position = "none",
              panel.grid.minor = element_blank(), panel.background = element_blank()) +
        scale_x_continuous(breaks = c(-24, -12, 0, 12, 24, 36), limits = c(-24, 48)) 

DQN_IV_testarea = DQN_IV_time_test_melt[complete.cases(DQN_IV_time_test_melt), ] %>% gather(key,value,2:6) %>%
       ggplot(aes(x=relative_time,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsG,labels = c('OPT IV 1','OPT IV 2', 'OPT IV 3', 'OPT IV 4 ', 'OPT IV 5')) + 
        xlab('Relative Hours') +
        ylab('% patients') + 
        ggtitle(subtitle='AmsterdamUMCdb dataset', label='IV fluids optimal policy action distribution') + 
        theme(panel.border = element_blank(), panel.grid.major = element_blank(), #legend.position = "none",
              panel.grid.minor = element_blank(), panel.background = element_blank()) +
        scale_x_continuous(breaks = c(0, 12, 24, 36, 48, 60), limits = c(0, 74))  

### VP
myColorsH <- brewer.pal(5,"Reds")
names(myColorsH) <- (levels(as.factor(as.factor((PHY_VP_time_test_melt[complete.cases(PHY_VP_time_test_melt), ] %>%  gather(key,value, 2:ncol(PHY_VP_time_test_melt)))$key))))

PHY_VP_trainarea = PHY_VP_time_train_melt[complete.cases(PHY_VP_time_train_melt), ] %>%  gather(key,value, 2:6) %>%
ggplot(aes(x=relative_time-24,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsH,labels = c('PHY VP 1','PHY VP 2', 'PHY VP 3 ', 'PHY VP 4', 'PHY VP 5')) + 
        xlab('Relative Hours') + ylim(0,101) +
        ylab('% patients') +
        ggtitle(subtitle='MIMIC training dataset', label='Vasopressor physician action distribution') + 
        theme(panel.border = element_blank(), panel.grid.major = element_blank(), legend.position = "none",
              panel.grid.minor = element_blank(), panel.background = element_blank()) +
       scale_x_continuous(breaks = c(-24, -12, 0, 12, 24, 36), limits = c(-24, 48)) 

PHY_VP_valarea = PHY_VP_time_val_melt[complete.cases(PHY_VP_time_val_melt), ] %>% gather(key,value,2:6) %>%
ggplot(aes(x=relative_time-24,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsH,labels = c('PHY VP 1','PHY VP 2', 'PHY VP 3 ', 'PHY VP 4', 'PHY VP 5')) + 
        xlab('Relative Hours') + ylim(0,101) +
        ylab('% patients') + 
        ggtitle(subtitle='MIMIC validation dataset', label='Vasopressor physician action distribution') +  
        theme(panel.border = element_blank(), panel.grid.major = element_blank(), legend.position = "none",
              panel.grid.minor = element_blank(), panel.background = element_blank()) +
       scale_x_continuous(breaks = c(-24, -12, 0, 12, 24, 36), limits = c(-24, 48)) 

PHY_VP_testarea = PHY_VP_time_test_melt[complete.cases(PHY_VP_time_test_melt), ] %>% gather(key,value,2:6) %>%
       ggplot(aes(x=relative_time,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsH,labels = c('PHY VP 1', 'PHY VP 2', 'PHY VP 3 ', 'PHY VP 4', 'PHY VP 5')) + 
        xlab('Relative Hours') + 
        ylab('% patients') + 
        ggtitle(subtitle='AmsterdamUMCdb dataset', label='Vasopressor physician action distribution') + 
        theme(panel.border = element_blank(), panel.grid.major = element_blank(), #legend.position = "none",
              panel.grid.minor = element_blank(), panel.background = element_blank()) +
       scale_x_continuous(breaks = c(0, 12, 24, 36, 48, 60), limits = c(0, 74)) 

myColorsI <- brewer.pal(5,"Reds")
names(myColorsI) <- (levels(as.factor(as.factor((DQN_VP_time_test_melt[complete.cases(DQN_VP_time_test_melt), ] %>%  gather(key,value, 2:ncol(DQN_VP_time_test_melt)))$key))))

DQN_VP_trainarea = DQN_VP_time_train_melt[complete.cases(DQN_VP_time_train_melt), ] %>%  gather(key,value, 2:6) %>%
     ggplot(aes(x=relative_time-24,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsI,labels = c('OPT VP 1', 'OPT VP 2', 'OPT VP 3 ', 'OPT VP 4', 'OPT VP 5')) +
        xlab('Relative Hours') + ylim(0,101) +
        ylab('% patients') +
        ggtitle(subtitle='MIMIC training dataset', label='Vasopressor optimal policy action distribution') + 
        theme(panel.border = element_blank(), panel.grid.major = element_blank(), legend.position = "none",
              panel.grid.minor = element_blank(), panel.background = element_blank()) +
        scale_x_continuous(breaks = c(-24, -12, 0, 12, 24, 36), limits = c(-24, 48)) 

DQN_VP_valarea = DQN_VP_time_val_melt[complete.cases(DQN_VP_time_val_melt), ] %>% gather(key,value,2:6) %>% spread(key,value) %>% gather(key,value,2:6) %>%
     ggplot(aes(x=relative_time-24,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsI,labels = c('OPT VP 1', 'OPT VP 2', 'OPT VP 3 ', 'OPT VP 4', 'OPT VP 5')) +
        xlab('Relative Hours') + ylim(0,101) +
        ylab('% patients') +
        ggtitle(subtitle='MIMIC validation dataset', label='Vasopressor optimal policy action distribution') + 
        theme(panel.border = element_blank(), panel.grid.major = element_blank(), legend.position = "none",
              panel.grid.minor = element_blank(), panel.background = element_blank()) +
        scale_x_continuous(breaks = c(-24, -12, 0, 12, 24, 36), limits = c(-24, 48)) 

DQN_VP_testarea = DQN_VP_time_test_melt[complete.cases(DQN_VP_time_test_melt), ] %>% gather(key,value,2:6) %>% 
     ggplot(aes(x=relative_time,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsI,labels = c('OPT VP 1', 'OPT VP 2', 'OPT VP 3 ', 'OPT VP 4', 'OPT VP 5')) + 
        xlab('Relative Hours') +
        ylab('% patients') +
        ggtitle(subtitle='AmsterdamUMCdb dataset', label='Vasopressor optimal policy action distribution') + 
        theme(panel.border = element_blank(), panel.grid.major = element_blank(), #legend.position = "none",
              panel.grid.minor = element_blank(), panel.background = element_blank()) +
        scale_x_continuous(breaks = c(0, 12, 24, 36, 48, 60), limits = c(0, 74)) 


###########################
### LEGENDS
library(ggpubr)

# create VP legends
VP_OPT = DQN_VP_time_test_melt[complete.cases(DQN_VP_time_test_melt), ] %>% gather(key,value,2:6) %>% 
     ggplot(aes(x=relative_time,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsI,labels = c('OPT VP 1', 'OPT VP 2', 'OPT VP 3 ', 'OPT VP 4', 'OPT VP 5')) + 
        xlab('Relative Hours') +
        ylab('% patients') +
        ggtitle(subtitle='AmsterdamUMCdb dataset', label='Vasopressor optimal policy action distribution') + 
        theme(panel.border = element_blank(), panel.grid.major = element_blank(),
              panel.grid.minor = element_blank(), panel.background = element_blank()) +
        scale_x_continuous(breaks = c(0, 12, 24, 36, 48, 60), limits = c(0, 74)) 

VP_PHY = PHY_VP_time_test_melt[complete.cases(PHY_VP_time_test_melt), ] %>% gather(key,value,2:6) %>%
       ggplot(aes(x=relative_time,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsH,labels = c('PHY VP 1', 'PHY VP 2', 'PHY VP 3 ', 'PHY VP 4', 'PHY VP 5')) + 
        xlab('Relative Hours') + 
        ylab('% patients') + 
        ggtitle(subtitle='AmsterdamUMCdb dataset', label='Vasopressor physician action distribution') + 
        theme(panel.border = element_blank(), panel.grid.major = element_blank(), 
              panel.grid.minor = element_blank(), panel.background = element_blank()) +
       scale_x_continuous(breaks = c(0, 12, 24, 36, 48, 60), limits = c(0, 74)) 

IV_OPT = DQN_IV_time_test_melt[complete.cases(DQN_IV_time_test_melt), ] %>% gather(key,value,2:6) %>%
       ggplot(aes(x=relative_time,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsG,labels = c('OPT IV 1','OPT IV 2', 'OPT IV 3', 'OPT IV 4 ', 'OPT IV 5')) + 
        xlab('Relative Hours') +
        ylab('% patients') + 
        ggtitle(subtitle='AmsterdamUMCdb dataset', label='IV fluids optimal policy action distribution') + 
        theme(panel.border = element_blank(), panel.grid.major = element_blank(), 
              panel.grid.minor = element_blank(), panel.background = element_blank()) +
        scale_x_continuous(breaks = c(0, 12, 24, 36, 48, 60), limits = c(0, 74))  

IV_PHY = PHY_IV_time_test_melt[complete.cases(PHY_IV_time_test_melt), ] %>% gather(key,value,2:6) %>%
     ggplot(aes(x=relative_time,y=value,fill=key,group=key)) + 
        stat_smooth(geom = 'area', method = 'loess', span = 0.29, position = position_stack(reverse = F)) +
        scale_fill_manual(name="Actions", values = myColorsF,labels = c('PHY IV 1', 'PHY IV 2', 'PHY IV 3 ', 'PHY IV 4', 'PHY IV 5')) + 
        xlab('Relative Hours') +
        ylab('% patients') + 
        ggtitle(subtitle='AmsterdamUMCdb dataset', label='IV fluids physician action distribution') + 
        theme(panel.border = element_blank(), panel.grid.major = element_blank(), 
              panel.grid.minor = element_blank(), panel.background = element_blank()) +
        scale_x_continuous(breaks = c(0, 12, 24, 36, 48, 60), limits = c(0, 74)) 

# Extract the legends
VP_OPT_leg <- get_legend(VP_OPT)
VP_OPT_leg_plot = as_ggplot(VP_OPT_leg)  + theme(legend.justification=c(0,0), legend.position=c(0,0))

VP_PHY_leg <- get_legend(VP_PHY)
VP_PHY_leg_plot = as_ggplot(VP_PHY_leg)  + theme(legend.position="left")

IV_OPT_leg <- get_legend(IV_OPT)
IV_OPT_leg_plot = as_ggplot(IV_OPT_leg) + theme(legend.position="left")

IV_PHY_leg <- get_legend(IV_PHY)
IV_PHY_leg_plot = as_ggplot(IV_PHY_leg)  + theme(legend.position="left")



if(SAVE_PNG==TRUE){
    ########################
    ### SAFE PLOTS
    options(repr.plot.width=32, repr.plot.height=4)
    png(filename = paste(exp_fig_path,"Multiplot_PHY_IV_AREA.png"), width = 32, height = 4, units = "in", res = 400, pointsize=6)
    suppressWarnings(multiplot(PHY_IV_trainarea, PHY_IV_valarea, PHY_IV_testarea, cols=3))
    dev.off()
    png(filename = paste(exp_fig_path,"Multiplot_DQN_IV_AREA.png"), width = 32, height = 4, units = "in", res = 400, pointsize=6)
    suppressWarnings(multiplot(DQN_IV_trainarea, DQN_IV_valarea, DQN_IV_testarea, cols=3))
    dev.off()
    png(filename = paste(exp_fig_path,"Multiplot_PHY_VP_AREA.png"), width = 32, height = 4, units = "in", res = 400, pointsize=6)
    suppressWarnings(multiplot(PHY_VP_trainarea, PHY_VP_valarea, PHY_VP_testarea, cols=3))
    dev.off()
    png(filename = paste(exp_fig_path,"Multiplot_DQN_VP_AREA.png"), width = 32, height = 4, units = "in", res = 400, pointsize=6)
    suppressWarnings(multiplot(DQN_VP_trainarea, DQN_VP_valarea, DQN_VP_testarea, cols=3))
    dev.off()
}
if(SAVE_PDF==TRUE){   
    options(repr.plot.width=32, repr.plot.height=4)
    tiff(filename = paste(exp_fig_path,"Multiplot_PHY_IV_AREA.tiff"), width = 32, height = 4, units = "in", res = 200, pointsize=6)
    suppressWarnings(multiplot(PHY_IV_trainarea, PHY_IV_valarea, PHY_IV_testarea, cols=3))
    dev.off()
    tiff(filename = paste(exp_fig_path,"Multiplot_DQN_IV_AREA.tiff"), width = 32, height = 4, units = "in", res = 200, pointsize=6)
    suppressWarnings(multiplot(DQN_IV_trainarea, DQN_IV_valarea, DQN_IV_testarea, cols=3))
    dev.off()
    tiff(filename = paste(exp_fig_path,"Multiplot_PHY_VP_AREA.tiff"), width = 32, height = 4, units = "in", res = 200, pointsize=6)
    suppressWarnings(multiplot(PHY_VP_trainarea, PHY_VP_valarea, PHY_VP_testarea, cols=3))
    dev.off()
    tiff(filename = paste(exp_fig_path,"Multiplot_DQN_VP_AREA.tiff"), width = 32, height = 4, units = "in", res = 200, pointsize=6)
    suppressWarnings(multiplot(DQN_VP_trainarea, DQN_VP_valarea, DQN_VP_testarea, cols=3))
    dev.off()
    print("done")
}

########################
### SHOW PLOTS
options(repr.plot.width=26, repr.plot.height=4)
suppressWarnings(multiplot(PHY_IV_trainarea, PHY_IV_valarea, PHY_IV_testarea, cols=3))
suppressWarnings(multiplot(DQN_IV_trainarea, DQN_IV_valarea, DQN_IV_testarea, cols=3))
suppressWarnings(multiplot(PHY_VP_trainarea, PHY_VP_valarea, PHY_VP_testarea, cols=3))
suppressWarnings(multiplot(DQN_VP_trainarea, DQN_VP_valarea, DQN_VP_testarea, cols=3))