# TODOs:

- final train and val loss (best)

# Google colab initialization

For Google colab uncomment these lines and run them to access your drive or try the second way (not tested)

In [None]:
#USE_COLAB = True
#
#if USE_COLAB:
#    from google.colab import drive
#
#    drive.mount('/content/drive')
#    import sys
#
#    sys.path.insert(1, r'/content/drive/My Drive/')

In [None]:
USE_COLAB = False

In [None]:
# Other try
# !git clone https://github.com/Alexanderstaehle/OM_project

In [None]:
# import sys
# sys.path.append("OM_project")

# Imports and Setup

In [None]:
import numpy as np
import seaborn as sns
import tensorflow as tf
from tensorflow import keras

from utils import ml_utils, visualization, data_loading, tf_models

In [None]:
RETRAIN_FLAG = True
#RETRAIN_FLAG = False

In [None]:
filename_bs = lambda bs: ml_utils.path_from_filename(f'model_fixed_lr_diff_bs_{bs}', format_="tf")
filename_bs_lr = lambda bs, lr: ml_utils.path_from_filename(f'model_lr_{lr}_diff_bs_{bs}', format_="tf")

In [None]:
filename_bs_opt_sam = lambda bs, opt, sam: ml_utils.path_from_filename(
    f'model_fixed_lr_diff_bs_{bs}_opt_{opt}_sam_{sam}', format_="tf")
filename_bs_lr_opt_sam = lambda bs, lr, opt, sam: ml_utils.path_from_filename(
    f'model_lr_{lr}_diff_bs_{bs}_opt_{opt}_sam_{sam}', format_="tf")

In [None]:
def build_and_load_sam_model_weights(train, optimizer, filename, adaptive=False, rho=0.05):
    model = tf_models.build_simple_cnn_sam(train, optimizer, adaptive, rho)
    model.load_weights(filename)
    return model

In [None]:
def build_and_load_model_weights(train, optimizer, filename):
    model = tf_models.build_and_compile_simple_cnn(train, optimizer)
    model.load_weights(filename)
    return model

In [None]:
def callback_for_filename(filename):
    train_callbacks = [
        tf.keras.callbacks.EarlyStopping(
            monitor="val_loss", patience=10,
            restore_best_weights=True
        ),
        tf.keras.callbacks.ModelCheckpoint(
            filename,
            monitor='val_loss',
            mode='min',
            verbose=1,
            save_best_only=True,
            save_weights_only=True
        )
    ]

    return train_callbacks

In [None]:
data_loading.initialize_env()
sns.set_theme()
color_map = sns.color_palette(as_cmap=True)
ml_utils.check_tpu_gpu()
# dataset_name = 'MNIST'
dataset_name = 'Fashion_MNIST'
EPOCHS = 500

In [None]:
models = {}
models_states = {}
sharpnesses = ml_utils.init_sharpnesses_dict()
initial_weights = ml_utils.init_initial_weights_dict()

# Different batch sizes with fixed learning rate

## with sharpness aware minimization

### SGD with Momentum + SAM

In [None]:
batch_sizes = [32, 64, 128, 256, 512, 1024]
lr = 0.001
training_epochs = EPOCHS

key = ('fixed', 'sgd', 'sam')

In [None]:
if not RETRAIN_FLAG:
    for batch_size in batch_sizes:
        # Read training data
        train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                          batch_size=batch_size,
                                                                          img_size=32)
        optimizer = keras.optimizers.SGD(learning_rate=lr, momentum=0.9)
        model = build_and_load_sam_model_weights(train, optimizer, filename_bs_opt_sam(batch_size, "SGD-MOM", "SAM"))
        models[key + (batch_size,)] = model

In [None]:
if RETRAIN_FLAG:
    models_states[key] = {}
    for batch_size in batch_sizes:
        with tf.distribute.MirroredStrategy().scope():
            # Read training data
            train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                              batch_size=batch_size,
                                                                              img_size=32)

            optimizer = keras.optimizers.SGD(learning_rate=lr, momentum=0.9)
            model = tf_models.build_simple_cnn_sam(train, optimizer)
            train_callbacks = callback_for_filename(filename_bs_opt_sam(batch_size, "SGD-MOM", "SAM"))

            models_states[key][batch_size] = ml_utils.train_model(model, train, validation, epochs=training_epochs,
                                                                  extra_callbacks=train_callbacks, verbose=1)
            models[key + (batch_size,)] = model
            initial_weights[key + (batch_size,)] = model.get_weights()
            ml_utils.save_initial_weights_dict(initial_weights)

In [None]:
fixed_lr_state_filename = 'model_fixed_lr_diff_bs_state'

In [None]:
if RETRAIN_FLAG:
    ml_utils.save_model_state(models_states[key], fixed_lr_state_filename)
else:
    models_states[key] = ml_utils.load_model_state(fixed_lr_state_filename)

In [None]:
visualization.plot_loss_by_param(models_states[key], 'batch size with fixed learning rate', 'fixed_lr_diff_bs_SGD_SAM')

#### Sharpness

In [None]:
if RETRAIN_FLAG:
    for batch_size in batch_sizes:
        model = models[key + (batch_size,)]

        sharpness_bs = visualization.get_sharpness(model.base_model, train)
        sharpnesses[key + (batch_size,)] = sharpness_bs

        ml_utils.save_sharpnesses_dict(sharpnesses)

if not RETRAIN_FLAG:
    sharpnesses = ml_utils.load_sharpnesses_dict()

In [None]:
visualization.plot_sharpness(batch_sizes, sharpnesses, key)

#### Distance from initial weights

