# R4 on Plant Phenotyping Dataset

In [None]:
%load_ext autoreload
%autoreload 2
import torch
import tqdm
import random
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
sys.path.append(os.path.abspath('..'))
from models.R4_models import PlantNet
from models.robust_regularizer import input_gradient_interval_regularizer
from models.pipeline import (train_model_with_pgd_robust_input_grad, train_model_with_certified_input_grad,
                             test_model_accuracy, test_delta_input_robustness, write_results_to_file,
                             load_params_or_results_from_file, uniformize_magnitudes_schedule,
                             train_model_with_smoothed_input_grad, test_model_avg_and_wg_accuracy)
from datasets import plant
from metrics import get_restart_avg_and_worst_group_accuracy_with_stddev, get_restart_macro_avg_acc_over_labels_with_stddev

# Get the dataloaders 

In [None]:
CUDA_LAUNCH_BLOCKING=1
SEED = 0
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)
random.seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
batch_size = 50
test_batch_size = 10
criterion = torch.nn.BCELoss()

In [None]:
SPLIT_ROOT = "/vol/bitbucket/mg2720/plant/rgb_dataset_splits"
DATA_ROOT = "/vol/bitbucket/mg2720/plant/rgb_data"
MASKS_FILE = "/vol/bitbucket/mg2720/plant/mask/preprocessed_masks.pyu"

plant_train_2 = plant.PlantDataset(SPLIT_ROOT, DATA_ROOT, MASKS_FILE, 2, True)
plant_test_2 = plant.PlantDataset(SPLIT_ROOT, DATA_ROOT, MASKS_FILE, 2, False)
print(len(plant_train_2), len(plant_test_2))
num_neg, num_pos = (plant_train_2.data_labels == 0).sum(), (plant_train_2.data_labels == 1).sum()
print(num_neg, num_pos)

In [None]:
dl_train = plant.get_dataloader(plant_train_2, batch_size)
dl_test = plant.get_dataloader(plant_test_2, test_batch_size)

In [None]:
def visualize_gradient(model, batch_input, batch_labels, batch_mask, epsilon, elem_idx, has_conv, curr_device):
    channel_to_view = 1
    batch_input, batch_labels, batch_mask = batch_input.to(curr_device), batch_labels.to(curr_device), batch_mask.to(curr_device)
    model.to(curr_device)
    grad_bounds = input_gradient_interval_regularizer(
        model, batch_input, batch_labels, "binary_cross_entropy", epsilon, 0.0, return_grads=True, regularizer_type="r4",
        batch_masks=batch_mask, has_conv=has_conv, device=curr_device
    )
    dx_l, dx_u = grad_bounds[1]
    dx_n, _ = grad_bounds[0]
    print(f"input lower bound shape: {dx_l.shape}")
    print(f"input upper bound shape: {dx_u.shape}")
    print(f"input gradient shape: {dx_n.shape}")
    fig, ax = plt.subplots(3, 2, figsize=(14, 13))
    lesion = batch_input[elem_idx].permute(1, 2, 0).cpu().numpy()
    mask = batch_mask[elem_idx].permute(1, 2, 0).cpu().numpy()
    # choose only 1 channel gradient to view, because with 3 channels, the bounds do not represent rgb values
    dx_l_view, dx_u_view, dx_n_view = dx_l[elem_idx][channel_to_view].squeeze(), dx_u[elem_idx][channel_to_view].squeeze(), dx_n[elem_idx][channel_to_view].squeeze()
    ax[0][0].imshow(lesion)
    ax[0][0].set_title(f"Input at index {elem_idx}")
    im_mask = ax[0][1].imshow(mask, cmap='gray')
    ax[0][1].set_title(f"Mask at index {elem_idx}")
    fig.colorbar(im_mask, ax=ax[0][1])
    im_dx_l = ax[1][0].imshow(dx_l_view.cpu().detach().numpy())
    ax[1][0].set_title(f"Lower bound of gradient at index {elem_idx}")
    fig.colorbar(im_dx_l, ax=ax[1][0])
    im_dx_u = ax[1][1].imshow(dx_u_view.cpu().detach().numpy())
    ax[1][1].set_title(f"Upper bound of gradient at index {elem_idx}")
    cbu = fig.colorbar(im_dx_u, ax=ax[1][1])
    cbu.ax.invert_yaxis()
    im_dx_n = ax[2][0].imshow(dx_n_view.cpu().detach().numpy())
    ax[2][0].set_title(f"Gradient at index {elem_idx}")
    fig.colorbar(im_dx_n, ax=ax[2][0])

