# Find best survival augmentation parameters

This notebooks does experiments to find the best augmentation parameters for the survival models and evaluates them.

In [2]:
import os

import pandas as pd
import mlflow
from dotenv import load_dotenv
import plotly.express as px

from cross_validate_survival import cross_validate_survival_model

In [3]:
load_dotenv()
mlflow.set_tracking_uri(os.environ.get("MLFLOW_TRACKING_URL"))

In [4]:
df_raw_merged = pd.read_parquet("./data/my_datasets/raw_merged.parquet")
all_device_uuids = df_raw_merged["device_uuid"].unique()

In [5]:
def load_exp_results(exp_name):
    run_path = "/home/nkuechen/Documents/Thesis"
    os.chdir(run_path)
    print(os.getcwd())

    surv_id = mlflow.get_experiment_by_name(exp_name)
    if surv_id is None:
        print(f"Could not find experiment {exp_name}")
        return None
    else:
        output_csv = f"code/thesis_code/data/aug_effect/{exp_name}.csv"
        command = f"mlflow experiments csv -o {output_csv} -x {surv_id.experiment_id}"
        os.system(command)
        return pd.read_csv(output_csv)

## Einfluss von Augmentationsparametern auf Metriken

In [6]:
surv_experiments = {}
for param, param_bool, param_values in [
    ("max_noise", "add_noise", [0, 10, 20]),
    ("max_noise_temperature", "add_noise_temperature", [0, 5]),
    ("random_max_time_warp_percent", "random_warp_status_times", [0, 0.5, 1]),
]:
    EXP_NAME = f"surv_{param}"
    surv_experiments[EXP_NAME] = (param, param_bool, param_values)
    if mlflow.get_experiment_by_name(EXP_NAME) is None:
        for n_dev in [63, 63, 63, 63, 63]:
            for model_class in ["CoxPHFitter", "RandomSurvivalForest"]:
                for param_value in param_values:
                    cross_validate_survival_model(
                        raw_merged_df=df_raw_merged,
                        model_class=model_class,
                        n_dev=n_dev,
                        n_aug=1,
                        train_df_params={
                            param_bool: True,
                            param: param_value,
                        },
                        all_device_uuids=all_device_uuids,
                        by_metric="metrics.c_index_ipcw",
                        mlflow_experiment=EXP_NAME,
                    )
    else:
        print(f"Skipping {EXP_NAME} because the experiment already exists.")

Skipping surv_max_noise because the experiment already exists.
Skipping surv_max_noise_temperature because the experiment already exists.
Skipping surv_random_max_time_warp_percent because the experiment already exists.


In [17]:
metric_to_name = {"metrics.ibs": "IBS", "metrics.c_index_ipcw": "C-Index IPCW"}
param_to_name = {
    "max_noise": "<i>max_jittering_battery_level</i>",
    "max_noise_temperature": "<i>max_jittering_air_temperature</i>",
    "random_max_time_warp_percent": "<i>max_jittering_measurement_interval</i>",
}

In [19]:
for surv_exp, (param, param_bool, param_values) in surv_experiments.items():
    for metric in ["metrics.c_index_ipcw", "metrics.ibs"]:
        surv_exp_results = load_exp_results(surv_exp)
        if surv_exp_results is not None:
            try:
                fig = px.box(
                    surv_exp_results,
                    x=f"params.{param}",
                    y=metric,
                    color="params.model_class",
                    title=f"Einfluss von {param_to_name[param]} auf den {metric_to_name[metric]}.",
                    width=600,
                    height=600,
                )
                
                fig.update_layout(legend_title="Modelltyp")
                fig.update_xaxes(title=param_to_name[param])
                fig.update_yaxes(title=metric_to_name[metric])
                
                fig.show()

                medians = {}
                for model_class, model_group in surv_exp_results.groupby(
                    by="params.model_class"
                ):
                    medians[model_class] = {}
                    for param_value, param_group in model_group.groupby(
                        by=f"params.{param}"
                    ):
                        medians[model_class][param_value] = param_group[metric].median()
                print(medians)
                medians_df = pd.DataFrame(medians).T
                pd.options.display.float_format = "{:.3f}".format
                print(medians_df)

            except ValueError:
                print(f"{surv_exp} does not seem to be finished.")
                raise

/home/nkuechen/Documents/Thesis
Experiment with ID 722659421398569041 has been exported as a CSV to file: code/thesis_code/data/aug_effect/surv_max_noise.csv.


{'CoxPHFitter': {0: 0.911025265517968, 10: 0.883786415912967, 20: 0.8934030329832127}, 'RandomSurvivalForest': {0: 0.9170407137330226, 10: 0.8797832793305128, 20: 0.9032985828954524}}
                        0     10    20
CoxPHFitter          0.911 0.884 0.893
RandomSurvivalForest 0.917 0.880 0.903
/home/nkuechen/Documents/Thesis
Experiment with ID 722659421398569041 has been exported as a CSV to file: code/thesis_code/data/aug_effect/surv_max_noise.csv.