In [None]:
visualization.plot_distance_from_initial_weight(models, initial_weights, batch_sizes, key)

#### Runtime

In [None]:
mean_times, convergence_epochs, overall_training_times = visualization.extract_times_for_batch_sizes(models_states,
                                                                                                     batch_sizes, key)

In [None]:
visualization.plot_mean_time_per_epoch(batch_sizes, mean_times)

In [None]:
visualization.histogram_num_of_train_epochs_until_conv(batch_sizes, convergence_epochs)

In [None]:
visualization.histogram_overall_time_until_end_of_epochs(batch_sizes, overall_training_times)

In [None]:
visualization.plot_sharpness_times_runtime(batch_sizes, overall_training_times, sharpnesses, key)

### SGD with Momentum + ASAM

In [None]:
models_dict_fixed_sgd_asam = {}
batch_sizes = [32, 64, 128, 256, 512, 1024]
#batch_sizes = [32]
lr = 0.001
training_epochs = EPOCHS

key = ('fixed', 'sgd', 'asam')

In [None]:
models_by_batch_size_fixed_lr_sgd_asam = {}

In [None]:
if not RETRAIN_FLAG:
    for batch_size in batch_sizes:
        # Read training data
        train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                          batch_size=batch_size,
                                                                          img_size=32)
        optimizer = keras.optimizers.SGD(learning_rate=lr, momentum=0.9)
        model = build_and_load_sam_model_weights(train, optimizer, filename_bs_opt_sam(batch_size, "SGD-MOM", "ASAM"))
        models_by_batch_size_fixed_lr_sgd_asam[batch_size] = model
        models[key + (batch_size,)] = model

In [None]:
if RETRAIN_FLAG:
    models_states[key] = {}
    for batch_size in batch_sizes:
        with tf.distribute.MirroredStrategy().scope():
            # Read training data
            train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                              batch_size=batch_size,
                                                                              img_size=32)

            optimizer = keras.optimizers.SGD(learning_rate=lr, momentum=0.9)
            model = tf_models.build_simple_cnn_sam(train, optimizer, adaptive=True, rho=2.0)
            train_callbacks = callback_for_filename(filename_bs_opt_sam(batch_size, "SGD-MOM", "ASAM"))

            models_states[key][batch_size] = ml_utils.train_model(model, train, validation, epochs=training_epochs,
                                                                  extra_callbacks=train_callbacks, verbose=1)
            models[key + (batch_size,)] = model
            initial_weights[key + (batch_size,)] = model.get_weights()
            ml_utils.save_initial_weights_dict(initial_weights)

In [None]:
fixed_lr_sgd_asam_state_filename = 'model_fixed_lr_diff_bs_sgd_asam_state'

In [None]:
if RETRAIN_FLAG:
    ml_utils.save_model_state(models_states[key], fixed_lr_sgd_asam_state_filename)
else:
    models_states[key] = ml_utils.load_model_state(fixed_lr_sgd_asam_state_filename)

In [None]:
visualization.plot_loss_by_param(models_states[key], 'batch size with fixed learning rate, SGD and ASAM',
                                 'model_fixed_lr_diff_bs_sgd_asam_state')

#### Sharpness

In [None]:
if RETRAIN_FLAG:
    for batch_size in batch_sizes:
        model = models[key + (batch_size,)]

        sharpness_bs = visualization.get_sharpness(model.base_model, train)
        sharpnesses[key + (batch_size,)] = sharpness_bs

        ml_utils.save_sharpnesses_dict(sharpnesses)

if not RETRAIN_FLAG:
    sharpnesses = ml_utils.load_sharpnesses_dict()

In [None]:
visualization.plot_sharpness(batch_sizes, sharpnesses, key)

#### Distance from initial weights

In [None]:
visualization.plot_distance_from_initial_weight(models, initial_weights, batch_sizes, key)

#### Runtime

In [None]:
mean_times, convergence_epochs, overall_training_times = visualization.extract_times_for_batch_sizes(models_states,
                                                                                                     batch_sizes, key)

In [None]:
visualization.plot_mean_time_per_epoch(batch_sizes, mean_times)

In [None]:
visualization.histogram_num_of_train_epochs_until_conv(batch_sizes, convergence_epochs)

In [None]:
visualization.histogram_overall_time_until_end_of_epochs(batch_sizes, overall_training_times)

In [None]:
visualization.plot_sharpness_times_runtime(batch_sizes, overall_training_times, sharpnesses, key)

### ADAM + SAM

In [None]:
batch_sizes = [32, 64, 128, 256, 512, 1024]
lr = 0.001
training_epochs = EPOCHS

key = ('fixed', 'adam', 'sam')

In [None]:
if not RETRAIN_FLAG:
    for batch_size in batch_sizes:
        # Read training data
        train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                          batch_size=batch_size,
                                                                          img_size=32)
        optimizer = keras.optimizers.Adam(learning_rate=lr)
        model = build_and_load_sam_model_weights(train, optimizer, filename_bs_opt_sam(batch_size, "ADAM", "SAM"))
        models[key + (batch_size,)] = model

In [None]:
if RETRAIN_FLAG:
    models_states[key] = {}
    for batch_size in batch_sizes:
        with tf.distribute.MirroredStrategy().scope():
            # Read training data
            train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                              batch_size=batch_size,
                                                                              img_size=32)

            optimizer = keras.optimizers.Adam(learning_rate=lr)
            model = tf_models.build_simple_cnn_sam(train, optimizer)
            train_callbacks = callback_for_filename(filename_bs_opt_sam(batch_size, "ADAM", "SAM"))

            models_states[key][batch_size] = ml_utils.train_model(model, train, validation, epochs=training_epochs,
                                                                  extra_callbacks=train_callbacks, verbose=1)
            models[key + (batch_size,)] = model
            initial_weights[key + (batch_size,)] = model.get_weights()
            ml_utils.save_initial_weights_dict(initial_weights)

