In [None]:
%load_ext autoreload
%autoreload 2
import tqdm
import matplotlib.pyplot as plt
import os
os.environ['HF_HOME'] = '/vol/bitbucket/mg2720/llm/huggingface'
import random
import sys
import torch
import numpy as np
sys.path.append(os.path.abspath('..'))
import abstract_gradient_training as agt
import models.llm as llm
import models.robust_regularizer
from abstract_gradient_training import AGTConfig
from abstract_gradient_training import certified_training_utils as ct_utils
from models.pipeline import train_llm_with_guidance, test_llm_accuracy, write_results_to_file
from datasets.imdb import get_loader_from_dataset, ImdbDataset
from datasets.spurious_words import all_imdb_spur
from metrics import llm_restart_avg_and_worst_group_acc

In [None]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
batch_size = 88
test_batch_size = 250
criterion = torch.nn.BCELoss()
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
bert_tokenizer = llm.BertTokenizerWrapper(all_imdb_spur())
imdb_train = ImdbDataset(is_train=True)
imdb_test = ImdbDataset(is_train=False, grouped=True)
dl_masks_train, dl_masks_test = get_loader_from_dataset(imdb_train, batch_size=batch_size), get_loader_from_dataset(imdb_test, batch_size=batch_size)

## Experiments 

In [None]:
model_root_save_dir = "saved_experiment_models/performance/imdb"
os.makedirs(model_root_save_dir, exist_ok=True)
methods = ["std", "r3", "smooth_r3", "rand_r4", "pgd_r4"]
save_dir_for_method = {method: os.path.join(model_root_save_dir, method) for method in methods}
for method in methods:
    os.makedirs(save_dir_for_method[method], exist_ok=True)

### Standard Training

In [None]:
std_method = "std"
# hyperparameters
num_epochs, lr, restarts, weight_decay, lmbda = 2, 0.0001, 1, 0, 0
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc = 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i + seed)
    curr_model = llm.BertModelWrapper(1, device)

    print(f"========== Training model with method {std_method} restart {i} ==========")
    train_llm_with_guidance(curr_model, bert_tokenizer, dl_masks_train, num_epochs, lr, criterion, std_method, lmbda, device, weight_decay=weight_decay)
    print("Testing model accuracy for the training set")
    train_acc += test_llm_accuracy(curr_model, bert_tokenizer, dl_masks_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_llm_accuracy(curr_model, bert_tokenizer, dl_masks_test, device)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[std_method], f"run_{i}.pt"))
empty_model = llm.BertModelWrapper(1, device)
avg_acc, wg_acc, wg, *_ = llm_restart_avg_and_worst_group_acc(
    dl_masks_test, save_dir_for_method[std_method], empty_model, bert_tokenizer, device, num_groups=2
)
write_results_to_file("experiment_results/imdb_bert.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg}, std_method)
write_results_to_file("experiment_results/imdb_bert_params.yaml",
                        {"k": lmbda,
                         "weight_decay": weight_decay,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_masks_train.batch_size,
                         "test_batch_size": dl_masks_test.batch_size,
                         "class_weights": -1,
                         "multi_class": False}, std_method)

### R3 Training

In [None]:
from importlib import reload
reload(models.pipeline)
reload(models.robust_regularizer)
r3_method = "r3"
# hyperparameters
num_epochs, lr, restarts, weight_decay, lmbda, num_frags = 3, 0.0001, 1, 0, 5e+8, 5
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc = 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i + seed)
    curr_model = llm.BertModelWrapper(1, device)

    print(f"========== Training model with method {r3_method} restart {i} ==========")
    train_llm_with_guidance(
        curr_model, bert_tokenizer, dl_masks_train, num_epochs, lr, criterion, r3_method, lmbda, device, num_fragments=num_frags, weight_decay=weight_decay
    )
    print("Testing model accuracy for the training set")
    train_acc += test_llm_accuracy(curr_model, bert_tokenizer, dl_masks_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_llm_accuracy(curr_model, bert_tokenizer, dl_masks_test, device)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[r3_method], f"run_{i}.pt"))
empty_model = llm.BertModelWrapper(1, device)
avg_acc, wg_acc, wg, *_ = llm_restart_avg_and_worst_group_acc(
    dl_masks_test, save_dir_for_method[r3_method], empty_model, bert_tokenizer, device, num_groups=2
)
write_results_to_file("experiment_results/imdb_bert.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg}, r3_method)
write_results_to_file("experiment_results/imdb_bert_params.yaml",
                        {"k": lmbda,
                         "weight_decay": weight_decay,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_masks_train.batch_size,
                         "test_batch_size": dl_masks_test.batch_size,
                         "class_weights": -1,
                         "multi_class": False}, r3_method)

### Smooth-R3 Training

In [None]:
from importlib import reload
reload(models.pipeline)
reload(models.robust_regularizer)
smooth_r3_method = "smooth_r3"
# hyperparameters
num_epochs, lr, restarts, weight_decay, lmbda, num_frags, n_samples, alpha = 2, 0.00005, 1, 0, 1e+9, 11, 3, 0.1
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc = 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i + seed)
    curr_model = llm.BertModelWrapper(1, device)

    print(f"========== Training model with method {smooth_r3_method} restart {i} ==========")
    train_llm_with_guidance(
        curr_model, bert_tokenizer, dl_masks_train, num_epochs, lr, criterion, smooth_r3_method, lmbda, device,
        num_fragments=num_frags, weight_decay=weight_decay, num_samples=n_samples, alpha=alpha
    )
    print("Testing model accuracy for the training set")
    train_acc += test_llm_accuracy(curr_model, bert_tokenizer, dl_masks_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_llm_accuracy(curr_model, bert_tokenizer, dl_masks_test, device)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[smooth_r3_method], f"run_{i}.pt"))