## Experiments 

In [None]:
DELTA_INPUT_ROBUSTNESS_PARAM = 1
model_root_save_dir = "saved_experiment_models/performance/plant"
os.makedirs(model_root_save_dir, exist_ok=True)
methods = ["std", "r3", "r4", "ibp_ex", "ibp_ex+r3", "smooth_r3", "rand_r4", "pgd_r4"]
save_dir_for_method = {method: os.path.join(model_root_save_dir, method) for method in methods}
for method in methods:
    os.makedirs(save_dir_for_method[method], exist_ok=True)

### Standard Training

In [None]:
std_method = "std"
# hyperparameters
class_weights = [2.5, 1]
num_epochs, lr, restarts, epsilon, weight_coeff, k = 12, 0.0001, 3, 0.01, -1, -1
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i + SEED)
    curr_model = PlantNet(3, 1).to(device)

    print(f"========== Training model with method {std_method} restart {i} ==========")
    train_model_with_certified_input_grad(
        dl_train, num_epochs, curr_model, lr, criterion, epsilon, std_method, k, device, True, class_weights = class_weights
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test, curr_model, epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test, device, num_groups=2)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[std_method], f"run_{i}.pt"))
empty_model = PlantNet(3, 1).to(device)
avg_acc, wg_acc, wg, *_ = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test, save_dir_for_method[std_method], empty_model, device, num_groups=2
)
write_results_to_file("experiment_results/plant.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)
                       }, std_method)
write_results_to_file("experiment_results/plant_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": epsilon,
                         "k": k,
                         "weight_coeff": weight_coeff,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "multi_class": False,
                         "has_conv": True,
                         "with_k_schedule": False}, std_method)

### RRR Training

In [None]:
rrr_method = "r3"
# hyperparameters
class_weights = [4.5, 1]
num_epochs, lr, restarts, epsilon, weight_coeff, k = 20, 0.0001, 3, 0.01, 0.00005, 25
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i + SEED)
    curr_model = PlantNet(3, 1).to(device)

    print(f"========== Training model with method {rrr_method} restart {i} ==========")
    train_model_with_pgd_robust_input_grad(
        dl_train, num_epochs, curr_model, lr, criterion, epsilon, rrr_method, k, device, class_weights = class_weights
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test, curr_model, epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test, device, num_groups=2)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[rrr_method], f"run_{i}.pt"))
empty_model = PlantNet(3, 1).to(device)
avg_acc, wg_acc, wg, *_ = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test, save_dir_for_method[rrr_method], empty_model, device, num_groups=2
)
write_results_to_file("experiment_results/plant.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)
                       }, rrr_method)
write_results_to_file("experiment_results/plant_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": epsilon,
                         "k": k,
                         "weight_coeff": weight_coeff,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "multi_class": False,
                         "has_conv": True,
                         "with_k_schedule": False}, rrr_method)

### R4 Training

In [None]:
r4_method = "r4"
# hyperparameters
class_weights = [2.3, 1]
num_epochs, lr, restarts, epsilon, weight_coeff, k, alpha = 12, 0.0002, 3, 0.02, -1, 0.01, 0.7
test_epsilon = 0.01
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
new_dl_train = plant.make_soft_masks(dl_train, alpha)
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    if i == 2:
        torch.manual_seed(0)
    else:
        torch.manual_seed(i)
    curr_model = PlantNet(3, 1).to(device)

    print(f"========== Training model with method {r4_method} restart {i} ==========")
    train_model_with_certified_input_grad(
        new_dl_train, num_epochs, curr_model, lr, criterion, epsilon, r4_method, k, device, True, class_weights = class_weights
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, new_dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test, curr_model, test_epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test, device, num_groups=2)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[r4_method], f"run_{i}.pt"))
empty_model = PlantNet(3, 1).to(device)
avg_acc, wg_acc, wg, *_ = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test, save_dir_for_method[r4_method], empty_model, device, num_groups=2
)
write_results_to_file("experiment_results/plant.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)}, r4_method)
write_results_to_file("experiment_results/plant_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": test_epsilon,
                         "k": k,
                         "weight_coeff": weight_coeff,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "alpha_soft": alpha,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "multi_class": False,
                         "has_conv": True,
                         "with_k_schedule": False}, r4_method)