In [None]:
state_filename = 'model_fixed_lr_diff_bs_adam_sam_state'

In [None]:
if RETRAIN_FLAG:
    ml_utils.save_model_state(models_states[key], state_filename)
else:
    models_states[key] = ml_utils.load_model_state(state_filename)

In [None]:
visualization.plot_loss_by_param(models_states[key], 'batch size with fixed learning rate, ADAM and SAM',
                                 state_filename)

#### Sharpness

In [None]:
if RETRAIN_FLAG:
    for batch_size in batch_sizes:
        model = models[key + (batch_size,)]

        sharpness_bs = visualization.get_sharpness(model.base_model, train)
        sharpnesses[key + (batch_size,)] = sharpness_bs

        ml_utils.save_sharpnesses_dict(sharpnesses)

if not RETRAIN_FLAG:
    sharpnesses = ml_utils.load_sharpnesses_dict()

In [None]:
visualization.plot_sharpness(batch_sizes, sharpnesses, key)

#### Distance from initial weights

In [None]:
visualization.plot_distance_from_initial_weight(models, initial_weights, batch_sizes, key)

#### Runtime

In [None]:
mean_times, convergence_epochs, overall_training_times = visualization.extract_times_for_batch_sizes(models_states,
                                                                                                     batch_sizes, key)

In [None]:
visualization.plot_mean_time_per_epoch(batch_sizes, mean_times)

In [None]:
visualization.histogram_num_of_train_epochs_until_conv(batch_sizes, convergence_epochs)

In [None]:
visualization.histogram_overall_time_until_end_of_epochs(batch_sizes, overall_training_times)

In [None]:
visualization.plot_sharpness_times_runtime(batch_sizes, overall_training_times, sharpnesses, key)

### ADAM + ASAM

In [None]:
batch_sizes = [32, 64, 128, 256, 512, 1024]
lr = 0.001
training_epochs = EPOCHS

key = ('fixed', 'adam', 'asam')

In [None]:
if not RETRAIN_FLAG:
    for batch_size in batch_sizes:
        # Read training data
        train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                          batch_size=batch_size,
                                                                          img_size=32)
        optimizer = keras.optimizers.Adam(learning_rate=lr)
        model = build_and_load_sam_model_weights(train, optimizer, filename_bs_opt_sam(batch_size, "ADAM", "ASAM"))
        models[key + (batch_size,)] = model

In [None]:
if RETRAIN_FLAG:
    models_states[key] = {}
    for batch_size in batch_sizes:
        with tf.distribute.MirroredStrategy().scope():
            # Read training data
            train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                              batch_size=batch_size,
                                                                              img_size=32)

            optimizer = keras.optimizers.Adam(learning_rate=lr)
            model = tf_models.build_simple_cnn_sam(train, optimizer, adaptive=True, rho=2.0)
            train_callbacks = callback_for_filename(filename_bs_opt_sam(batch_size, "ADAM", "ASAM"))

            models_states[key][batch_size] = ml_utils.train_model(model, train, validation, epochs=training_epochs,
                                                                  extra_callbacks=train_callbacks, verbose=1)
            models[key + (batch_size,)] = model
            initial_weights[key + (batch_size,)] = model.get_weights()
            ml_utils.save_initial_weights_dict(initial_weights)

In [None]:
state_filename = 'model_fixed_lr_diff_bs_adam_asam_state'

In [None]:
if RETRAIN_FLAG:
    ml_utils.save_model_state(models_states[key], state_filename)
else:
    models_states[key] = ml_utils.load_model_state(state_filename)

In [None]:
visualization.plot_loss_by_param(models_states[key], 'batch size with fixed learning rate, ADAM and ASAM',
                                 state_filename)

#### Sharpness

In [None]:
if RETRAIN_FLAG:
    for batch_size in batch_sizes:
        model = models[key + (batch_size,)]

        sharpness_bs = visualization.get_sharpness(model.base_model, train)
        sharpnesses[key + (batch_size,)] = sharpness_bs

        ml_utils.save_sharpnesses_dict(sharpnesses)

if not RETRAIN_FLAG:
    sharpnesses = ml_utils.load_sharpnesses_dict()


In [None]:
visualization.plot_sharpness(batch_sizes, sharpnesses, key)

#### Distance from initial weights

In [None]:
visualization.plot_distance_from_initial_weight(models, initial_weights, batch_sizes, key)

#### Runtime

In [None]:
mean_times, convergence_epochs, overall_training_times = visualization.extract_times_for_batch_sizes(models_states,
                                                                                                     batch_sizes, key)

In [None]:
visualization.plot_mean_time_per_epoch(batch_sizes, mean_times)

In [None]:
visualization.histogram_num_of_train_epochs_until_conv(batch_sizes, convergence_epochs)

In [None]:
visualization.histogram_overall_time_until_end_of_epochs(batch_sizes, overall_training_times)

In [None]:
visualization.plot_sharpness_times_runtime(batch_sizes, overall_training_times, sharpnesses, key)

## without sharpness aware minimization

### SGD with Momentum

In [None]:
batch_sizes = [32, 64, 128, 256, 512, 1024]
lr = 0.001
training_epochs = EPOCHS

key = ('fixed', 'sgd', 'none')

