In [None]:
# datasets here return x1, x2, label
# label -> 0 if x1 and x2 are from the same class
# label -> 1 if x1 and x2 are from different classes

# imports from general packages
import os
import torch
import numpy as np
from collections import defaultdict
from torchvision import datasets
from torch.utils.data import Dataset
from torchvision import transforms

# Citation:
# 1. https://discuss.pytorch.org/t/about-normalization-using-pre-trained-vgg16-networks/23560/6
def get_mean_and_std_of_dataset(dataset):
    single_img, _ = dataset[0]
    assert torch.is_tensor(single_img)
    num_channels, dim_1, dim_2 = single_img.shape[0], single_img.shape[1], single_img.shape[2]

    loader = torch.utils.data.DataLoader(dataset, batch_size = 128, num_workers = 4, shuffle = False)
    mean = 0.0
    for images, _ in loader:
        batch_samples = images.size(0)
        images = images.view(batch_samples, images.size(1), -1)
        mean += images.mean(2).sum(0)
    mean = mean / len(loader.dataset)

    var = 0.0
    for images, _ in loader:
        batch_samples = images.size(0)
        images = images.view(batch_samples, images.size(1), -1)
        var += ((images - mean.unsqueeze(1)) ** 2).sum([0,2])

    std = torch.sqrt(var / (len(loader.dataset) * dim_1 * dim_2))

    return mean, std

def print_dataset_information(dataset, dataset_name, train, verbose):
    if verbose:
        print()
        print("Dataset name:", dataset_name)
        print("Is train: ", train)
        print("Number of elements: ", len(dataset))

        img, _ = dataset[0]
        print(img.shape)
        print()

        print("Transform: ", dataset.transform)
        print()


