In [None]:
### SETUP

import os
import sys
import json
import numpy as np
import pandas as pd
from decouple import config
MAIN_PATH = config('MAIN_PATH')
sys.path.insert(1, MAIN_PATH)
from visualiser.core import ExperimentVisualise, plot_testing_rewards, plot_training_logs, plot_value_function
from visualiser.core import plot_episode, plot_training_action_summary, plot_training_logs
from visualiser.core import display_commands, plot_testing_metric, display_commands_v2#, plot_testing_average_metric
from metrics.metrics import time_in_range
from metrics.statistics import calc_stats
from agents.models.actor_critic_td3_bc import ActorCritic, QNetwork, PolicyNetwork
import numpy.polynomial.polynomial as poly 
from utils.worker import OnPolicyWorker, OffPolicyWorker, OfflineSampler
from omegaconf import DictConfig, OmegaConf
import matplotlib.gridspec
from matplotlib import pyplot as plt
import math
import torch
from utils.logger import setup_folders, Logger


PI_GRAD_THRESHOLD = 140
PI_CONCAV_THRESHOLD = 0.00008
PI_CONCAV_CONVERGENCE_THRESHOLD = 0.0001
PI_CONCAV_CONVERGENCE_N = 3
MONOTONIC_EPISLON = 1e-5

PER_RUN_HEADERS = ["name", "cohort", "subject", "seed"] + ["rew", "tir", "adj_tir", "tbr_1", "tbr_2", "tar_1", "tar_2", 't', "success_bool"]
HEADERS = ["rew_mn", "rew_sd", "tir_mn","tir_sd", "adj_tir_mn", "adj_tir_sd", "tbr_1_mn", "tbr_1_sd", "tbr_2_mn", "tbr_2_sd", "tar_1_mn", "tar_1_sd", "tar_2_mn", "tar_2_sd", "success"]


def create_file_paths(path, seeds, filename, FILES):
    for seed in seeds:
        FILES.append(path + filename + str(seed)+'.csv')
    return FILES
    
def generate_reward_list(exp_name):
    exp1 = ExperimentVisualise(id=exp_name, version=1.1, plot_version=1, test_seeds=5000)

    disp_arr = display_commands_v2(command)
    
    exp_vis = {
                '1':{ "id":exp1, "color":'r',"show": disp_arr[1], "label":exp_name},
    }

    metric = 'reward'

    path, seeds, filename = exp_vis['1']['id'].get_file_paths()
    FILES = []
    FILES = create_file_paths(path, seeds, filename, FILES)
    cur_length, full_arr, refined = [], [], []
    for file in FILES:
        reward_summary = pd.read_csv(file)
        cur_length.append(reward_summary.shape[0])
        full_arr.append(reward_summary[metric])
    for x in full_arr:
        refined.append(x[0:min(cur_length)])
    data = pd.concat(refined, axis=1)
    data['mean'] = data.mean(axis=1)
    
    return list(data['mean'])
    
def generate_end_index(exp_name, return_proportion=False, debug_show=False, use_alt=False):
    df = pd.read_csv(MAIN_PATH + '/' + 'results/' + exp_name + '/experiment_summary.csv', header="infer")
    df_len = len(df["value_grad"])
    
    exp = ExperimentVisualise(id=exp_name, version=1.1, plot_version=1, test_seeds=5000)
    d1 = exp.get_training_logs()

    steps = list(d1['steps'])
    val_grad = list(d1['value_grad'])
    val_loss = list(d1['val_loss'])
    pi_grad = list(d1['policy_grad'])

    if use_alt: ind = check_end_condition_total_alt(val_grad, val_loss, pi_grad, steps, debug_show=debug_show)
    else: ind = check_end_condition_total(val_grad, val_loss, pi_grad, steps, debug_show=debug_show)
    if return_proportion: return ind / (df_len - 1)
    return ind

def find_best_tir_index(exp_name, bound, workers=5): #bound exclusive
    ep_dfs = []
    for epi in range(bound):
        ep_dfs.append(take_rows(exp_name, episode=epi+1, workers=workers))

    return max(list(range(bound)), key = lambda ind: get_mn_sd(ep_dfs[ind]["adj_normo"])[0])

# modified graphing functions