In [None]:
if not RETRAIN_FLAG:
    for batch_size in batch_sizes:
        # Read training data
        train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                          batch_size=batch_size,
                                                                          img_size=32)
        optimizer = keras.optimizers.SGD(learning_rate=lr, momentum=0.9)
        model = build_and_load_model_weights(train, optimizer, filename_bs_opt_sam(batch_size, "SGD-MOM", "NONE"))
        models[key + (batch_size,)] = model

In [None]:
if RETRAIN_FLAG:
    models_states[key] = {}
    for batch_size in batch_sizes:
        with tf.distribute.MirroredStrategy().scope():
            # Read training data
            train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                              batch_size=batch_size,
                                                                              img_size=32)

            optimizer = keras.optimizers.SGD(learning_rate=lr, momentum=0.9)
            model = tf_models.build_and_compile_simple_cnn(train, optimizer)
            train_callbacks = callback_for_filename(filename_bs_opt_sam(batch_size, "SGD-MOM", "NONE"))

            models_states[key][batch_size] = ml_utils.train_model(model, train, validation, epochs=training_epochs,
                                                                  extra_callbacks=train_callbacks, verbose=1)
            models[key + (batch_size,)] = model
            initial_weights[key + (batch_size,)] = model.get_weights()
            ml_utils.save_initial_weights_dict(initial_weights)

In [None]:
state_filename = 'model_fixed_lr_diff_bs_noSAM_state'

In [None]:
if RETRAIN_FLAG:
    ml_utils.save_model_state(models_states[key], state_filename)
else:
    models_states[key] = ml_utils.load_model_state(state_filename)

In [None]:
visualization.plot_loss_by_param(models_states[key], 'batch size with fixed learning rate without SAM',
                                 'fixed_lr_diff_bs_SGD_noSAM')

#### Sharpness

In [None]:
if RETRAIN_FLAG:
    for batch_size in batch_sizes:
        model = models[key + (batch_size,)]

        sharpness_bs = visualization.get_sharpness(model, train)
        sharpnesses[key + (batch_size,)] = sharpness_bs

        ml_utils.save_sharpnesses_dict(sharpnesses)

if not RETRAIN_FLAG:
    sharpnesses = ml_utils.load_sharpnesses_dict()

In [None]:
visualization.plot_sharpness(batch_sizes, sharpnesses, key)

#### Distance from initial weights

In [None]:
visualization.plot_distance_from_initial_weight(models, initial_weights, batch_sizes, key)

#### Runtime

In [None]:
mean_times, convergence_epochs, overall_training_times = visualization.extract_times_for_batch_sizes(models_states,
                                                                                                     batch_sizes, key)

In [None]:
visualization.plot_mean_time_per_epoch(batch_sizes, mean_times)

In [None]:
visualization.histogram_num_of_train_epochs_until_conv(batch_sizes, convergence_epochs)

In [None]:
visualization.histogram_overall_time_until_end_of_epochs(batch_sizes, overall_training_times)

In [None]:
visualization.plot_sharpness_times_runtime(batch_sizes, overall_training_times, sharpnesses, key)

### ADAM

In [None]:
batch_sizes = [32, 64, 128, 256, 512, 1024]
lr = 0.001
training_epochs = EPOCHS

key = ('fixed', 'adam', 'none')

In [None]:
if not RETRAIN_FLAG:
    for batch_size in batch_sizes:
        # Read training data
        train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                          batch_size=batch_size,
                                                                          img_size=32)
        optimizer = keras.optimizers.Adam(learning_rate=lr)
        model = build_and_load_model_weights(train, optimizer, filename_bs_opt_sam(batch_size, "ADAM", "NONE"))
        models[key + (batch_size,)] = model

In [None]:
if RETRAIN_FLAG:
    models_states[key] = {}
    for batch_size in batch_sizes:
        with tf.distribute.MirroredStrategy().scope():
            # Read training data
            train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                              batch_size=batch_size,
                                                                              img_size=32)

            optimizer = keras.optimizers.Adam(learning_rate=lr)
            model = tf_models.build_and_compile_simple_cnn(train, optimizer)
            train_callbacks = callback_for_filename(filename_bs_opt_sam(batch_size, "ADAM", "NONE"))

            models_states[key][batch_size] = ml_utils.train_model(model, train, validation, epochs=training_epochs,
                                                                  extra_callbacks=train_callbacks, verbose=1)
            models[key + (batch_size,)] = model
            initial_weights[key + (batch_size,)] = model.get_weights()
            ml_utils.save_initial_weights_dict(initial_weights)

In [None]:
state_filename = 'model_fixed_lr_diff_bs_ADAM_noSAM_state'

In [None]:
if RETRAIN_FLAG:
    ml_utils.save_model_state(models_states[key], state_filename)
else:
    models_states[key] = ml_utils.load_model_state(state_filename)

In [None]:
visualization.plot_loss_by_param(models_states[key], 'batch size with fixed learning rate without SAM and ADAM',
                                 'fixed_lr_diff_bs_ADAM_noSAM')

#### Sharpness

In [None]:
if RETRAIN_FLAG:
    for batch_size in batch_sizes:
        model = models[key + (batch_size,)]

        sharpness_bs = visualization.get_sharpness(model, train)
        sharpnesses[key + (batch_size,)] = sharpness_bs

        ml_utils.save_sharpnesses_dict(sharpnesses)

if not RETRAIN_FLAG:
    sharpnesses = ml_utils.load_sharpnesses_dict()

In [None]:
visualization.plot_sharpness(batch_sizes, sharpnesses, key)

#### Distance from initial weights

In [None]:
visualization.plot_distance_from_initial_weight(models, initial_weights, batch_sizes, key)

#### Runtime