### IBP-Ex Training

In [None]:
ibp_ex_method = "ibp_ex"
# hyperparameters
class_weights = [1.9, 1]
num_epochs, lr, restarts, epsilon, weight_coeff, k, alpha, gain = 18, 0.0001, 3, 0.01, 5e-5, 0.01, 0.725, 0.1
test_epsilon = 0.01
new_dl_train = plant.make_soft_masks(dl_train, alpha)
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    curr_model = PlantNet(3, 1).to(device)

    print(f"========== Training model with method {ibp_ex_method} restart {i} ==========")
    train_model_with_certified_input_grad(
        new_dl_train, num_epochs, curr_model, lr, criterion, epsilon, ibp_ex_method, k, device, True, class_weights = class_weights
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, new_dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test, curr_model, test_epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test, device, num_groups=2)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[ibp_ex_method], f"run_{i}.pt"))
empty_model = PlantNet(3, 1).to(device)
avg_acc, wg_acc, wg, *_ = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test, save_dir_for_method[ibp_ex_method], empty_model, device, num_groups=2
)
write_results_to_file("experiment_results/plant.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)}, ibp_ex_method)
write_results_to_file("experiment_results/plant_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": test_epsilon,
                         "k": k,
                         "weight_coeff": weight_coeff,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "multi_class": False,
                         "has_conv": True,
                         "with_k_schedule": False}, ibp_ex_method)

# IBP_EX+R3 Training

In [None]:
ibp_ex_and_r3_method = "ibp_ex+r3"
# hyperparameters
class_weights = [1.85, 1]
num_epochs, lr, restarts, epsilon, k, alpha, weight_decay = 20, 0.00008, 3, 0.01, 0.01, 0.72, 5e-5
test_epsilon = 0.01
new_dl_train = plant.make_soft_masks(dl_train, alpha)
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    curr_model = PlantNet(3, 1).to(device)

    print(f"========== Training model with method {ibp_ex_and_r3_method} restart {i} ==========")
    train_model_with_certified_input_grad(
        new_dl_train, num_epochs, curr_model, lr, criterion, epsilon, ibp_ex_and_r3_method, k, device, True,
        class_weights = class_weights, weight_decay=weight_decay
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, new_dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test, curr_model, test_epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test, device, num_groups=2)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[ibp_ex_and_r3_method], f"run_{i}.pt"))
empty_model = PlantNet(3, 1).to(device)
avg_acc, wg_acc, wg, *_ = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test, save_dir_for_method[ibp_ex_and_r3_method], empty_model, device, num_groups=2
)
write_results_to_file("experiment_results/plant.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)}, ibp_ex_and_r3_method)
write_results_to_file("experiment_results/plant_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": test_epsilon,
                         "k": k,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "multi_class": False,
                         "weight_decay": weight_decay,
                         "has_conv": True,
                         "with_k_schedule": False}, ibp_ex_and_r3_method)

# PGD_R4

In [None]:
pgd_r4_method = "pgd_r4"
# hyperparameters
class_weights = [2.3, 1]
num_epochs, lr, restarts, epsilon, weight_coeff, k = 15, 5e-5, 1, 0.03, 5e-5, 0.003
test_epsilon = 0.01
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i)
    i = 2
    curr_model = PlantNet(3, 1).to(device)

    print(f"========== Training model with method {pgd_r4_method} restart {i} ==========")
    train_model_with_pgd_robust_input_grad(
        dl_train, num_epochs, curr_model, lr, criterion, epsilon, pgd_r4_method, k, device, weight_reg_coeff=weight_coeff, class_weights=class_weights
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test, curr_model, test_epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test, device, num_groups=2)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[pgd_r4_method], f"run_{i}.pt"))
empty_model = PlantNet(3, 1).to(device)
avg_acc, wg_acc, wg, *_ = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test, save_dir_for_method[pgd_r4_method], empty_model, device, num_groups=2
)
write_results_to_file("experiment_results/plant.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)}, pgd_r4_method)
write_results_to_file("experiment_results/plant_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": test_epsilon,
                         "k": k,
                         "weight_coeff": weight_coeff,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "multi_class": False,
                         "has_conv": True,
                         "with_k_schedule": False}, pgd_r4_method)