def get_custom_data_transform(dataset_name, mean, std):
    
    if dataset_name in ["ImageFolder"]:
        transform = transforms.Compose([
            transforms.Resize(size = (32, 32)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    else:
        raise ValueError("Given dataset name is not supported.")

    return transform

def load_dataset_with_custom_data_transform(dataset_name, transform):
    if dataset_name == "ImageFolder":
        dataset = datasets.ImageFolder(root = path, transform = transform)

    else:
        raise ValueError("Dataset is not supported")

    return dataset

def load_dataset(dataset_name, dataset_path, train, id_dataset_name, id_dataset_path, augment, return_mean = False):
    id_train_dataset = load_dataset_with_basic_transform(dataset_name = id_dataset_name, train = True, path = id_dataset_path)
    mean, std = get_mean_and_std_of_dataset(dataset = id_train_dataset)

    print()
    print("Mean: ", mean)
    print("Std: ", std)
    print()

    custom_transform = get_custom_data_transform(dataset_name = dataset_name, mean = mean, std = std)
    dataset = load_dataset_with_custom_data_transform(dataset_name = dataset_name, train = train, path = dataset_path, transform = custom_transform)

    print()
    print("###########################")
    print("printing information on the loaded dataset!")
    print("###########################")
    print()

    print_dataset_information(dataset = dataset, dataset_name = dataset_name, train = train, verbose = True)

    if return_mean:
        return dataset, mean
    else:
        return dataset

def create_label_to_index_mapping(dataset):
    mapping = defaultdict(list)
    for index in range(len(dataset)):
        _, label = dataset[index]
        if torch.is_tensor(label):
            label = label.item()
        mapping[label].append(index)

    return mapping

def generate_per_class_sample(dataset_name, dataset_path, save_path, train, sample_size):
    dataset = load_dataset(dataset_name = dataset_name, train = train, path = dataset_path, download = True)
    mapping = create_label_to_index_mapping(dataset)
    labels = [key for key in mapping.keys()]
    labels.sort()
    print("Labels: ", labels)

    per_class_samples = []

    for label in labels:
        indices = np.random.choice(a = mapping[label], size = sample_size, replace = False)
        indices = np.sort(a = indices)

        sample = [label] + [indices[i] for i in range(indices.shape[0])]
        per_class_samples.append(sample)

    np.savetxt(save_path, np.array(per_class_samples, dtype = np.int32), delimiter=",", fmt="%d")

def load_per_class_samples(save_path):
    assert os.path.isfile(save_path)
    class_to_sample_indices_mapping = {}

    with open(save_path, 'r') as file:
        for line in file:
            tokens = line.split(",")
            label = int(tokens[0])
            samples = [int(tokens[i]) for i in range(1, len(tokens))]

            class_to_sample_indices_mapping[label] = samples

    return class_to_sample_indices_mapping

# prepare validation dataset
def generate_validation_dataset(dataset_name, dataset_path, save_path, half_validation_dataset_size):
    dataset = load_dataset(dataset_name = dataset_name, train = False, path = dataset_path, download = True)
    mapping = create_label_to_index_mapping(dataset)
    labels = [key for key in mapping.keys()]
    labels.sort()
    print("Labels: ", labels)

    validation_dataset = []

    for i in range(half_validation_dataset_size * 2):
        # same class -> label 0
        if i % 2 == 0:
            random_label = np.random.choice(labels)
            indices = np.random.choice(a = mapping[random_label], size = 2, replace = False)
            label = 0
            data_point = [indices[0], indices[1], label]

        # different class -> label 1
        else:
            random_labels = np.random.choice(a = labels, size = 2, replace = False)
            index_1 = np.random.choice(a = mapping[random_labels[0]])
            index_2 = np.random.choice(a = mapping[random_labels[1]])
            label = 1
            data_point = [index_1, index_2, label]

        validation_dataset.append(data_point)

    np.savetxt(save_path, np.array(validation_dataset, dtype= np.int32), delimiter=",", fmt = "%d")

def load_validation_dataset(save_path):
    assert os.path.isfile(save_path)

    validation_dataset = []
    with open(save_path, 'r') as file:
        for line in file:
            tokens = line.split(",")
            assert len(tokens) == 3
            data_point = [int(tokens[0]), int(tokens[1]), int(tokens[2])]

            validation_dataset.append(data_point)

    return validation_dataset

def is_valid_class_to_sample_mapping(mapping):
    mapped_list_size = -1
    for key in mapping:
        mapped_list = mapping[key]
        if mapped_list_size == -1:
            mapped_list_size = len(mapped_list)
        elif mapped_list_size != len(mapped_list):
            return False

    return True

class PreSavedPerClassSampledDataset(Dataset):
    def __init__(self, dataset, sample_indices_path):
        self.dataset = dataset

        self.class_to_sample_indices_mapping = load_per_class_samples(save_path = sample_indices_path)
        assert is_valid_class_to_sample_mapping(self.class_to_sample_indices_mapping)

        self.class_labels = [class_label for class_label in self.class_to_sample_indices_mapping]
        assert len(self.class_labels) > 0
        self.class_labels.sort()
        self.sample_size_per_class = len(self.class_to_sample_indices_mapping[self.class_labels[0]])

        img_indices = []
        for class_label in self.class_labels:
            img_indices = img_indices + self.class_to_sample_indices_mapping[class_label]
        self.img_indices = img_indices

        self.len = len(self.img_indices)

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        img_index = self.img_indices[index]
        return self.dataset[img_index]

    def get_class_labels(self):
        return self.class_labels

    def get_sample_size_per_class(self):
        return self.sample_size_per_class

    def get_sample_from_class(self, class_label, sample_index):
        assert sample_index >= 0 and sample_index < self.sample_size_per_class

        sample_datapoint_indices = self.class_to_sample_indices_mapping[class_label]
        img_index = sample_datapoint_indices[sample_index]

        return self.dataset[img_index]

# pairwise dataset, with random pairings made over the given dataset
# label = 0, both data points from the pair are from the same class
# label = 1, datapoints from the pair are from different classes
class PairwiseDatasetRandom(Dataset):
    def __init__(self, dataset, epoch_length = 10000):
        self.dataset = dataset
        self.mapping = create_label_to_index_mapping(self.dataset)
        self.labels = [key for key in self.mapping.keys()]
        self.len = epoch_length

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        img1 = None
        img2 = None
        label = None

        # same class -> label = 0
        if index % 2 == 0:
            random_label = np.random.choice(self.labels)
            indices = np.random.choice(a = self.mapping[random_label], size = 2, replace = False)
            img1, _ = self.dataset[indices[0]]
            img2, _ = self.dataset[indices[1]]

            label = 0

        # different class -> label = 1
        else:
            random_labels = np.random.choice(a = self.labels, size = 2, replace = False)
            index_1 = np.random.choice(a = self.mapping[random_labels[0]])
            index_2 = np.random.choice(a = self.mapping[random_labels[1]])
            img1, _ = self.dataset[index_1]
            img2, _ = self.dataset[index_2]

            label = 1

        return img1, img2, label

# pairwise validation dataset
class PairwiseDatasetPreSaved(Dataset):
    def __init__(self, dataset, combination_path):
        self.dataset = dataset
        self.combinations = load_validation_dataset(save_path = combination_path)
        self.len = len(self.combinations)

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        data_point = self.combinations[index]
        index_1, index_2, label = data_point[0], data_point[1], data_point[2]

        img1, _ = self.dataset[index_1]
        img2, _ = self.dataset[index_2]

        return img1, img2, label


# dataset for classification task (not pairwise task)
# that contains selected classes
class SelectiveClassDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, labels_path):
        self.dataset = dataset

        self.partial_labels = []
        f = open(labels_path, 'r')
        for line in f:
            self.partial_labels.append(int(line))
        f.close()

        self.subset_indices = []
        for index in range(len(self.dataset)):
            _, label = self.dataset[index]
            if label in self.partial_labels:
                self.subset_indices.append(index)

        self.len = len(self.subset_indices)

        self.remapped_labels = {}
        for i in range(len(self.partial_labels)):
            label = self.partial_labels[i]
            self.remapped_labels[label] = i

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        img, label = self.dataset[self.subset_indices[index]]
        if torch.is_tensor(label):
            label = label.item()

        new_label = self.remapped_labels[label]
        return img, new_label

# dataset that contains randomly sampled datapoints of the original dataset
class RandomlySampledDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, base_rate, choose_randomly = True, fixed_sample_size = None):
        self.dataset = dataset

        assert base_rate > 0.0 and base_rate <= 1.0
        self.base_rate = base_rate

        label_to_index_mapping = create_label_to_index_mapping(dataset = self.dataset)
        labels = [int(key) for key in label_to_index_mapping.keys()]
        labels.sort()

        print("Labels: ", labels)

        sample_indices = []
        for label in labels:
            indices = label_to_index_mapping[label]

            if fixed_sample_size is not None:
                sample_size = fixed_sample_size
            else:
                sample_size = int(len(indices) * self.base_rate)

            if choose_randomly:
                class_sample = np.random.choice(a = indices, size = sample_size, replace = False)
                class_sample = class_sample.tolist()
            else:
                class_sample = indices[0 : sample_size]

            sample_indices = sample_indices + class_sample

        self.sample_indices = sample_indices
        self.len = len(self.sample_indices)

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        new_index = self.sample_indices[index]
        return self.dataset[new_index]

def check_dataset(dataset):
    label_to_index_mapping = create_label_to_index_mapping(dataset = dataset)
    labels = [key for key in label_to_index_mapping.keys()]
    labels.sort()

    print()
    print("Printing classes and number of element in each class.")
    print()

    for label in labels:
        print("Class: ", label, " Num datapoints: ", len(label_to_index_mapping[label]))

    print()


In [None]:
# Citation:
# 1. https://thenewstack.io/tutorial-train-a-deep-learning-model-in-pytorch-and-export-it-to-onnx/
# 2. https://discuss.pytorch.org/t/creating-custom-dataset-from-inbuilt-pytorch-datasets-along-with-data-transformations/58270/2
# 3. https://medium.com/@sergioalves94/deep-learning-in-pytorch-with-cifar-10-dataset-858b504a6b54
# 4. https://github.com/fangpin/siamese-pytorch/blob/master/train.py
# 5. https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html

import torch
import time
import numpy as np
import os
import shutil
import copy

def print_message(message, verbose = True):
    if verbose:
        print("")
        print(message)
        print("")

# should_use_scheduler -> True, use scheduler at every batch (instead of every epoch) like Outlier-exposure paper
# should_use_scheduler -> False, use scheduler at every epoch
def train_model(model, train_loader, optimizer, scheduler, should_use_scheduler):
    loss_fn = torch.nn.BCEWithLogitsLoss(reduction = 'mean')
    model.train()

    for idx, (img1, img2, label) in enumerate(train_loader):
        if torch.cuda.is_available():
            img1 = img1.cuda()
            img2 = img2.cuda()
            label = label.cuda()

        model.zero_grad()
        output = model.forward(img1, img2).squeeze()
        loss = loss_fn(output, label.float())

        loss.backward()
        optimizer.step()

        if should_use_scheduler and scheduler is not None:
            scheduler.step()

def test_model(model, test_loader):
    model.eval()
    loss_fn = torch.nn.BCEWithLogitsLoss(reduction = 'sum')

    total_loss = 0
    num_correct = 0
    num_total = 0

    with torch.no_grad():
        for idx, (img1, img2, label) in enumerate(test_loader):
            if torch.cuda.is_available():
                img1 = img1.cuda()
                img2 = img2.cuda()
                label = label.cuda()

            output = model.forward(img1, img2).squeeze()
            loss = loss_fn(output, label.float())
            total_loss += loss.item()

            pred = (output > 0).long()
            num_correct += (pred.squeeze() == label.squeeze()).float().sum().item()
            num_total += pred.shape[0]

    accuracy = float(num_correct) / num_total
    average_loss = float(total_loss) / num_total
    return average_loss, accuracy

class PairwiseModelTrainer:
    def __init__(self, model, model_name, train_loader, validation_loader, optimizer, scheduler = None):
        self.model = model
        self.model_name = model_name

        self.train_loader = train_loader
        self.validation_loader = validation_loader

        self.optimizer = optimizer
        self.scheduler = scheduler

        self.max_val_accuracy = None
        self.best_epoch = None

        self.train_loss_history = []
        self.train_accuracy_history = []
        self.val_loss_history = []
        self.val_accuracy_history = []
        self.time_history = []

    def run_training(self, num_epochs, model_path = None, verbose = True, should_use_scheduler = False):
        print_message(message = "Model training is starting...", verbose = verbose)

        best_model_params = None
        for epoch in range(1, num_epochs + 1):
            start_time = time.time()

            # run the training step
            train_model(model = self.model, train_loader = self.train_loader, optimizer = self.optimizer, scheduler = self.scheduler, should_use_scheduler = should_use_scheduler)

            # calculate the training and validation loss and accuracy
            train_loss, train_accuracy = test_model(model = self.model, test_loader = self.train_loader)
            val_loss, val_accuracy = test_model(model = self.model, test_loader = self.validation_loader)

            end_time = time.time()

            if not should_use_scheduler and self.scheduler is not None:
                self.scheduler.step()

            self.train_loss_history.append(train_loss)
            self.train_accuracy_history.append(train_accuracy)
            self.val_loss_history.append(val_loss)
            self.val_accuracy_history.append(val_accuracy)

            time_per_epoch = end_time - start_time
            self.time_history.append(time_per_epoch)

            if verbose:
                print("Epoch: ", epoch, "/", num_epochs, " Train Accuracy: ", train_accuracy, " Val Accuracy: ", val_accuracy)
                print("            Train loss: ", train_loss, " Val loss: ", val_loss)
                print("Learning rate: ", self.scheduler.get_last_lr()[0])
                print()

            if epoch == 1 or val_accuracy > self.max_val_accuracy:
                self.best_epoch = epoch
                self.max_val_accuracy = val_accuracy
                best_model_params = copy.deepcopy(self.model.state_dict())

        if model_path is not None:
            # torch.save(best_model_params, model_path)
            torch.save(self.model.state_dict(), model_path)

        if verbose:
            total_time = 0
            for time_epoch in self.time_history:
                total_time += time_epoch

            print()
            print("Total training time: ", total_time, " seconds")
            print()

    def report_peak_performance(self):
        if self.max_val_accuracy == None:
            print("Model has not been trained yet.")
        else:
            print()
            print("Model peaked in validation accuracy at epoch ", self.best_epoch)
            print("Model peak validation accuracy: ", self.max_val_accuracy)
            print()

    def save_log(self, log_dir):
        print_message("Log directory: " + log_dir, verbose = 1)

        record_list = [self.train_loss_history, self.val_loss_history, self.train_accuracy_history, self.val_accuracy_history]
        filename_list = ["train_loss", "val_loss", "acc", "val_acc"]
        filename_prefix = "_per_epoch.txt"

        for i in range(len(record_list)):
            numpy_record = np.array(record_list[i])
            filename = filename_list[i] + filename_prefix
            filepath = os.path.join(log_dir, filename)
            np.savetxt(filepath, numpy_record, delimiter=",", fmt = "%1.4e")

        timer_log = np.array(self.time_history)
        total_time = np.sum(timer_log) / 3600.0

        appended =  [total_time] + self.time_history
        appended_log = np.array(appended)
        timer_path = os.path.join(log_dir, "training_time.txt")
        np.savetxt(timer_path, appended_log, delimiter=",", fmt = "%1.4e")


In [None]:
from __future__ import print_function
import torch
import numpy as np
import argparse
import os
import sys

# import from our scripts
from utils.pytorch_pairwise_dataset import load_dataset
from utils.pytorch_pairwise_dataset import RandomlySampledDataset
from utils.pytorch_pairwise_dataset import SelectiveClassDataset
from utils.pytorch_pairwise_dataset import check_dataset

from utils.wide_resnet_pytorch import create_wide_resnet
from utils.resnet_pytorch import ResNet18, ResNet34, ResNet50

from utils.pytorch_classifier_trainer import ClassifierTrainer

from utils.siamese_network import process_shared_model_name_wide_resnet

from utils.plotting_log_utils import plot_loss
from utils.plotting_log_utils import plot_accuracy

from utils.get_readable_timestamp import get_readable_timestamp

def parse_arguments():
    ap = argparse.ArgumentParser()

    # dataset and model name and path arguments
    ap.add_argument("-dataset_name", "--dataset_name", type = str, default = "CIFAR10")
    ap.add_argument("-dataset_path", "--dataset_path", type = str)
    ap.add_argument("-num_classes", "--num_classes", type = int, default = 10)
    ap.add_argument("-model_type", "--model_type", type = str, default = "ResNet34")

    # training arguments
    ap.add_argument("-train_batch_size", "--train_batch_size", type = int, default = 32)
    ap.add_argument("-val_batch_size", "--val_batch_size", type = int, default = 1024)
    ap.add_argument("-num_workers", "--num_workers", type = int, default = 4)

    ap.add_argument("-lr", "--lr", type = float, default = 0.1)
    ap.add_argument("-momentum", "--momentum", type = float, default = 0.9)
    ap.add_argument("-weight_decay", "--weight_decay", type = float, default = 0.0005)

    ap.add_argument("-num_epochs", "--num_epochs", type = int, default = 200)
    ap.add_argument("-use_nesterov", "--use_nesterov", type = int, default = 1)
    ap.add_argument("-verbose", "--verbose", type = bool, default = True)
    ap.add_argument("-use_default_scheduler", "--use_default_scheduler", type = int, default = 0, choices = [0, 1])
    ap.add_argument("-should_plot", "--should_plot", type = int, default = 0)

    # saving directory arguments
    ap.add_argument("-model_name", "--model_name", type = str, default = "MSP model")
    ap.add_argument("-model_save_path", "--model_save_path", type = str)
    ap.add_argument("-plot_directory", "--plot_directory", type = str, default = "./")
    ap.add_argument("-log_directory", "--log_directory", type = str, default = "./")

    # training on partial datasets
    ap.add_argument("-base_rate", "--base_rate", type = float, default = 1.0)
    ap.add_argument("-use_partial_dataset", "--use_partial_dataset", type = int, default = 0, choices = [0, 1])
    ap.add_argument("-partial_dataset_path_prefix", "--partial_dataset_path_prefix", type = str, default = "./")
    ap.add_argument("-partial_dataset_filename", "--partial_dataset_filename", type = str, default = "dataset_1/partial_dataset_labels.txt")

    script_arguments = vars(ap.parse_args())
    return script_arguments

def print_arguments(args):
    print()
    print("Arguments given for the script...")
    for key in args:
        print("Key: ", key, " Value: ", args[key])
    print()

def create_dataloaders(args):
    train_dataset = load_dataset(dataset_name = args["dataset_name"],
                                 dataset_path = args["dataset_path"],
                                 train = True,
                                 id_dataset_name = args["dataset_name"],
                                 id_dataset_path = args["dataset_path"],
                                 augment = True)

    if args["base_rate"] < 1.0:
        train_dataset = RandomlySampledDataset(dataset = train_dataset, base_rate = args["base_rate"], choose_randomly = False)

    validation_dataset = load_dataset(dataset_name = args["dataset_name"],
                                      dataset_path = args["dataset_path"],
                                      train = False,
                                      id_dataset_name = args["dataset_name"],
                                      id_dataset_path = args["dataset_path"],
                                      augment = False)

    if args["use_partial_dataset"] == 1:
        labels_path = args["partial_dataset_path_prefix"] + args["partial_dataset_filename"]
        train_dataset = SelectiveClassDataset(dataset = train_dataset, labels_path = labels_path)
        validation_dataset = SelectiveClassDataset(dataset = validation_dataset, labels_path = labels_path)

    check_dataset(dataset = train_dataset)
    check_dataset(dataset = validation_dataset)

    train_dataloader = torch.utils.data.DataLoader(dataset = train_dataset,
                                                   batch_size = args["train_batch_size"],
                                                   shuffle = True,
                                                   num_workers = args["num_workers"])

    validation_dataloader = torch.utils.data.DataLoader(dataset = validation_dataset,
                                                        batch_size = args["val_batch_size"],
                                                        shuffle = False,
                                                        num_workers = args["num_workers"])

    return train_dataloader, validation_dataloader

def create_resnet_given_params(model_name, num_classes, num_input_channels):
    resnet_model = None

    if model_name == "ResNet18":
        resnet_model = ResNet18(contains_last_layer = True, num_input_channels = num_input_channels, num_classes = num_classes)

    elif model_name == "ResNet34":
        resnet_model = ResNet34(contains_last_layer = True, num_input_channels = num_input_channels, num_classes = num_classes)

    elif model_name == "ResNet50":
        resnet_model = ResNet50(contains_last_layer = True, num_input_channels = num_input_channels, num_classes = num_classes)

    else:
        raise ValueError('Model name not supported')

    return resnet_model

def create_wide_resnet_model(args):
    architecture_map = process_shared_model_name_wide_resnet(shared_model_name = args["model_type"])
    dataset_name = args["dataset_name"]

    kwargs = {"depth": architecture_map["depth"],
              "widen_factor": architecture_map["widen_factor"],
              "dropRate": architecture_map["dropRate"],
              "num_classes": args["num_classes"],
              "contains_last_layer": True}

    if dataset_name == "CIFAR10" or dataset_name == "CIFAR100" or dataset_name == "CIFAR" or dataset_name == "SVHN" or dataset_name == "CIFAR100Coarse":
        kwargs["num_input_channels"] = 3
        model = create_wide_resnet(**kwargs)

    elif dataset_name == "MNIST":
        kwargs["num_input_channels"] = 1
        model = create_wide_resnet(**kwargs)

    else:
        raise ValueError("Dataset name not supported")

    return model

def create_resnet_model(args):
    dataset_name = args["dataset_name"]
    kwargs = {"model_name": args["model_type"],
              "num_classes": args["num_classes"]}

    if dataset_name == "CIFAR10" or dataset_name == "CIFAR100" or dataset_name == "CIFAR" or dataset_name == "SVHN" or dataset_name == "CIFAR100Coarse":
        kwargs["num_input_channels"] = 3
        model = create_resnet_given_params(**kwargs)

    elif dataset_name == "MNIST":
        kwargs["num_input_channels"] = 1
        model = create_resnet_given_params(**kwargs)

    else:
        raise ValueError("Dataset name not supported")

    return model

def create_model(args):
    if args["model_type"].find("WideResNet") == 0:
        model = create_wide_resnet_model(args)

    elif args["model_type"].find("ResNet") >= 0:
        model = create_resnet_model(args)

    else:
        raise ValueError("Given model type is not supported")

    if torch.cuda.is_available():
        model.cuda()

    return model

def create_optimizer_and_scheduler(args, model, len_train_dataloader):
    use_nesterov = None
    if args["use_nesterov"] == 1:
        use_nesterov = True
    elif args["use_nesterov"] == 0:
        use_nesterov = False
    else:
        raise ValueError("argument for using nesterov momentum, is not supported")

    optimizer = torch.optim.SGD(model.parameters(),
                                lr = args["lr"],
                                momentum = args["momentum"],
                                nesterov = use_nesterov,
                                weight_decay = args["weight_decay"])

    if args["use_default_scheduler"] == 1:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = args["num_epochs"])
        
    else:
        def cosine_annealing(step, total_steps, lr_max, lr_min):
            return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))

        # scheduler from outlier exposure
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda = lambda step: cosine_annealing(step,
                                                                                                          args["num_epochs"] * len_train_dataloader,
                                                                                                          1,  # since lr_lambda computes multiplicative factor
                                                                                                          1e-6 / args["lr"]))

    return optimizer, scheduler