In [None]:
mean_times, convergence_epochs, overall_training_times = visualization.extract_times_for_batch_sizes(models_states,
                                                                                                     batch_sizes, key)

In [None]:
visualization.plot_mean_time_per_epoch(batch_sizes, mean_times)

In [None]:
visualization.histogram_num_of_train_epochs_until_conv(batch_sizes, convergence_epochs)

In [None]:
visualization.histogram_overall_time_until_end_of_epochs(batch_sizes, overall_training_times)

In [None]:
visualization.plot_sharpness_times_runtime(batch_sizes, overall_training_times, sharpnesses, key)

# Different batch sizes with linear increasing learning rate

## with sharpness aware minimization

### SGD with Momentum + SAM

In [None]:
learning_rates = [0.001, 0.002, 0.004, 0.008, 0.016, 0.032]
batch_sizes = [32, 64, 128, 256, 512, 1024]
training_epochs = EPOCHS

key = ('increasing', 'sgd', 'sam')

In [None]:
models_by_batch_size_diff_lr = {}
model_history_dict_diff_lr = {}

In [None]:
diff_lr_state_filename = 'model_diff_lr_diff_bs_state'

In [None]:
if not RETRAIN_FLAG:
    for batch_size, lr in zip(batch_sizes, learning_rates):
        # Read training data
        train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                          batch_size=batch_size,
                                                                          img_size=32)
        optimizer = keras.optimizers.SGD(learning_rate=lr, momentum=0.9)
        model = build_and_load_sam_model_weights(train, optimizer,
                                                 filename_bs_lr_opt_sam(batch_size, lr, "SGD-MOM", "SAM"))
        models[key + (batch_size, lr)] = model

In [None]:
if RETRAIN_FLAG:
    models_states[key] = {}
    for batch_size, lr in zip(batch_sizes, learning_rates):
        with tf.distribute.MirroredStrategy().scope():
            # Read training data
            train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                              batch_size=batch_size,
                                                                              img_size=32)

            optimizer = keras.optimizers.SGD(learning_rate=lr, momentum=0.9)
            model = tf_models.build_simple_cnn_sam(train, optimizer, adaptive=True, rho=2.0)
            train_callbacks = callback_for_filename(filename_bs_lr_opt_sam(batch_size, lr, "SGD-MOM", "SAM"))

            models_states[key][batch_size] = ml_utils.train_model(model, train, validation, epochs=training_epochs,
                                                                  extra_callbacks=train_callbacks, verbose=1)
            models[key + (batch_size, lr)] = model
            initial_weights[key + (batch_size, lr)] = model.get_weights()
            ml_utils.save_initial_weights_dict(initial_weights)

In [None]:
if RETRAIN_FLAG:
    ml_utils.save_model_state(models_states[key], diff_lr_state_filename)
else:
    models_states[key] = ml_utils.load_model_state(diff_lr_state_filename)

In [None]:
visualization.plot_loss_by_param(models_states[key], 'batch size with increasing learning rate',
                                 'diff_lr_diff_bs')

#### Sharpness

In [None]:
if RETRAIN_FLAG:
    for batch_size, lr in zip(batch_sizes, learning_rates):
        model = models[key + (batch_size, lr)]

        sharpness = visualization.get_sharpness(model.base_model, train)
        sharpnesses[key + (batch_size, lr)] = sharpness

        ml_utils.save_sharpnesses_dict(sharpnesses)

if not RETRAIN_FLAG:
    sharpnesses = ml_utils.load_sharpnesses_dict()

In [None]:
visualization.plot_sharpness(batch_sizes, sharpnesses, key, learning_rates)

#### Distance from initial weights

In [None]:
visualization.plot_distance_from_initial_weight(models, initial_weights, batch_sizes, key, learning_rates)

#### Runtime

In [None]:
mean_times, convergence_epochs, overall_training_times = visualization.extract_times_for_batch_sizes(models_states,
                                                                                                     batch_sizes, key)

In [None]:
visualization.plot_mean_time_per_epoch(batch_sizes, mean_times)

In [None]:
visualization.histogram_num_of_train_epochs_until_conv(batch_sizes, convergence_epochs)

In [None]:
visualization.histogram_overall_time_until_end_of_epochs(batch_sizes, overall_training_times)

In [None]:
# Save to Drive in case we run on Google Colab
if USE_COLAB:
    !cp -r /content/graphs/ /content/drive/MyDrive/
    !cp -r /content/tmp/ /content/drive/MyDrive/

### SGD with Momentum + ASAM

In [None]:
learning_rates = [0.001, 0.002, 0.004, 0.008, 0.016, 0.032]
batch_sizes = [32, 64, 128, 256, 512, 1024]
training_epochs = EPOCHS

key = ('increasing', 'sgd', 'asam')

In [None]:
state_filename = 'model_diff_lr_diff_bs_sgd_asam_state'

In [None]:
if not RETRAIN_FLAG:
    for batch_size, lr in zip(batch_sizes, learning_rates):
        # Read training data
        train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                          batch_size=batch_size,
                                                                          img_size=32)
        optimizer = keras.optimizers.SGD(learning_rate=lr, momentum=0.9)
        model = build_and_load_sam_model_weights(train, optimizer,
                                                 filename_bs_lr_opt_sam(batch_size, lr, "SGD-MOM", "ASAM"))
        models[key + (batch_size, lr)] = model