{'CoxPHFitter': {0: 0.1414780421110422, 10: 0.0812717210801253, 20: 0.0748901789574013}, 'RandomSurvivalForest': {0: 0.1071394402156528, 10: 0.0783869576359958, 20: 0.0745946270946012}}
                        0     10    20
CoxPHFitter          0.141 0.081 0.075
RandomSurvivalForest 0.107 0.078 0.075
/home/nkuechen/Documents/Thesis
Experiment with ID 374792384496278134 has been exported as a CSV to file: code/thesis_code/data/aug_effect/surv_max_noise_temperature.csv.


{'CoxPHFitter': {0: 0.8876060479597238, 5: 0.9251955181290944}, 'RandomSurvivalForest': {0: 0.9031954637172714, 5: 0.8672459228065669}}
                         0     5
CoxPHFitter          0.888 0.925
RandomSurvivalForest 0.903 0.867
/home/nkuechen/Documents/Thesis
Experiment with ID 374792384496278134 has been exported as a CSV to file: code/thesis_code/data/aug_effect/surv_max_noise_temperature.csv.


{'CoxPHFitter': {0: 0.1228887989387457, 5: 0.1043801621520255}, 'RandomSurvivalForest': {0: 0.1076418242635192, 5: 0.1175934237139336}}
                         0     5
CoxPHFitter          0.123 0.104
RandomSurvivalForest 0.108 0.118
/home/nkuechen/Documents/Thesis
Experiment with ID 498940377533771453 has been exported as a CSV to file: code/thesis_code/data/aug_effect/surv_random_max_time_warp_percent.csv.


{'CoxPHFitter': {0.0: 0.9090590694321012, 0.5: 0.8982415532199213, 1.0: 0.8935735329404716}, 'RandomSurvivalForest': {0.0: 0.9138342006202284, 0.5: 0.9137510007536248, 1.0: 0.902268032304078}}
                      0.000  0.500  1.000
CoxPHFitter           0.909  0.898  0.894
RandomSurvivalForest  0.914  0.914  0.902
/home/nkuechen/Documents/Thesis
Experiment with ID 498940377533771453 has been exported as a CSV to file: code/thesis_code/data/aug_effect/surv_random_max_time_warp_percent.csv.


{'CoxPHFitter': {0.0: 0.119906446784803, 0.5: 0.0848494059540229, 1.0: 0.0781044852837499}, 'RandomSurvivalForest': {0.0: 0.1161328141415564, 0.5: 0.0824170532572442, 1.0: 0.0817900806051935}}
                      0.000  0.500  1.000
CoxPHFitter           0.120  0.085  0.078
RandomSurvivalForest  0.116  0.082  0.082


In [8]:
best_params = {
    "CoxPHFitter": {
        "max_noise": 20,
        "add_noise": True,
        "max_noise_temperature": 5,
        "add_noise_temperature": True,
        "random_max_time_warp_percent": 1.0,
        "random_warp_status_times": True,
    },
    "RandomSurvivalForest": {
        "max_noise": 20,
        "add_noise": True,
        "max_noise_temperature": 0,
        "add_noise_temperature": True,
        "random_max_time_warp_percent": 1.0,
        "random_warp_status_times": True,
    },
}

# for n_dev in [10, 20, 40]:
#     EXP_NAME = "surv_best_n_aug"
#     for i in range(5):
#         for model_class, train_df_params in best_params.items():
#             for n_aug in [1, 3, 5, 10]:
#                 cross_validate_survival_model(
#                     raw_merged_df=df_raw_merged,
#                     model_class=model_class,
#                     n_dev=n_dev,
#                     n_aug=n_aug,
#                     train_df_params=train_df_params,
#                     all_device_uuids=all_device_uuids,
#                     by_metric="metrics.ibs",
#                     mlflow_experiment=EXP_NAME,
#                 )

/home/nkuechen/Documents/Thesis/mlruns
!! Creating new split...
###Split 1/4!
c_index_ipcw=0.8514283994446536; ibs=0.12272213244863928
###Split 2/4!
c_index_ipcw=0.9284215111674322; ibs=0.058138277931657566
###Split 3/4!
c_index_ipcw=0.7015953055089669; ibs=0.220236186925914
###Split 4/4!
c_index_ipcw=0.8345338770871824; ibs=0.09967618710817203
/home/nkuechen/Documents/Thesis/mlruns
!! Creating new split...
###Split 1/4!
c_index_ipcw=0.8345679265816227; ibs=0.17938118366641606
###Split 2/4!


KeyboardInterrupt: 