def run_model_training(args):
    model = create_model(args = args)
    if args["verbose"]:
        print("Model created.")
        print(model)

    train_dataloader, val_dataloader = create_dataloaders(args = args)
    if args["verbose"]:
        print()
        print("Train dataloader and validation dataloader created")
        print()

    optimizer, scheduler = create_optimizer_and_scheduler(args = args, model = model, len_train_dataloader = len(train_dataloader))
    if args["verbose"]:
        print("Optimizer and scheduler created")
        print("Optimizer: ", optimizer)
        print("scheduler: ", scheduler)
        print("")

    readable_timestamp = get_readable_timestamp() + "_" + str(np.random.randint(1000000))
    model_name = args["model_name"]
    model_path = os.path.join(args["model_save_path"], readable_timestamp)
    if not os.path.isdir(args["model_save_path"]):
        os.makedirs(args["model_save_path"])

    if args["verbose"]:
        print("Model name: ", model_name)
        print("model path: ", model_path)
        print("")

    model_trainer = ClassifierTrainer(model = model, model_name = model_name,
                                      train_loader = train_dataloader, validation_loader = val_dataloader,
                                      optimizer = optimizer, scheduler = scheduler)

    # should_use_scheduler -> True, use scheduler at every batch (instead of every epoch) like Outlier-exposure paper
    # should_use_scheduler -> False, use scheduler at every epoch
    if args["use_default_scheduler"] == 1:
        should_use_scheduler = False
    else:
        should_use_scheduler = True

    model_trainer.run_training(num_epochs = args["num_epochs"],
                               model_path = model_path,
                               verbose = args["verbose"],
                               should_use_scheduler = should_use_scheduler)

    model_trainer.report_peak_performance()

    plot_directory = os.path.join(args["plot_directory"], readable_timestamp)
    if not os.path.isdir(plot_directory):
        os.makedirs(plot_directory)

    if args["should_plot"] > 0:
        if args["verbose"]:
            print("Plot directory: ", plot_directory)
            print("Plotting loss and accuracy...")
            print()

        plot_loss(model_name, model_trainer.train_loss_history, model_trainer.val_loss_history, plot_directory)
        plot_accuracy(model_name, model_trainer.train_accuracy_history, model_trainer.val_accuracy_history, plot_directory)

    log_directory = os.path.join(args["log_directory"], readable_timestamp)
    if not os.path.isdir(log_directory):
        os.makedirs(log_directory)
    if args["verbose"]:
        print("Log directory: ", log_directory)
        print("Saving training log...")
        print()

    model_trainer.save_log(log_directory)

if __name__ == '__main__':
    script_arguments = parse_arguments()
    print_arguments(args = script_arguments)
    run_model_training(args = script_arguments)