In [None]:
if RETRAIN_FLAG:
    models_states[key] = {}
    for batch_size, lr in zip(batch_sizes, learning_rates):
        with tf.distribute.MirroredStrategy().scope():
            # Read training data
            train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                              batch_size=batch_size,
                                                                              img_size=32)

            optimizer = keras.optimizers.SGD(learning_rate=lr, momentum=0.9)
            model = tf_models.build_simple_cnn_sam(train, optimizer, adaptive=True, rho=2.0)
            train_callbacks = callback_for_filename(filename_bs_lr_opt_sam(batch_size, lr, "SGD-MOM", "ASAM"))

            models_states[key][batch_size] = ml_utils.train_model(model, train, validation, epochs=training_epochs,
                                                                  extra_callbacks=train_callbacks, verbose=1)
            models[key + (batch_size, lr)] = model
            initial_weights[key + (batch_size, lr)] = model.get_weights()
            ml_utils.save_initial_weights_dict(initial_weights)

In [None]:
if RETRAIN_FLAG:
    ml_utils.save_model_state(models_states[key], state_filename)
else:
    models_states[key] = ml_utils.load_model_state(state_filename)

In [None]:
visualization.plot_loss_by_param(models_states[key], 'batch size with increasing learning rate with SGD and ASAM',
                                 'diff_lr_diff_bs_SGD_ASAM')

#### Sharpness

In [None]:
if RETRAIN_FLAG:
    for batch_size, lr in zip(batch_sizes, learning_rates):
        model = models[key + (batch_size, lr)]

        sharpness = visualization.get_sharpness(model.base_model, train)
        sharpnesses[key + (batch_size, lr)] = sharpness

        ml_utils.save_sharpnesses_dict(sharpnesses)

if not RETRAIN_FLAG:
    sharpnesses = ml_utils.load_sharpnesses_dict()

In [None]:
visualization.plot_sharpness(batch_sizes, sharpnesses, key, learning_rates)

#### Distance from initial weights

In [None]:
visualization.plot_distance_from_initial_weight(models, initial_weights, batch_sizes, key, learning_rates)

#### Runtime

In [None]:
mean_times, convergence_epochs, overall_training_times = visualization.extract_times_for_batch_sizes(models_states,
                                                                                                     batch_sizes, key)

In [None]:
visualization.plot_mean_time_per_epoch(batch_sizes, mean_times)

In [None]:
visualization.histogram_num_of_train_epochs_until_conv(batch_sizes, convergence_epochs)

In [None]:
visualization.histogram_overall_time_until_end_of_epochs(batch_sizes, overall_training_times)

In [None]:
# Save to Drive in case we run on Google Colab
if USE_COLAB:
    !cp -r /content/graphs/ /content/drive/MyDrive/
    !cp -r /content/tmp/ /content/drive/MyDrive/

### ADAM + SAM

In [None]:
learning_rates = [0.001, 0.002, 0.004, 0.008, 0.016, 0.032]
batch_sizes = [32, 64, 128, 256, 512, 1024]
training_epochs = EPOCHS

key = ('increasing', 'adam', 'sam')

In [None]:
state_filename = 'model_diff_lr_diff_bs_adam_sam_state'

In [None]:
if not RETRAIN_FLAG:
    for batch_size, lr in zip(batch_sizes, learning_rates):
        # Read training data
        train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                          batch_size=batch_size,
                                                                          img_size=32)
        optimizer = keras.optimizers.Adam(learning_rate=lr)
        model = build_and_load_sam_model_weights(train, optimizer,
                                                 filename_bs_lr_opt_sam(batch_size, lr, "ADAM", "SAM"))
        models[key + (batch_size, lr)] = model

In [None]:
if RETRAIN_FLAG:
    models_states[key] = {}
    for batch_size, lr in zip(batch_sizes, learning_rates):
        with tf.distribute.MirroredStrategy().scope():
            # Read training data
            train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                              batch_size=batch_size,
                                                                              img_size=32)

            optimizer = keras.optimizers.Adam(learning_rate=lr)
            model = tf_models.build_simple_cnn_sam(train, optimizer)
            train_callbacks = callback_for_filename(filename_bs_lr_opt_sam(batch_size, lr, "ADAM", "SAM"))

            models_states[key][batch_size] = ml_utils.train_model(model, train, validation, epochs=training_epochs,
                                                                  extra_callbacks=train_callbacks, verbose=1)
            models[key + (batch_size, lr)] = model
            initial_weights[key + (batch_size, lr)] = model.get_weights()
            ml_utils.save_initial_weights_dict(initial_weights)

In [None]:
if RETRAIN_FLAG:
    ml_utils.save_model_state(models_states[key], state_filename)
else:
    models_states[key] = ml_utils.load_model_state(state_filename)

In [None]:
visualization.plot_loss_by_param(models_states[key], 'batch size with increasing learning rate with ADAM and SAM',
                                 'diff_lr_diff_bs_ADAM_SAM')

#### Sharpness

In [None]:
if RETRAIN_FLAG:
    for batch_size, lr in zip(batch_sizes, learning_rates):
        model = models[key + (batch_size, lr)]

        sharpness = visualization.get_sharpness(model.base_model, train)
        sharpnesses[key + (batch_size, lr)] = sharpness

        ml_utils.save_sharpnesses_dict(sharpnesses)

if not RETRAIN_FLAG:
    sharpnesses = ml_utils.load_sharpnesses_dict()

In [None]:
visualization.plot_sharpness(batch_sizes, sharpnesses, key, learning_rates)

#### Distance from initial weights

In [None]:
visualization.plot_distance_from_initial_weight(models, initial_weights, batch_sizes, key, learning_rates)

#### Runtime

In [None]:
mean_times, convergence_epochs, overall_training_times = visualization.extract_times_for_batch_sizes(models_states,
                                                                                                     batch_sizes, key)

In [None]:
visualization.plot_mean_time_per_epoch(batch_sizes, mean_times)

