This Notebook is used to calculate the summary statistics for both clinical and RL-algorithms.

For RL algorithms: <br>
The results for experiments will be saved under "results/FOLDER_ID/AlgoSub_SEED".<br>
e.g., 'results/adult/G0_1'; G2P2C algorithm for adult subject 0 and seed 1.

In [10]:
import os 
import sys 
from decouple import config 
MAIN_PATH = 'C:/Users/davet/Desktop/Thesis/G2P2C/'
sys.path.insert(1, MAIN_PATH)
from visualiser.core import experiment_error_check
from visualiser.statistics import get_summary_stats, compare_algorithms

# First we quickly check and verify that the RL experiments have properly completed based on the 
# general template/guidelines used. You can customise additional parameters, check func "experiment_error_check".
# experiment_error_check(cohort="adolescent", algorithm='PPO', algoAbbreviation='P', subjects=['0', '2', '6'])
result_path = {'TD3_1':'results/L2Norm_LR1e0/NoCutOff/NoiseApplication/TD3/Model1/NormDist/sigma_2e-1',
               'TD3_2':'results/L2Norm_LR1e-1/NoCutOff/NoiseApplication/TD3/Model1/NormDist/sigma_2e-1',
               'TD3_3':'results/L2Norm_LR1e-2/NoCutOff/NoiseApplication/TD3/Model1/NormDist/sigma_2e-1'}

We calculate statistics (individual/cohort-levels) for identified clinical and RL metrics.
Target_metrics are time in normoglycemia (TIR), hypoglycemia, hyperglycemia, severe hypoglycemia, Severe hyperglycemia, LBGI (Low Blood Glucose Index), HBGI (High Blood Glucose Index), RI (Risk Index), Failure Rate, Reward.

In [11]:
# metrics and stats for RL agents and clinical, for a selected algorthm.
# Check parameters to get more information being calculated.
# Normal flow in the calculation: 
# For each subject, looking at 1500 simulations we calculate the metrics, you can get inter-quartile, mean, std etc.
# Using the mean for each metric we calculate cohort level metrics and statistics 
    
# get_summary_stats(cohort="adolescent", algo_type='rl', algorithm='G2P2C', algoAbbreviation='G', 
#                   metric=['50%', '25%', '75%','mean', 'std'], 
#                   verbose=False, show_res=True, sort=[False, 'hgbi'],
#                   subjects=['0', '2', '6'])
# 
# get_summary_stats(cohort="adolescent", algo_type='clinical', algorithm='BBI', algoAbbreviation='BBI', 
#                   metric=['50%', '25%', '75%','mean', 'std'], 
#                   verbose=False, show_res=True, sort=[False, 'hgbi'],
#                   subjects=['0', '2', '6'])
# patient = '0'
patient = '2'
# patient = '6'
get_summary_stats(cohort="adolescent", algo_type='rl', algorithm='TD3_1', algoAbbreviation='TD3', 
                  metric=['25%', '50%', '75%','mean', 'std'], 
                  verbose=False, show_res=True, sort=[False, 'hgbi'],
                  # subjects=['0', '2', '6'],
                  subjects=[patient],
                  result_path=result_path)
print("="*100)
get_summary_stats(cohort="adolescent", algo_type='rl', algorithm='TD3_2', algoAbbreviation='TD3', 
                  metric=['25%', '50%', '75%','mean', 'std'], 
                  verbose=False, show_res=True, sort=[False, 'hgbi'],
                  # subjects=['0', '2', '6'],
                  subjects=[patient],
                  result_path=result_path)
print("="*100)
get_summary_stats(cohort="adolescent", algo_type='rl', algorithm='TD3_3', algoAbbreviation='TD3', 
                  metric=['25%', '50%', '75%','mean', 'std'], 
                  verbose=False, show_res=True, sort=[False, 'hgbi'],
                  # subjects=['0', '2', '6'],
                  subjects=[patient],
                  result_path=result_path)





Summary statistics for adolescent cohort, TD3_1 Algorithm

Summarised cohort statistics (mean):
    normo   hypo  hyper  S_hypo  S_hyper   lgbi  hgbi     ri  reward   fail
id                                                                         
2   76.85  11.66   0.01   11.47      0.0  10.44  1.24  11.68   45.58  100.0

Averaged cohort statistics:
      normo   hypo  hyper  S_hypo  S_hyper   lgbi  hgbi     ri  reward   fail
25%   72.46   6.90   0.00    7.46      0.0   8.48  0.70   9.87   40.41  100.0
50%   77.14  10.87   0.00   10.81      0.0  10.07  1.11  11.24   45.54  100.0
75%   82.14  15.25   0.00   14.67      0.0  12.31  1.76  13.69   50.18  100.0
mean  76.85  11.66   0.01   11.47      0.0  10.44  1.24  11.68   45.58  100.0
std    6.94   6.18   0.19    5.15      0.0   2.53  0.65   2.74    7.52  100.0

Summary statistics for adolescent cohort, TD3_2 Algorithm

Summarised cohort statistics (mean):
    normo  hypo  hyper  S_hypo  S_hyper  lgbi   hgbi    ri  reward  fail
id      

Unnamed: 0,normo,hypo,hyper,S_hypo,S_hyper,lgbi,hgbi,ri,reward,fail
25%,26.35,0.0,8.93,0.0,0.0,0.0,3.88,10.5,103.6,87.27
50%,53.47,0.76,12.5,1.05,13.19,3.74,14.83,19.16,110.63,87.27
75%,81.98,3.95,19.16,2.88,54.49,7.28,37.62,38.09,156.63,87.27
mean,55.41,3.02,13.61,2.28,25.67,4.51,18.64,23.16,131.9,87.27
std,26.1,5.02,7.03,3.98,25.31,4.97,15.57,13.1,51.87,87.27


In [31]:
# Adolescents
compare_algorithms(cohort="adolescent", 
                   algo_types= ['rl', 'rl', 'rl', 'rl'],
                   algos=['DDPG1', 'DDPG2', 'DDPG3'], 
                   abbreviations=['DDPG', 'DDPG', 'DDPG'],
                   subjects=['0', '2', '6'],
                   result_path=result_path)


Compare algorithm performance for the adolescent cohort
      normo  hypo  hyper  S_hypo  S_hyper  lgbi   hgbi     ri  reward    fail
Algo                                                                         
TD1   24.96  0.00  24.16    0.00    50.89  0.04  33.85  33.89  111.62  100.00
TD2   24.96  0.00  24.16    0.00    50.89  0.04  33.85  33.89  111.62  100.00
TD3   35.22  0.72  22.50    0.39    41.17  1.56  27.94  29.50  141.27   83.89


Unnamed: 0_level_0,normo,hypo,hyper,S_hypo,S_hyper,lgbi,hgbi,ri,reward,fail
Algo,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
TD1,24.96,0.0,24.16,0.0,50.89,0.04,33.85,33.89,111.62,100.0
TD2,24.96,0.0,24.16,0.0,50.89,0.04,33.85,33.89,111.62,100.0
TD3,35.22,0.72,22.5,0.39,41.17,1.56,27.94,29.5,141.27,83.89
