In [None]:
import tensorflow as tf
import numpy as np
from wr_callback import WeightsReset, PrintEpoch
from matplotlib import pyplot as plt
import time
from datasets import load_dataset, Dataset as D
from utils import plot_history, get_dataset_name, get_csv_filename
from simple_model import make_model
import csv
from reg_configs import reg_configs

## Params

In [None]:
DATASET = D.IMAGENETTE
BATCH_SIZE = 32
PENULTIMATE = 512
EPOCHS = 80

## Dataset init

In [None]:
dataset_train, dataset_test, im_shape, num_classes = load_dataset(DATASET, batch_size = BATCH_SIZE)

## Model init

In [None]:
model = make_model(
    im_shape, PENULTIMATE, num_classes
)

In [None]:
model.summary()

In [None]:
model_init_weights = model.get_weights()

## CSV init

In [None]:
dataset_name = get_dataset_name(DATASET)
csv_file_name = get_csv_filename('compare', dataset_name)

with open(csv_file_name, 'w+', encoding='UTF8', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['config', 'best train loss', 'best test loss'])

## Train model with WR

In [None]:
glorot_init = tf.keras.initializers.GlorotNormal()
he_init = tf.keras.initializers.HeNormal()

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss="categorical_crossentropy",
    metrics=["accuracy", "categorical_crossentropy"],
)

layers_for_reset = [
    {"layer": model.get_layer(name = "dense_1"), "rand_lvl": 1.0, "weights_initializer": glorot_init},
    {"layer": model.get_layer(name = "dense"), "rand_lvl": 1.0, "weights_initializer": he_init}
]

wr = WeightsReset(
    layers_for_reset, 
    perform_reset = True, collect_stats = False, collect_weights = False, train_dataset = dataset_train)

model_wr_hist = model.fit(
    dataset_train,
    epochs=EPOCHS,
    validation_data=dataset_test,
    callbacks=[wr,PrintEpoch()],
    verbose=0
)

best_epoch_wr = np.argmin(model_wr_hist.history['val_categorical_crossentropy'])
best_train_loss_wr = model_wr_hist.history['categorical_crossentropy'][best_epoch_wr]
best_test_loss_wr = model_wr_hist.history['val_categorical_crossentropy'][best_epoch_wr]

print(f'WR model, best train loss {best_train_loss_wr}, best test loss {best_test_loss_wr}')

with open(csv_file_name, 'a', encoding='UTF8', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['WR model', best_train_loss_wr, best_test_loss_wr])

In [None]:
plot_history(model_wr_hist.history)

## Train model without WR

In [None]:
tf.keras.backend.clear_session()
model = make_model(
    im_shape, PENULTIMATE, num_classes
)
model.set_weights(model_init_weights)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss="categorical_crossentropy",
    metrics=["accuracy", "categorical_crossentropy"],
)

wr = WeightsReset([], perform_reset = False, train_dataset = dataset_train)

model_nowr_hist = model.fit(
    dataset_train,
    epochs=EPOCHS,
    validation_data=dataset_test,
    callbacks=[wr,PrintEpoch()],
    verbose=0
)

best_epoch_nowr = np.argmin(model_nowr_hist.history['val_categorical_crossentropy'])
best_train_loss_nowr = model_nowr_hist.history['categorical_crossentropy'][best_epoch_nowr]
best_test_loss_nowr = model_nowr_hist.history['val_categorical_crossentropy'][best_epoch_nowr]

print(f'NOWR model, best train loss {best_train_loss_nowr}, best test loss {best_test_loss_nowr}')

with open(csv_file_name, 'a', encoding='UTF8', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['No WR model', best_train_loss_nowr, best_test_loss_nowr])

In [None]:
plot_history(model_nowr_hist.history)

## Train model with other regularizations

In [None]:
with open(csv_file_name, 'a', encoding='UTF8', newline='') as f:
    writer = csv.writer(f)
    for config in reg_configs:
        tf.keras.backend.clear_session()
        model = make_model(
            im_shape, PENULTIMATE, num_classes,
            reg = config
        )
        model.set_weights(model_init_weights)
        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
            loss="categorical_crossentropy",
            metrics=["accuracy", "categorical_crossentropy"],
        )

        wr = WeightsReset([], perform_reset = False, train_dataset = dataset_train)

        print(f'---Config: {config}--')
        model_hist = model.fit(
            dataset_train,
            epochs=EPOCHS,
            validation_data=dataset_test,
            callbacks=[wr,PrintEpoch()],
            verbose=0
        )

        best_epoch = np.argmin(model_hist.history['val_categorical_crossentropy'])
        best_train_loss = model_hist.history['categorical_crossentropy'][best_epoch]
        best_test_loss = model_hist.history['val_categorical_crossentropy'][best_epoch]
        print(f'best train loss {best_train_loss}, best test loss {best_test_loss}')
        print('---end---')

        writer.writerow([f'config {config}', best_train_loss, best_test_loss])

        time.sleep(100) # gpu cooler :)