def plot_testing_average_metric(dict, groups, type, dis_len, metric, goal, fill,vline=None, title=None):
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111)

    for groupings in range(0, len(groups)):
        FILES = []
        for i in groups[groupings]:  # exp's inside the group
            # will give the exp list
            path, seeds, filename = dict[i]['id'].get_file_paths()
            FILES = create_file_paths(path, seeds, filename, FILES)
        cur_length, full_arr, refined = [], [], []
        for file in FILES:
            reward_summary = pd.read_csv(file)
            cur_length.append(reward_summary.shape[0])
            full_arr.append(reward_summary[metric])
        for x in full_arr:
            refined.append(x[0:min(cur_length)])
        data = pd.concat(refined, axis=1)
        data['mean'] = data.mean(axis=1)

        if type == 'normal':
            data['std_dev'] = data.std(axis=1)
            data['max'] = data['mean'] + data['std_dev']  # * 2
            data['min'] = data['mean'] - data['std_dev']  # * 2
        else:
            data['max'] = data.max(axis=1)
            data['min'] = data.min(axis=1)

        data['steps'] = np.arange(len(data))
        data['steps'] = (data['steps'] + 1) * dict[i]['id'].training_workers * dict[i]['id'].args['agent']['n_step']

        ax.plot(data['steps'], data['mean'], color=dict[i]['color'], label=dict[i]['label'])
        if fill:
            ax.fill_between(data['steps'], data['min'], data['max'], color=dict[i]['color'], alpha=0.1)

    ax.axhline(y=goal, color='k', linestyle='--')
    
    start_step = list(data['steps'])[0]
    end_step = list(data['steps'])[-1]
    ax.axvline(x=start_step + vline*(end_step - start_step), color='b', linestyle='-')

    graph_title =  title if title is not None else 'Average Rewards (Multiple Seeds)'
    ax.set_title(graph_title, fontsize=32)
    # ax.legend(loc="upper left", fontsize=16)
    ax.set_ylabel('Total Reward', fontsize=24) #ax.set_ylabel(metric)
    ax.set_xlabel('Steps', fontsize=24)
    ax.grid()
    ax.set_xlim(0, dis_len)
    ax.set_ylim(0, 320)
    plt.show()

def plot_training_logs(mode, exp_dict, dis_len, params,vline=None, cols=1,val_grad_poly=None):
    fig = plt.figure(figsize=(16, 26))
    subplots = []
    tot_plots = len(params)

    for i in range(0, tot_plots):
        subplots.append(fig.add_subplot(math.ceil(tot_plots/cols), cols, i+1))
        subplots[i].grid(True)
        subplots[i].set_xlim(0, dis_len)
        subplots[i].set_title(params[i])

    for exp in exp_dict:
        if exp_dict[exp]['show']:
            if mode == 'ppo':
                # ['exp_var', 'true_var','val_loss', 'policy_grad', 'value_grad', 'pi_loss', 'avg_rew']
                d1 = exp_dict[exp]['id'].get_training_logs()
            elif mode == 'aux':
                # ['pi_aux_loss', 'vf_aux_loss', 'pi_aux_grad', 'vf_aux_grad']
                d1 = exp_dict[exp]['id'].get_aux_training_logs()
            elif mode == 'planning':
                # ['plan_grad', 'plan_loss']
                d1 = exp_dict[exp]['id'].get_planning_training_logs()
            else:
                d1 = 0
                print('Invalid mode selected')
                exit()
            for j in range(0, tot_plots):
                start_step = list(d1['steps'])[0]
                end_step = list(d1['steps'])[-1]

                if params[j] == 'value_grad' and val_grad_poly != None:
                    polyline = np.linspace(start_step, end_step)
                    subplots[j].plot(polyline, val_grad_poly(polyline))
                subplots[j].plot(d1['steps'], d1[params[j]], color=exp_dict[exp]['color'], label=exp_dict[exp]['id'].id)
                
                subplots[j].axvline(x=start_step + vline*(end_step - start_step), color='b', linestyle='-')

    for i in range(0, tot_plots):
        subplots[i].legend(loc="upper right")
    plt.show()

def calculate_r_squared(y_data, RSS):
    y_mean = np.mean(y_data)
    TSS = sum([(y_mean - y)**2 for y in y_data])
    return 1 - (RSS/TSS)

def generate_poly(steps, val_grad):
    fitted_poly = poly.Polynomial.fit(steps, val_grad, 2, full=True) #fit quadratic model
    r_2 = calculate_r_squared(val_grad, fitted_poly[1][0][0])
    return r_2, fitted_poly[0]