# Smoothed-R3

In [None]:
smooth_r3 = "smooth_r3"
# hyperparameters
class_weights = [2.4, 1]
num_epochs, lr, restarts, epsilon, weight_coeff, k, num_samples = 17, 7e-5, 3, 0.01, 1e-4, 0.01, 3
test_epsilon = 0.01
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
seeds = [0, 1, 12]
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(seeds[i])
    curr_model = PlantNet(3, 1).to(device)

    print(f"========== Training model with method {smooth_r3} restart {i} ==========")
    train_model_with_smoothed_input_grad(
        dl_train, num_epochs, curr_model, lr, criterion, epsilon, smooth_r3, k, device, weight_reg_coeff=weight_coeff,
        class_weights=class_weights, num_samples=num_samples
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test, curr_model, test_epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test, device, num_groups=2)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[smooth_r3], f"run_{2}.pt"))
empty_model = PlantNet(3, 1).to(device)
avg_acc, wg_acc, wg, std_dev_all, std_dev_wg = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test, save_dir_for_method[smooth_r3], empty_model, device, num_groups=2
)
write_results_to_file("experiment_results/plant.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "std_dev_group_acc": round(std_dev_all, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "std_dev_wg_acc": round(std_dev_wg, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)}, smooth_r3)
write_results_to_file("experiment_results/plant_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": test_epsilon,
                         "k": k,
                         "weight_coeff": weight_coeff,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "multi_class": False,
                         "has_conv": True,
                         "num_samples": num_samples,
                         "with_k_schedule": False}, smooth_r3)

# Rand-R4 Training

In [None]:
rand_r4 = "rand_r4"
# hyperparameters
class_weights = [2.1, 1]
num_epochs, lr, restarts, epsilon, weight_coeff, k, num_samples = 21, 5e-5, 3, 0.025, 3e-4, 0.05, 3
test_epsilon = 0.01
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc, num_robust, avg_delta, min_lower_bound, max_upper_bound = 0, 0, 0, 0, 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i)
    curr_model = PlantNet(3, 1).to(device)

    print(f"========== Training model with method {rand_r4} restart {i} ==========")
    train_model_with_smoothed_input_grad(
        dl_train, num_epochs, curr_model, lr, criterion, epsilon, rand_r4, k, device, weight_reg_coeff=weight_coeff,
        class_weights=class_weights, num_samples=num_samples
    )
    print("Testing model accuracy for the training set")
    train_acc += test_model_accuracy(curr_model, dl_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_model_accuracy(curr_model, dl_test, device)
    n_r, min_delta, m_l, m_u = test_delta_input_robustness(
        dl_test, curr_model, test_epsilon, DELTA_INPUT_ROBUSTNESS_PARAM, "binary_cross_entropy", device, has_conv=True
    )
    num_robust += num_robust
    avg_delta += min_delta
    min_lower_bound += m_l
    max_upper_bound += m_u
    avg_g_acc, wg_acc, wg = test_model_avg_and_wg_accuracy(curr_model, dl_test, device, num_groups=2)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[rand_r4], f"run_{i}.pt"))
empty_model = PlantNet(3, 1).to(device)
avg_acc, wg_acc, wg, std_dev_all, std_dev_wg = get_restart_avg_and_worst_group_accuracy_with_stddev(
    dl_test, save_dir_for_method[rand_r4], empty_model, device, num_groups=2
)
write_results_to_file("experiment_results/plant.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "std_dev_group_acc": round(std_dev_all, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "std_dev_wg_acc": round(std_dev_wg, 5),
                       "worst_group": wg,
                       "min_robust_delta": round(avg_delta / restarts, 5),
                       "min_lower_bound": round(min_lower_bound / restarts, 5),
                       "max_upper_bound": round(max_upper_bound / restarts, 5)}, rand_r4)
write_results_to_file("experiment_results/plant_params.yaml",
                        {"epsilon": epsilon,
                         "test_epsilon": test_epsilon,
                         "k": k,
                         "weight_coeff": weight_coeff,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_train.batch_size,
                         "test_batch_size": dl_train.batch_size,
                         "class_weights": class_weights,
                         "multi_class": False,
                         "has_conv": True,
                         "num_samples": num_samples,
                         "with_k_schedule": False}, rand_r4)