# R4 on ISIC 2019

In [None]:
%load_ext autoreload
%autoreload 2
import torch
import tqdm
import random
import matplotlib.pyplot as plt
import os
import sys
sys.path.append(os.path.abspath('..'))
import numpy as np
from models.R4_models import LesionNet
from models.pipeline import (train_model_with_certified_input_grad, train_model_with_pgd_robust_input_grad, write_results_to_file,
                             test_model_avg_and_wg_accuracy, train_model_with_smoothed_input_grad, test_model_accuracy,
                             test_delta_input_robustness, test_model_avg_and_wg_accuracy)
from metrics import (get_restart_avg_and_worst_group_accuracy_with_stddev, get_restart_macro_avg_acc_over_labels_with_stddev)
from datasets import isic

### Note: for ISIC, we need to separately calculate the macro averaged accuracy over the LABELS, because the groups differ from the labels in this dataset, i.e.: labels are 0 and 1 (benign and malignant), while the groups are 0, 1 and 2 representing ("cancer", "no patch no cancer" and "patch no cancer").

# Get the dataloaders 

In [None]:
CUDA_LAUNCH_BLOCKING=1
SEED = 0
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
criterion = torch.nn.BCELoss()

In [None]:
DATA_ROOT = "/vol/bitbucket/mg2720/isic/"

isic_train = isic.ISICDataset(DATA_ROOT, is_train=True)
isic_test = isic.ISICDataset(DATA_ROOT, is_train=False)
isic_test_grouped = isic.ISICDataset(DATA_ROOT, is_train=False, grouped=True)
img, lbl, mask = isic_train[1]
train_pos, train_neg = (isic_train.labels == 1).sum(), (isic_train.labels == 0).sum()
print(f"Train: {train_pos} positive, {train_neg} negative")

In [None]:
idx = 0
for i in range(len(isic_train)):
    img, lbl, mask = isic_train[i]
    if mask.sum() > 0:
        idx = i
        break
plt.imshow(img.permute(1, 2, 0).squeeze().numpy())
plt.colorbar()
plt.show()
print(mask.shape, mask.sum())
plt.imshow(mask.permute(1, 2, 0).squeeze().numpy())
plt.colorbar()

In [None]:
batch_size = 256
dl_train = isic.get_loader_from_dataset(isic_train, batch_size=batch_size, shuffle=False)
dl_test = isic.get_loader_from_dataset(isic_test, batch_size=batch_size, shuffle=False)
dl_test_grouped = isic.get_loader_from_dataset(isic_test_grouped, batch_size=50, shuffle=False)

## Experiments 

In [None]:
DELTA_INPUT_ROBUSTNESS_PARAM = 1
model_root_save_dir = "saved_experiment_models/performance/isic"
os.makedirs(model_root_save_dir, exist_ok=True)
methods = ["r4", "pgd_r4", "std", "smooth_r3", "rand_r4", "ibp_ex", "ibp_ex+r3", "r3"]
save_dir_for_method = {method: os.path.join(model_root_save_dir, method) for method in methods}
for method in methods:
    os.makedirs(save_dir_for_method[method], exist_ok=True)

### Standard Training

In [None]:
std_method = "std"
# hyperparameters
class_weights = [1, 9]
num_epochs, lr, restarts, epsilon, weight_coeff, k = 8, 0.0003, 3, 0.1, -1, -1
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i + SEED)
    curr_model = LesionNet(3, 1).to(device)

    print(f"========== Training model with method {std_method} restart {i} ==========")
    train_model_with_certified_input_grad(
        dl_train, num_epochs, curr_model, lr, criterion, epsilon, std_method, k, device, True, class_weights = class_weights
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test_grouped, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test_grouped, curr_model, epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test_grouped, device, num_groups=3)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[std_method], f"run_{i}.pt"))
empty_model = LesionNet(3, 1).to(device)
avg_acc, wg_acc, wg, *_ = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test_grouped, save_dir_for_method[std_method], empty_model, device, num_groups=3
)
macro_avg_over_labels, _ = get_restart_macro_avg_acc_over_labels_with_stddev(
    dl_test_grouped, save_dir_for_method[std_method], empty_model, device, num_classes=2
)
write_results_to_file("experiment_results/isic.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "macro_avg_over_labels": round(macro_avg_over_labels, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)
                       }, std_method)