empty_model = llm.BertModelWrapper(1, device)
avg_acc, wg_acc, wg, *_ = llm_restart_avg_and_worst_group_acc(
    dl_masks_test, save_dir_for_method[smooth_r3_method], empty_model, bert_tokenizer, device, num_groups=2
)
write_results_to_file("experiment_results/imdb_bert.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg}, smooth_r3_method)
write_results_to_file("experiment_results/imdb_bert_params.yaml",
                        {"k": lmbda,
                         "alpha": alpha,
                         "weight_decay": weight_decay,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_masks_train.batch_size,
                         "test_batch_size": dl_masks_test.batch_size,
                         "num_samples": n_samples,
                         "class_weights": -1,
                         "multi_class": False}, smooth_r3_method)

### Rand-R4 Training

In [None]:
from importlib import reload
reload(models.pipeline)
reload(models.robust_regularizer)
rand_r4_method = "rand_r4"
# hyperparameters
num_epochs, lr, restarts, weight_decay, lmbda, num_frags, n_samples, alpha = 1, 0.00005, 1, 0, 5e+10, 11, 3, 0.75
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc = 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i + seed)
    curr_model = llm.BertModelWrapper(1, device)

    print(f"========== Training model with method {rand_r4_method} restart {i} ==========")
    train_llm_with_guidance(
        curr_model, bert_tokenizer, dl_masks_train, num_epochs, lr, criterion, rand_r4_method, lmbda, device,
        num_fragments=num_frags, weight_decay=weight_decay, num_samples=n_samples, alpha=alpha
    )
    print("Testing model accuracy for the training set")
    train_acc += test_llm_accuracy(curr_model, bert_tokenizer, dl_masks_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_llm_accuracy(curr_model, bert_tokenizer, dl_masks_test, device)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[rand_r4_method], f"run_{i}.pt"))
empty_model = llm.BertModelWrapper(1, device)
avg_acc, wg_acc, wg, *_ = llm_restart_avg_and_worst_group_acc(
    dl_masks_test, save_dir_for_method[rand_r4_method], empty_model, bert_tokenizer, device, num_groups=2
)
write_results_to_file("experiment_results/imdb_bert.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg}, rand_r4_method)
write_results_to_file("experiment_results/imdb_bert_params.yaml",
                        {"k": lmbda,
                         "alpha": alpha,
                         "weight_decay": weight_decay,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_masks_train.batch_size,
                         "test_batch_size": dl_masks_test.batch_size,
                         "num_samples": n_samples,
                         "class_weights": -1,
                         "multi_class": False}, rand_r4_method)

### GCG Training

In [None]:
from importlib import reload
reload(models.pipeline)
reload(models.robust_regularizer)
pgd_r4_method = "pgd_r4"
# hyperparameters
num_epochs, lr, restarts, weight_decay, lmbda, num_frags, alpha = 1, 0.00005, 1, 0, 5e+10, 6, 0.3
# Train standard 3 times and test accuracy and delta input robustness for the masked region
train_acc, test_acc = 0, 0
for i in range(restarts):
    # Reinitialize the model
    # We could try to just reinitialize the weights, but we can throw away the previous model for now as we do not need it
    torch.manual_seed(i + seed)
    curr_model = llm.BertModelWrapper(1, device)

    print(f"========== Training model with method {pgd_r4_method} restart {i} ==========")
    train_llm_with_guidance(
        curr_model, bert_tokenizer, dl_masks_train, num_epochs, lr, criterion, pgd_r4_method,
        lmbda, device, num_fragments=num_frags, weight_decay=weight_decay
    )
    print("Testing model accuracy for the training set")
    train_acc += test_llm_accuracy(curr_model, bert_tokenizer, dl_masks_train, device)
    print("Testing model accuracy for the test set")
    test_acc += test_llm_accuracy(curr_model, bert_tokenizer, dl_masks_test, device)
    torch.save(curr_model.state_dict(), os.path.join(save_dir_for_method[pgd_r4_method], f"run_{i}.pt"))
empty_model = llm.BertModelWrapper(1, device)
avg_acc, wg_acc, wg, *_ = llm_restart_avg_and_worst_group_acc(
    dl_masks_test, save_dir_for_method[pgd_r4_method], empty_model, bert_tokenizer, device, num_groups=2
)
write_results_to_file("experiment_results/imdb_bert.yaml",
                      {"train_acc": round(train_acc / restarts, 5),
                       "test_acc": round(test_acc / restarts, 5),
                       "avg_group_acc": round(avg_acc, 5),
                       "worst_group_acc": round(wg_acc, 5),
                       "worst_group": wg}, pgd_r4_method)
write_results_to_file("experiment_results/imdb_bert_params.yaml",
                        {"k": lmbda,
                         "alpha": alpha,
                         "weight_decay": weight_decay,
                         "num_epochs": num_epochs,
                         "lr": lr,
                         "restarts": restarts,
                         "train_batch_size": dl_masks_train.batch_size,
                         "test_batch_size": dl_masks_test.batch_size,
                         "class_weights": -1,
                         "multi_class": False}, pgd_r4_method)