def check_end_condition_total(
        val_grads, val_losses, pi_grads, steps,
        debug_show=False
    ):
    
    # default to last network
    epochs = len(val_grads)
    total_ind = epochs

    pi_grad_grads = []
    for n in range(epochs-1):
        dy = pi_grads[n+1] - pi_grads[n]
        dx = steps[n+1] - steps[n]
        pi_grad_grads.append( (dy / dx) )

    if debug_show: 
        print(pi_grads)
        print(pi_grad_grads)
    
    rough_monotonic = all([pi_grad_grad < MONOTONIC_EPISLON for pi_grad_grad in pi_grad_grads])
    if not rough_monotonic and debug_show: print("WARNING! Not monotonic")
    

    # ### pi grad threshold
    for n, pi_grad in enumerate(pi_grads):
        if abs(pi_grad) < PI_GRAD_THRESHOLD:
            total_ind = (n - 1)
            break


    total_ind = min(max(total_ind, 0), epochs -1)

    return total_ind

def check_end_condition_total_alt(
        val_grads, val_losses, pi_grads, steps,
        debug_show=False
    ):
    
    # default to last network
    epochs = len(val_grads)
    total_ind = epochs



    pi_grad_grads = []
    for n in range(epochs-1):
        dy = pi_grads[n+1] - pi_grads[n]
        dx = steps[n+1] - steps[n]
        pi_grad_grads.append( (dy / dx) )

    if debug_show: 
        print(pi_grads)
        print(pi_grad_grads)
    

    ## maximum pi grad
    max_grad_ind = max(list(range(epochs-1)), key= lambda ind : pi_grads[ind])
    total_ind = max_grad_ind + 1
        


    total_ind = min(max(total_ind, 0), epochs -1)

    return total_ind

# def generate_stats(exp_dir, n_val_trials=5, val_offset=5000, )

def take_rows(exp_name, offset_ind=5000, workers=5, episode=0):
    exp_folder = MAIN_PATH + '/results/' + exp_name + '/' 

    dfs = []
    files = [exp_folder + f'testing/worker_episode_summary_{offset_ind + n}.csv' for n in range(workers)]
    for file in files:
        df = pd.read_csv(file)
        df["full_length"] = (df['t'] == 288).astype(int)
        df["adj_normo"] = (df["t"] / 288) * df["normo"]
        # print(df["adj_normo"][:5])
        # print(df["t"][:5])
        # print(df["normo"][:5])
        if episode != None:
            df_filt = df[df['epi'] == episode]
            dfs.append(df_filt)
        else:
            dfs.append(df)
    return pd.concat(dfs, ignore_index=True)

def print_mn_sd(mn_sd, rounding=1):
    mn, sd = mn_sd 
    return f"{round(mn,rounding)}±{round(sd,rounding)}"

def get_mn_sd(li):
    return (np.mean(li), np.std(li))

def get_prop(li):
    return sum(li) / len(li)

def get_means_str(df, stats=True):
    rew = print_mn_sd(get_mn_sd(list(df["reward"])))
    tir = print_mn_sd(get_mn_sd(list(df["normo"])))
    adj_tir = print_mn_sd(get_mn_sd(list(df["adj_normo"])))
    tbr_1 = print_mn_sd(get_mn_sd(list(df["hypo"])))
    tbr_2 = print_mn_sd(get_mn_sd(list(df["sev_hypo"])))
    tar_1 = print_mn_sd(get_mn_sd(list(df["hyper"])))
    tar_2 = print_mn_sd(get_mn_sd(list(df["sev_hyper"])))
    success = get_prop(list(df["full_length"]))
    if stats:
        text_li =[]
        for stat in [rew, tir, adj_tir, tbr_1, tbr_2, tar_1, tar_2]: text_li += list(stat.split('±'))
        text_li.append(success)
        return ','.join([str(i) for i in text_li])
    
    else:
        return f" & {rew} & {tir} & {tbr_1} & {tbr_2} & {tar_1} & {tar_2} & {round(100*success,2)}\\% \\\\ \\hline"
        # return f" Reward: {rew}, Adj TIR: {adj_tir}, TIR: {tir}, TBR1: {tbr_1}, TBR2: {tbr_2}, TAR1: {tar_1}, TAR2: {tar_2}, Success: {round(100*success,2)}%"