write_results_to_file("experiment_results/isic_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": epsilon,
                         "k": k,
                         "weight_coeff": weight_coeff,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "multi_class": False,
                         "has_conv": True,
                         "with_k_schedule": False}, std_method)

### RRR Training

In [None]:
rrr_method = "r3"
# hyperparameters
class_weights = [1, 9]
num_epochs, lr, restarts, epsilon, weight_coeff, k = 10, 0.0003, 3, 0.1, -1, 0.1
test_epsilon = 0.1
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i + SEED)
    curr_model = LesionNet(3, 1).to(device)

    print(f"========== Training model with method {rrr_method} restart {i} ==========")
    train_model_with_pgd_robust_input_grad(
        dl_train, num_epochs, curr_model, lr, criterion, epsilon, rrr_method, k, device, True, class_weights = class_weights
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test_grouped, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test_grouped, curr_model, epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test_grouped, device, num_groups=3)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[rrr_method], f"run_{i}.pt"))
empty_model = LesionNet(3, 1).to(device)
avg_acc, wg_acc, wg, *_ = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test_grouped, save_dir_for_method[rrr_method], empty_model, device, num_groups=3
)
macro_avg_over_labels, _ = get_restart_macro_avg_acc_over_labels_with_stddev(
    dl_test_grouped, save_dir_for_method[rrr_method], empty_model, device, num_classes=2
)
write_results_to_file("experiment_results/isic.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "macro_avg_over_labels": round(macro_avg_over_labels, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)
                       }, rrr_method)
write_results_to_file("experiment_results/isic_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": test_epsilon,
                         "k": k,
                         "weight_coeff": weight_coeff,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "multi_class": False,
                         "has_conv": True,
                         "with_k_schedule": False}, rrr_method)

### Smoothed-R3 Training

In [None]:
smooth_r3_method = "smooth_r3"
# hyperparameters
class_weights = [1, 9]
num_epochs, lr, weight_decay, restarts, epsilon, k, num_samples = 14, 0.0007, 1e-5, 3, 0.1, 0.1, 3
test_epsilon = 0.1
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i + SEED)
    curr_model = LesionNet(3, 1).to(device)

    print(f"========== Training model with method {smooth_r3_method} restart {i} ==========")
    train_model_with_smoothed_input_grad(
        dl_train, num_epochs, curr_model, lr, criterion, epsilon, smooth_r3_method, k, device, weight_decay=weight_decay,
        class_weights = class_weights, num_samples=num_samples
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test_grouped, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test_grouped, curr_model, epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test_grouped, device, num_groups=3)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[smooth_r3_method], f"run_{i}.pt"))
empty_model = LesionNet(3, 1).to(device)
avg_acc, wg_acc, wg, _, std_dev_wg = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test_grouped, save_dir_for_method[smooth_r3_method], empty_model, device, num_groups=3
)
macro_avg_over_labels, std_dev_labels = get_restart_macro_avg_acc_over_labels_with_stddev(
    dl_test_grouped, save_dir_for_method[smooth_r3_method], empty_model, device, num_classes=2
)
write_results_to_file("experiment_results/isic.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "macro_avg_over_labels": round(macro_avg_over_labels, 5),
                       "std_dev_over_labels": round(std_dev_labels, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "std_dev_worst_group": round(std_dev_wg, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)
                       }, smooth_r3_method)
write_results_to_file("experiment_results/isic_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": test_epsilon,
                         "k": k,
                         "weight_decay": weight_decay,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "multi_class": False,
                         "has_conv": True,
                         "num_samples": num_samples,
                         "with_k_schedule": False}, smooth_r3_method)

## IBP-Ex Training