In [None]:
metrics = ["metrics.ibs", "metrics.c_index_ipcw"]
surv_exp_results = load_exp_results(EXP_NAME)
for metric in metrics:
    if surv_exp_results is not None:
        try:
            for n_dev, n_dev_group in surv_exp_results.groupby(by="params.n_dev"):
                print(n_dev)
                fig = px.box(
                    n_dev_group,
                    x=f"params.n_aug",
                    y=metric,
                    color="params.model_class",
                    title=f"Einfluss des Augmentationsanteils auf den {metric} ({n_dev} Trainingsgeräte).",
                    width=1000,
                    height=600,
                )
                fig.show()

                medians = {}
                for model_class, model_group in n_dev_group.groupby(
                    by="params.model_class"
                ):
                    medians[model_class] = {}
                    for param_value, param_group in model_group.groupby(
                        by=f"params.n_aug"
                    ):
                        medians[model_class][param_value] = param_group[metric].median()
                print(medians)
                medians_df = pd.DataFrame(medians).T
                pd.options.display.float_format = "{:.3f}".format
                print(medians_df)

        except ValueError:
            print(f"{EXP_NAME} does not seem to be finished.")
            raise

/home/nkuechen/Documents/Thesis
Experiment with ID 549101132328494905 has been exported as a CSV to file: code/thesis_code/data/aug_effect/surv_best_n_aug.csv.
10


{'CoxPHFitter': {1: 0.1233255903529012, 3: 0.1273578312169854, 5: 0.128836803158937, 10: 0.1385599433686683}, 'RandomSurvivalForest': {1: 0.1318152343095358, 3: 0.1325563961075138, 5: 0.1372332058723137, 10: 0.1602220008534381}}
                        1     3     5     10
CoxPHFitter          0.123 0.127 0.129 0.139
RandomSurvivalForest 0.132 0.133 0.137 0.160
20


{'CoxPHFitter': {1: 0.1003221002023705, 3: 0.1187229034393419, 5: 0.1134429382142091, 10: 0.1417889671461055}, 'RandomSurvivalForest': {1: 0.0921202751244178, 3: 0.1187523344824649, 5: 0.1149388558319368, 10: 0.122744088474559}}
                        1     3     5     10
CoxPHFitter          0.100 0.119 0.113 0.142
RandomSurvivalForest 0.092 0.119 0.115 0.123
40


{'CoxPHFitter': {1: 0.077595432466064, 3: 0.1052388078626398, 5: 0.1411317531800066, 10: 0.1667454600129219}, 'RandomSurvivalForest': {1: 0.0794084970761958, 3: 0.1080476178883913, 5: 0.1451717129304351, 10: 0.1626288216927617}}
                        1     3     5     10
CoxPHFitter          0.078 0.105 0.141 0.167
RandomSurvivalForest 0.079 0.108 0.145 0.163
63


{'CoxPHFitter': {1: 0.0746495119241774, 3: 0.0868421561931239, 5: 0.1110299081560714, 10: 0.1411311260077432}, 'RandomSurvivalForest': {1: 0.071879425953094, 3: 0.0899734580803458, 5: 0.0998735878882198, 10: 0.099973832909225}}
                        1     3     5     10
CoxPHFitter          0.075 0.087 0.111 0.141
RandomSurvivalForest 0.072 0.090 0.100 0.100
10


{'CoxPHFitter': {1: 0.8324820796000194, 3: 0.7935810833973016, 5: 0.8721587901708034, 10: 0.8303688786958806}, 'RandomSurvivalForest': {1: 0.8538165776683185, 3: 0.8594103906322789, 5: 0.8612782769463693, 10: 0.835141126228242}}
                        1     3     5     10
CoxPHFitter          0.832 0.794 0.872 0.830
RandomSurvivalForest 0.854 0.859 0.861 0.835
20


{'CoxPHFitter': {1: 0.8697312832332095, 3: 0.8570729448118408, 5: 0.8777612924600435, 10: 0.8781814467009876}, 'RandomSurvivalForest': {1: 0.8799575692480897, 3: 0.8854830613846141, 5: 0.8446318823871816, 10: 0.8293954061970635}}
                        1     3     5     10
CoxPHFitter          0.870 0.857 0.878 0.878
RandomSurvivalForest 0.880 0.885 0.845 0.829
40


{'CoxPHFitter': {1: 0.8901404708012876, 3: 0.8591310408877293, 5: 0.8861568404040516, 10: 0.878674430252512}, 'RandomSurvivalForest': {1: 0.8962603478656659, 3: 0.8679190089335711, 5: 0.8523386177956631, 10: 0.8666785815707215}}
                        1     3     5     10
CoxPHFitter          0.890 0.859 0.886 0.879
RandomSurvivalForest 0.896 0.868 0.852 0.867
63


{'CoxPHFitter': {1: 0.8938691130082416, 3: 0.8792632051800933, 5: 0.8833649631341653, 10: 0.8869818416857144}, 'RandomSurvivalForest': {1: 0.8918477104096155, 3: 0.8850359645229194, 5: 0.8850587644698535, 10: 0.89074749289642}}
                        1     3     5     10
CoxPHFitter          0.894 0.879 0.883 0.887
RandomSurvivalForest 0.892 0.885 0.885 0.891