def print_run_row(df_row, nickname):
    show_li = [
        nickname,
        ("adlscnt" if df_row["patient_id"] < 10 else "adult"),
        df_row["patient_id"],
        df_row["seed"],

        df_row["reward"],
        df_row["normo"],
        df_row["hypo"],
        df_row["sev_hypo"],
        df_row["hyper"],
        df_row["sev_hyper"],
        df_row["t"],
        df_row["full_length"],
    ]

    return ','.join( [str(i) for i in show_li]  )

command = []
dis_len = 3.0 * 1000000
# dis_len = 2 * 100000

disp_arr = display_commands_v2(command)



In [None]:
### GET RESULTS FROM VALIDATION TRIALS

""" 
5000 - Validation during training
8000 - early stopping
9000 - picking best adj normoglycemia (TIR)
"""
CSV_MODE = False
IQL_EARLY_STOPPING = True
SHOW_PER_RUN = False
SHOW_PER_SUBJ = False
REWARD_ONLY = False


SEP_ST = ',' if CSV_MODE else ' '

if REWARD_ONLY: GET_STR_FUNC = lambda t_df : "Reward: " + print_mn_sd(get_mn_sd(list(t_df["reward"]))) if not CSV_MODE else ','.join([str(i) for i in get_mn_sd(list(t_df["reward"]))])
else: GET_STR_FUNC = lambda t_df :  get_means_str(t_df, CSV_MODE)

NICKNAMES = {
    "offline_td3_hp5" : "TD3+BC",
    "offline_iql_t3"  : "IQL"
}



SUBJ_NAMES = ["adolescent" + str(n) for n in range(10)] + ['No Name']*10 + ["adult" + str(n) for n in range(10)]

if CSV_MODE: 
    if SHOW_PER_RUN:
        print(','.join(PER_RUN_HEADERS))
    else:
        if SHOW_PER_SUBJ: prefix = ("name", "subject")
        else: prefix = ("name", "cohort")

        if REWARD_ONLY: print(*prefix, "reward_mn","reward_sd", sep=',')
        else: 
            print(','.join(list(prefix)+ HEADERS))
else:
    print("name & cohort & reward & TIR & TBR1 & TBR2 & TAR1 & TAR2 & Success \\\\ \\hline")

for RUN_NAME in [
    # "offline_td3_hp5",
    # "offline_td3_clnonly",
    # "offline_td3_clnandevl",
    # "offline_td3_trnonly",

    # "offline_td3_nobc",
    # "offline_td3_purebc",

    "offline_iql_t3" ,
    "offline_iql_clnandevl",

    # "offline_td3_ope",
    # "offline_iql_ope",
    # "offline_td3_ope_purebc"

    ]:

    rows_total = []
    for patient_id in list(range(0,10)) + list(range(20,30)):
        patient_rows = []
        for seed in range(3):
            exp_name = f"{RUN_NAME}_{patient_id}_{seed}"

            reward_list = generate_reward_list(exp_name)

            #early stopping from validation
            # rows = take_rows(exp_name, 8000, 50, None)

            #best reward from validation
            # rows = take_rows(exp_name, 9000, 50, None)

            #early stopping from training
            # chosen_ind = generate_end_index(exp_name, use_alt=IQL_EARLY_STOPPING)
            # rows = take_rows(exp_name, 5000, 4, chosen_ind+1)

             #take last net
            # chosen_ind = len(reward_list)-1
            # rows = take_rows(exp_name, 5000, 50, chosen_ind+1)

            # take last net from validation
            chosen_ind = len(reward_list)-1
            rows = take_rows(exp_name, 6000, 5, 1)

            #take best reward
            # chosen_ind = find_best_tir_index(exp_name, bound=len(reward_list), workers=4) 
            # chosen_ind = max(list(range(len(reward_list))), key = lambda ind : reward_list[ind])
            # rows = take_rows(exp_name, 5000, 4, chosen_ind+1)

            rows["seed"] = seed
            rows["patient_id"] = patient_id

            patient_rows.append(rows)
        
        #add all rows
        rows_total += patient_rows

        #add best row only
        # best_row = max(patient_rows, key = lambda row : get_mn_sd(row["adj_normo"])[0])
        # rows_total.append(best_row)



    total_df = pd.concat(rows_total, ignore_index=True)
    if SHOW_PER_RUN:
        nickname = NICKNAMES[RUN_NAME] if RUN_NAME in NICKNAMES else RUN_NAME
        for _, row in total_df.iterrows():
            print(print_run_row(row, nickname))

        

    elif SHOW_PER_SUBJ:
        for subj in list(range(20,30)) + list(range(0,10)):
            subj_df = total_df[total_df["patient_id"] == subj]
            nickname = NICKNAMES[RUN_NAME] if RUN_NAME in NICKNAMES else RUN_NAME
            print(nickname, SUBJ_NAMES[subj], GET_STR_FUNC(subj_df), sep=SEP_ST)
    else:
        # print(total_df)
        adult_df = total_df[total_df["patient_id"] > 19]
        nickname = NICKNAMES[RUN_NAME] if RUN_NAME in NICKNAMES else RUN_NAME
        print(nickname,"& adult",GET_STR_FUNC(adult_df), sep=SEP_ST)

        adolescent_df = total_df[total_df["patient_id"] < 10]
        print(nickname,"& adlscnt",GET_STR_FUNC(adolescent_df), sep=SEP_ST)
    


    # total_df.to_csv(f"raw_runs_vld_early.csv", index=False)