In [None]:
ibp_ex_method = "ibp_ex"
# hyperparameters
class_weights = [1, 9]
num_epochs, lr, restarts, epsilon, weight_coeff, k = 12, 0.001, 3, 0.2, 5e-5, 0.3
test_epsilon = 0.1
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i + SEED)
    curr_model = LesionNet(3, 1).to(device)

    print(f"========== Training model with method {ibp_ex_method} restart {i} ==========")
    train_model_with_certified_input_grad(
        dl_train, num_epochs, curr_model, lr, criterion, epsilon, ibp_ex_method, k, device, True, class_weights = class_weights
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test_grouped, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test_grouped, curr_model, epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test_grouped, device, num_groups=3)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[ibp_ex_method], f"run_{i}.pt"))
empty_model = LesionNet(3, 1).to(device)
avg_acc, wg_acc, wg, *_ = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test_grouped, save_dir_for_method[ibp_ex_method], empty_model, device, num_groups=3
)
macro_avg_over_labels, _ = get_restart_macro_avg_acc_over_labels_with_stddev(
    dl_test_grouped, save_dir_for_method[ibp_ex_method], empty_model, device, num_classes=2
)
write_results_to_file("experiment_results/isic.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "macro_avg_over_labels": round(macro_avg_over_labels, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)
                       }, ibp_ex_method)
write_results_to_file("experiment_results/isic_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": test_epsilon,
                         "k": k,
                         "weight_coeff": weight_coeff,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "multi_class": False,
                         "has_conv": True,
                         "with_k_schedule": False}, ibp_ex_method)

## IBP-Ex+R3 Training

In [None]:
ibp_ex_and_r3_method = "ibp_ex+r3"
# hyperparameters
class_weights = [1, 9]
num_epochs, lr, restarts, epsilon, weight_coeff, k = 14, 0.001, 3, 0.25, 1e-4, 0.325
test_epsilon = 0.1
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i + SEED)
    curr_model = LesionNet(3, 1).to(device)

    print(f"========== Training model with method {ibp_ex_and_r3_method} restart {i} ==========")
    train_model_with_certified_input_grad(
        dl_train, num_epochs, curr_model, lr, criterion, epsilon, ibp_ex_and_r3_method, k, device, True, class_weights = class_weights
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test_grouped, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test_grouped, curr_model, epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test_grouped, device, num_groups=3)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[ibp_ex_and_r3_method], f"run_{i}.pt"))
empty_model = LesionNet(3, 1).to(device)
avg_acc, wg_acc, wg, *_ = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test_grouped, save_dir_for_method[ibp_ex_and_r3_method], empty_model, device, num_groups=3
)
macro_avg_over_labels, _ = get_restart_macro_avg_acc_over_labels_with_stddev(
    dl_test_grouped, save_dir_for_method[ibp_ex_and_r3_method], empty_model, device, num_classes=2
)
write_results_to_file("experiment_results/isic.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "macro_avg_over_labels": round(macro_avg_over_labels, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)
                       }, ibp_ex_and_r3_method)
write_results_to_file("experiment_results/isic_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": test_epsilon,
                         "k": k,
                         "weight_coeff": weight_coeff,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "multi_class": False,
                         "has_conv": True,
                         "with_k_schedule": False}, ibp_ex_and_r3_method)

## R4 Training

In [None]:
r4_method = "r4"
# hyperparameters
class_weights = [1, 9]
num_epochs, lr, restarts, epsilon, weight_coeff, k, weight_decay = 25, 0.001, 3, 0.3, -1, 1, 1e-4
test_epsilon = 0.1
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i + SEED)
    curr_model = LesionNet(3, 1).to(device)

    print(f"========== Training model with method {r4_method} restart {i} ==========")
    train_model_with_certified_input_grad(
        dl_train, num_epochs, curr_model, lr, criterion, epsilon, r4_method, k, device, True,
        class_weights=class_weights, weight_decay=weight_decay
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test_grouped, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test_grouped, curr_model, epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test_grouped, device, num_groups=3)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[r4_method], f"run_{i}.pt"))
empty_model = LesionNet(3, 1).to(device)
avg_acc, wg_acc, wg, *_ = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test_grouped, save_dir_for_method[r4_method], empty_model, device, num_groups=3
)
macro_avg_over_labels, _ = get_restart_macro_avg_acc_over_labels_with_stddev(
    dl_test_grouped, save_dir_for_method[r4_method], empty_model, device, num_classes=2
)
write_results_to_file("experiment_results/isic.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "macro_avg_over_labels": round(macro_avg_over_labels, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)
                       }, r4_method)