In [None]:
visualization.histogram_num_of_train_epochs_until_conv(batch_sizes, convergence_epochs)

In [None]:
visualization.histogram_overall_time_until_end_of_epochs(batch_sizes, overall_training_times)

In [None]:
# Save to Drive in case we run on Google Colab
if USE_COLAB:
    !cp -r /content/graphs/ /content/drive/MyDrive/
    !cp -r /content/tmp/ /content/drive/MyDrive/

## ADAM + ASAM

In [None]:
learning_rates = [0.001, 0.002, 0.004, 0.008, 0.016, 0.032]
batch_sizes = [32, 64, 128, 256, 512, 1024]
training_epochs = EPOCHS

key = ('increasing', 'adam', 'asam')

In [None]:
state_filename = 'model_diff_lr_diff_bs_adam_asam_state'

In [None]:
if not RETRAIN_FLAG:
    for batch_size, lr in zip(batch_sizes, learning_rates):
        # Read training data
        train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                          batch_size=batch_size,
                                                                          img_size=32)
        optimizer = keras.optimizers.Adam(learning_rate=lr)
        model = build_and_load_sam_model_weights(train, optimizer,
                                                 filename_bs_lr_opt_sam(batch_size, lr, "ADAM", "ASAM"))
        models[key + (batch_size, lr)] = model

In [None]:
if RETRAIN_FLAG:
    models_states[key] = {}
    for batch_size, lr in zip(batch_sizes, learning_rates):
        with tf.distribute.MirroredStrategy().scope():
            # Read training data
            train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                              batch_size=batch_size,
                                                                              img_size=32)

            optimizer = keras.optimizers.Adam(learning_rate=lr)
            model = tf_models.build_simple_cnn_sam(train, optimizer)
            train_callbacks = callback_for_filename(filename_bs_lr_opt_sam(batch_size, lr, "ADAM", "ASAM"))

            models_states[key][batch_size] = ml_utils.train_model(model, train, validation, epochs=training_epochs,
                                                                  extra_callbacks=train_callbacks, verbose=1)
            models[key + (batch_size, lr)] = model
            initial_weights[key + (batch_size, lr)] = model.get_weights()
            ml_utils.save_initial_weights_dict(initial_weights)

In [None]:
if RETRAIN_FLAG:
    ml_utils.save_model_state(models_states[key], state_filename)
else:
    models_states[key] = ml_utils.load_model_state(state_filename)

In [None]:
visualization.plot_loss_by_param(models_states[key], 'batch size with increasing learning rate with ADAM and ASAM',
                                 'diff_lr_diff_bs_ADAM_ASAM')

#### Sharpness

In [None]:
if RETRAIN_FLAG:
    for batch_size, lr in zip(batch_sizes, learning_rates):
        model = models[key + (batch_size, lr)]

        sharpness = visualization.get_sharpness(model.base_model, train)
        sharpnesses[key + (batch_size, lr)] = sharpness

        ml_utils.save_sharpnesses_dict(sharpnesses)

if not RETRAIN_FLAG:
    sharpnesses = ml_utils.load_sharpnesses_dict()

In [None]:
visualization.plot_sharpness(batch_sizes, sharpnesses, key, learning_rates)

#### Distance from initial weights

In [None]:
visualization.plot_distance_from_initial_weight(models, initial_weights, batch_sizes, key, learning_rates)

#### Runtime

In [None]:
mean_times, convergence_epochs, overall_training_times = visualization.extract_times_for_batch_sizes(models_states,
                                                                                                     batch_sizes, key)

In [None]:
visualization.plot_mean_time_per_epoch(batch_sizes, mean_times)

In [None]:
visualization.histogram_num_of_train_epochs_until_conv(batch_sizes, convergence_epochs)

In [None]:
visualization.histogram_overall_time_until_end_of_epochs(batch_sizes, overall_training_times)

In [None]:
# Save to Drive in case we run on Google Colab
if USE_COLAB:
    !cp -r /content/graphs/ /content/drive/MyDrive/
    !cp -r /content/tmp/ /content/drive/MyDrive/

## without sharpness aware minimization

### SGD with Momentum

In [None]:
learning_rates = [0.001, 0.002, 0.004, 0.008, 0.016, 0.032]
batch_sizes = [32, 64, 128, 256, 512, 1024]
training_epochs = EPOCHS

key = ('increasing', 'SGD', 'none')

In [None]:
state_filename = 'model_diff_lr_diff_bs_sgd_state'

In [None]:
if not RETRAIN_FLAG:
    for batch_size, lr in zip(batch_sizes, learning_rates):
        # Read training data
        train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                          batch_size=batch_size,
                                                                          img_size=32)
        optimizer = keras.optimizers.SGD(learning_rate=lr, momentum=0.9)
        model = build_and_load_model_weights(train, optimizer,
                                             filename_bs_lr_opt_sam(batch_size, lr, "SGD", "NONE"))
        models[key + (batch_size, lr)] = model

In [None]:
if RETRAIN_FLAG:
    models_states[key] = {}
    for batch_size, lr in zip(batch_sizes, learning_rates):
        with tf.distribute.MirroredStrategy().scope():
            # Read training data
            train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                              batch_size=batch_size,
                                                                              img_size=32)

            optimizer = keras.optimizers.SGD(learning_rate=lr, momentum=0.9)
            model = tf_models.build_and_compile_simple_cnn(train, optimizer)
            train_callbacks = callback_for_filename(filename_bs_lr_opt_sam(batch_size, lr, "SGD", "NONE"))

            models_states[key][batch_size] = ml_utils.train_model(model, train, validation, epochs=training_epochs,
                                                                  extra_callbacks=train_callbacks, verbose=1)
            models[key + (batch_size, lr)] = model
            initial_weights[key + (batch_size, lr)] = model.get_weights()
            ml_utils.save_initial_weights_dict(initial_weights)