name & cohort & reward & TIR & TBR1 & TBR2 & TAR1 & TAR2 & Success \\ \hline
IQL & adult  & 212.4±33.5 & 31.6±7.2 & 0.0±0.0 & 0.0±0.0 & 22.3±8.7 & 46.2±13.8 & 100.0\% \\ \hline
IQL & adlscnt  & 205.8±46.9 & 49.5±11.6 & 0.6±3.7 & 0.3±1.8 & 15.7±10.9 & 33.9±18.1 & 71.33\% \\ \hline
offline_iql_clnandevl & adult  & 178.0±53.1 & 29.8±10.0 & 0.0±0.0 & 0.0±0.0 & 15.3±7.8 & 54.9±13.4 & 72.0\% \\ \hline
offline_iql_clnandevl & adlscnt  & 184.2±43.9 & 41.8±13.9 & 0.4±2.3 & 0.4±2.2 & 15.9±10.9 & 41.6±16.4 & 61.33\% \\ \hline


In [None]:
### GET RESULTS FROM OFFLINE POLICY EVALUATION

print("name,patient_id,seed,critic_loss_mean,critic_loss_sd,critic_loss_n,critic_eval_mean,critic_eval_sd,critic_eval_n,action_diff_mean,action_diff_sd,action_diff_n,action_diff_ins_mean,action_diff_ins_sd,action_diff_ins_n")
for RUN_NAME in [
    # "offline_td3_hp5",
    # "offline_td3_clnonly",
    # "offline_td3_clnandevl",
    # "offline_td3_trnonly",

    # "offline_td3_nobc",
    # "offline_td3_purebc",

    # "offline_iql_t3" ,
    # "offline_iql_clnandevl",

    "offline_td3_ope",
    "offline_iql_ope",
    "offline_td3_ope_purebc"

    ]:
    for patient_id in list(range(0,10)) + list(range(20,30)):
        for seed in range(3):
            exp_name = f"{RUN_NAME}_{patient_id}_{seed}" 
            ope_path = MAIN_PATH + '/results/' + exp_name + '/ope_summary.csv'
            with open(ope_path, 'r') as f:
                print(RUN_NAME,patient_id,seed,f.read().splitlines()[1].replace(';',','),sep=',')



name,patient_id,seed,critic_loss_mean,critic_loss_sd,critic_loss_n,critic_eval_mean,critic_eval_sd,critic_eval_n,action_diff_mean,action_diff_sd,action_diff_n,action_diff_ins_mean,action_diff_ins_sd,action_diff_ins_n
offline_td3_ope,0,0,0.16345657,3.2794454,1024,64.78299,24.728788,1024,0.02956325,0.041293684,1024,0.0003644476071083862,0.001630906456688033,1024
offline_td3_ope,0,1,0.3202726,3.9404101,1024,65.15315,24.590292,1024,0.022304222,0.030507877,1024,0.0002925705478468318,0.0009212276684817768,1024
offline_td3_ope,0,2,0.16469747,3.6956928,1024,66.259674,23.277203,1024,0.030081172,0.04299115,1024,0.00040759255735961985,0.0016020982973573518,1024
offline_td3_ope,1,0,0.29644048,3.456633,1024,50.81781,21.525686,1024,0.03726221,0.05122145,1024,0.0010297716261845466,0.004611724024166396,1024
offline_td3_ope,1,1,0.34887636,3.7611873,1024,50.832905,21.831858,1024,0.04212531,0.05961373,1024,0.0012704214504780033,0.008641182851947838,1024
offline_td3_ope,1,2,0.47489777,4.87359,1024,50.9777