write_results_to_file("experiment_results/isic_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": test_epsilon,
                         "k": k,
                         "weight_coeff": weight_coeff,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "multi_class": False,
                         "weight_decay": weight_decay,
                         "has_conv": True,
                         "with_k_schedule": False}, r4_method)

## PGD-R4 Training

In [None]:
pgd_r4_method = "pgd_r4"
# hyperparameters
class_weights = [1, 9]
num_epochs, lr, restarts, epsilon, weight_coeff, k, weight_decay = 25, 0.0008, 3, 0.1, 1e-5, 0.0005, 1e-4
test_epsilon = 0.1
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i + SEED)
    curr_model = LesionNet(3, 1).to(device)

    print(f"========== Training model with method {pgd_r4_method} restart {i} ==========")
    train_model_with_pgd_robust_input_grad(
        dl_train, num_epochs, curr_model, lr, criterion, epsilon, pgd_r4_method, k, device, True,
        class_weights=class_weights, weight_decay=weight_decay, num_iterations=5
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test_grouped, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test_grouped, curr_model, epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test_grouped, device, num_groups=3)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[pgd_r4_method], f"run_{i}.pt"))
empty_model = LesionNet(3, 1).to(device)
avg_acc, wg_acc, wg, *_ = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test_grouped, save_dir_for_method[pgd_r4_method], empty_model, device, num_groups=3
)
macro_avg_over_labels, _ = get_restart_macro_avg_acc_over_labels_with_stddev(
    dl_test_grouped, save_dir_for_method[pgd_r4_method], empty_model, device, num_classes=2
)
write_results_to_file("experiment_results/isic.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "macro_avg_over_labels": round(macro_avg_over_labels, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)
                       }, pgd_r4_method)
write_results_to_file("experiment_results/isic_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": test_epsilon,
                         "k": k,
                         "weight_coeff": weight_coeff,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "weight_decay": weight_decay,
                         "multi_class": False,
                         "has_conv": True,
                         "with_k_schedule": False}, pgd_r4_method)

## Rand-R4 Training

In [None]:
rand_r4_method = "rand_r4"
# hyperparameters
class_weights = [1, 9]
num_epochs, lr, restarts, epsilon, weight_coeff, k, weight_decay, num_samples = 25, 0.0008, 3, 0.1, 1e-5, 0.0007, 1e-4, 3
test_epsilon = 0.1
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i + SEED)
    curr_model = LesionNet(3, 1).to(device)

    print(f"========== Training model with method {rand_r4_method} restart {i} ==========")
    train_model_with_smoothed_input_grad(
        dl_train, num_epochs, curr_model, lr, criterion, epsilon, rand_r4_method, k, device,
        class_weights=class_weights, weight_decay=weight_decay, num_samples=num_samples
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test_grouped, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test_grouped, curr_model, epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test_grouped, device, num_groups=3)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[rand_r4_method], f"run_{i}.pt"))
empty_model = LesionNet(3, 1).to(device)
avg_acc, wg_acc, wg, _, std_dev_wg = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test_grouped, save_dir_for_method[rand_r4_method], empty_model, device, num_groups=3
)
macro_avg_over_labels, std_dev_all = get_restart_macro_avg_acc_over_labels_with_stddev(
    dl_test_grouped, save_dir_for_method[rand_r4_method], empty_model, device, num_classes=2
)
write_results_to_file("experiment_results/isic.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "macro_avg_over_labels": round(macro_avg_over_labels, 5),
                       "std_dev_over_labels": round(std_dev_all, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "std_dev_wg": round(std_dev_wg, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)
                       }, rand_r4_method)
write_results_to_file("experiment_results/isic_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": test_epsilon,
                         "k": k,
                         "weight_coeff": weight_coeff,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "weight_decay": weight_decay,
                         "multi_class": False,
                         "has_conv": True,
                         "num_samples": num_samples,
                         "with_k_schedule": False}, rand_r4_method)