In [None]:
if RETRAIN_FLAG:
    ml_utils.save_model_state(models_states[key], state_filename)
else:
    models_states[key] = ml_utils.load_model_state(state_filename)

In [None]:
visualization.plot_loss_by_param(models_states[key], 'batch size with increasing learning rate with SGD',
                                 'diff_lr_diff_bs_SGD_NONE')

#### Sharpness

In [None]:
if RETRAIN_FLAG:
    for batch_size, lr in zip(batch_sizes, learning_rates):
        model = models[key + (batch_size, lr)]

        sharpness = visualization.get_sharpness(model, train)
        sharpnesses[key + (batch_size, lr)] = sharpness

        ml_utils.save_sharpnesses_dict(sharpnesses)

if not RETRAIN_FLAG:
    sharpnesses = ml_utils.load_sharpnesses_dict()

In [None]:
visualization.plot_sharpness(batch_sizes, sharpnesses, key, learning_rates)

#### Distance from initial weights

In [None]:
visualization.plot_distance_from_initial_weight(models, initial_weights, batch_sizes, key, learning_rates)

#### Runtime

In [None]:
mean_times, convergence_epochs, overall_training_times = visualization.extract_times_for_batch_sizes(models_states,
                                                                                                     batch_sizes, key)

In [None]:
visualization.plot_mean_time_per_epoch(batch_sizes, mean_times)

In [None]:
visualization.histogram_num_of_train_epochs_until_conv(batch_sizes, convergence_epochs)

In [None]:
visualization.histogram_overall_time_until_end_of_epochs(batch_sizes, overall_training_times)

In [None]:
# Save to Drive in case we run on Google Colab
if USE_COLAB:
    !cp -r /content/graphs/ /content/drive/MyDrive/
    !cp -r /content/tmp/ /content/drive/MyDrive/

### ADAM

In [None]:
learning_rates = [0.001, 0.002, 0.004, 0.008, 0.016, 0.032]
batch_sizes = [32, 64, 128, 256, 512, 1024]
training_epochs = EPOCHS

key = ('increasing', 'adam', 'none')

In [None]:
state_filename = 'model_diff_lr_diff_bs_adam_state'

In [None]:
if not RETRAIN_FLAG:
    for batch_size, lr in zip(batch_sizes, learning_rates):
        # Read training data
        train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                          batch_size=batch_size,
                                                                          img_size=32)
        optimizer = keras.optimizers.Adam(learning_rate=lr)
        model = build_and_load_model_weights(train, optimizer,
                                             filename_bs_lr_opt_sam(batch_size, lr, "ADAM", "NONE"))
        models[key + (batch_size, lr)] = model

In [None]:
if RETRAIN_FLAG:
    models_states[key] = {}
    for batch_size, lr in zip(batch_sizes, learning_rates):
        with tf.distribute.MirroredStrategy().scope():
            # Read training data
            train, validation = data_loading.load_batched_and_resized_dataset(dataset_name=dataset_name,
                                                                              batch_size=batch_size,
                                                                              img_size=32)

            optimizer = keras.optimizers.Adam(learning_rate=lr)
            model = tf_models.build_and_compile_simple_cnn(train, optimizer)
            train_callbacks = callback_for_filename(filename_bs_lr_opt_sam(batch_size, lr, "ADAM", "NONE"))

            models_states[key][batch_size] = ml_utils.train_model(model, train, validation, epochs=training_epochs,
                                                                  extra_callbacks=train_callbacks, verbose=1)
            models[key + (batch_size, lr)] = model
            initial_weights[key + (batch_size, lr)] = model.get_weights()
            ml_utils.save_initial_weights_dict(initial_weights)

In [None]:
if RETRAIN_FLAG:
    ml_utils.save_model_state(models_states[key], state_filename)
else:
    models_states[key] = ml_utils.load_model_state(state_filename)

In [None]:
visualization.plot_loss_by_param(models_states[key], 'batch size with increasing learning rate with ADAM',
                                 'diff_lr_diff_bs_ADAM_NONE')

#### Sharpness

In [None]:
if RETRAIN_FLAG:
    for batch_size, lr in zip(batch_sizes, learning_rates):
        model = models[key + (batch_size, lr)]

        sharpness = visualization.get_sharpness(model, train)
        sharpnesses[key + (batch_size, lr)] = sharpness

        ml_utils.save_sharpnesses_dict(sharpnesses)

if not RETRAIN_FLAG:
    sharpnesses = ml_utils.load_sharpnesses_dict()

In [None]:
visualization.plot_sharpness(batch_sizes, sharpnesses, key, learning_rates)

#### Distance from initial weights

In [None]:
visualization.plot_distance_from_initial_weight(models, initial_weights, batch_sizes, key, learning_rates)

#### Runtime

In [None]:
mean_times, convergence_epochs, overall_training_times = visualization.extract_times_for_batch_sizes(models_states,
                                                                                                     batch_sizes, key)

In [None]:
visualization.plot_mean_time_per_epoch(batch_sizes, mean_times)

In [None]:
visualization.histogram_num_of_train_epochs_until_conv(batch_sizes, convergence_epochs)

In [None]:
visualization.histogram_overall_time_until_end_of_epochs(batch_sizes, overall_training_times)

In [None]:
visualization.plot_sharpness_times_runtime(batch_sizes, overall_training_times, sharpnesses, key, learning_rates)