# Write Sweep Configs

Set up sweep configurations for hyperparameter tuning the models.

To train on a specified dataset, there are four choices which must be made:

### 1. Which model to use
This is the model architecture, which is specified by the model name.
The model name must be one of the following:
- cph:  Cox Proportional Hazards (Auton-Survival)
- dcph: Deep Cox Proportional Hazards (Auton-Survival)
- dcm:  Deep Cox Mixtures (Auton-Survival)
- dsm:  Deep Survival Machines (Auton-Survival)
- rf:   Random Forest (scikit-learn)
- km:   Kaplan-Meier (extension of Random Forest)

### 2. Which alarm to use
The metrics we are calculating are based on true positives and false positives for disruption prediction.
They are inherently influenced by which type of alarm is selected to use with the model.
The alarm name must be one of the following:
- sthr: Simple threshold. The alarm is triggered when the model's predicted risk is above a specified threshold. The list of thresholds is pre-defined in the code (just a list of values between 0 and 1)
- athr: All thresholds. Equivalent to simple threshold, however the thresholds are each unique float in the model's predicted risk.
- hyst: Hysteresis. The alarm is triggered when the model's risk exceeds an upper threshold and does not return below a lower threshold for a specified time interval.
- ettd: Expected time to disruption. An alarm is triggered when the model expects the disruption to occur within the next x seconds.
- ethy: Expected time to disrupty hysteresis. An alarm is triggered when the model expects the disruption to occur within the next x seconds, and continues to return this result for a specified time interval.

### 3. Which metric to use
The metric is the function which is used to evaluate the model's performance and tune hyperparameters.
The metric name must be one of the following:
- auroc: Area under the ROC curve, where the true positive rates and false positive rates are on a per-shot level.
- auwtc: Area under warning time curve. Similar to an ROC curve, but the y axis is warning time before disruption.
- maxf1: Maximum F1 score. The F1 score is calculated for each point on the ROC curve, and the maximum is returned.

### 4. Required warning time
For a disruption prediction to be useful, it must come with a warning time that is sufficiently long for mitigation strategies to be triggered. A prediction is only counted as a true positive if it comes before the disruption with a warning time greater than or equal to the required warning time.

In [None]:
from disruption_survival_analysis.sweep_config import make_sweep_config, write_sweep_config

# Datasets to use
devices = ["synthetic"]
dataset_paths = ["test"]

# List of models to create sweeps for
# cph, dcph, dcm, dsm, rf, km
model_types = ["dsm", "rf"]

# List of alarm types to use
# sthr, athr, hyst, ettd, ethy
alarm_types = ["sthr"]

# List of validation metrics to use
# auroc, auwtc, maxf1
metrics = ["auroc"]

# List of required warning times to train on (in seconds)
required_warning_times = [0.02]

for device in devices:
    for dataset_path in dataset_paths:
        for model_type in model_types:
            for alarm_type in alarm_types:
                for metric in metrics:
                    for required_warning_time in required_warning_times:
                        sweep_config = make_sweep_config(device, dataset_path, model_type, alarm_type, metric, required_warning_time)
                        write_sweep_config(sweep_config)