In [None]:
import tensorflow as tf
import numpy as np
from wr_callback import WeightsReset, PrintEpoch
from matplotlib import pyplot as plt
import time
from datasets import load_dataset, Dataset as D
from utils import plot_history, get_dataset_name, get_csv_filename
import csv

from simple_model import make_model

## Params

In [None]:
DATASET = D.IMAGENETTE
BATCH_SIZE = 32
PENULTIMATE = 512
EPOCHS = 80

rand_configs = [
    [1.0, 1.0, 1.0, 1.0],
    [1.0, 1.0, 1.0, 0.5],
    [1.0, 1.0, 0.5, 0.5],
    [1.0, 1.0, 1.0, 0.0],
    [1.0, 1.0, 0.5, 0.0],
    [1.0, 0.5, 0.5, 0.0],
    [1.0, 1.0, 0.0, 0.0],
    [1.0, 0.5, 0.0, 0.0],
    [0.5, 0.5, 0.0, 0.0],
    [1.0, 0.0, 0.0, 0.0],
    [0.5, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0],
]

## Dataset init

In [None]:
dataset_train, dataset_test, im_shape, num_classes = load_dataset(DATASET, batch_size = BATCH_SIZE)

## Model init

In [None]:
model = make_model(
    im_shape, PENULTIMATE, num_classes
)

In [None]:
model.summary()

In [None]:
model_init_weights = model.get_weights()

## Train model with different WR configs

In [None]:
glorot_init = tf.keras.initializers.GlorotNormal()
he_init = tf.keras.initializers.HeNormal()

train_loss_values = []
test_loss_values = []
epochs_per_config = []

training_hist_per_config = []

dataset_name = get_dataset_name(DATASET)
csv_file_name = get_csv_filename('configs', dataset_name)

with open(csv_file_name, 'w+', encoding='UTF8', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['config', 'train loss', 'best test loss'])
    
for config in rand_configs:
    print(f'---config {config}---')
    model.set_weights(model_init_weights)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss="categorical_crossentropy",
        metrics=["accuracy"],
    )

    layers_for_reset = [
        {"layer": model.get_layer(name = "dense_1"), "rand_lvl": config[0], "weights_initializer": glorot_init},
        {"layer": model.get_layer(name = "dense"), "rand_lvl": config[1], "weights_initializer": he_init},
        {"layer": model.get_layer(name = "conv2d_3"), "rand_lvl": config[2], "weights_initializer": he_init},
        {"layer": model.get_layer(name = "conv2d_2"), "rand_lvl": config[3], "weights_initializer": he_init}
    ]

    wr = WeightsReset(
        layers_for_reset, 
        perform_reset = True, collect_stats = False, collect_weights = False, train_dataset = dataset_train)

    model_hist = model.fit(
        dataset_train,
        epochs=EPOCHS,
        validation_data=dataset_test,
        callbacks=[wr,PrintEpoch()],
        verbose=0
    )

    best_epoch = np.argmin(model_hist.history['val_loss'])

    train_loss_values.append(model_hist.history['loss'][best_epoch])
    test_loss_values.append(model_hist.history['val_loss'][best_epoch])
    epochs_per_config.append(len(model_hist.history['val_loss']))
    training_hist_per_config.append(model_hist.history)
    print(f'best train loss = {train_loss_values[-1]}, best test loss = {test_loss_values[-1]}, total epochs = {epochs_per_config[-1]}')
    print('---end---')
    
    with open(csv_file_name, 'a', encoding='UTF8', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([f'{config}', train_loss_values[-1], test_loss_values[-1]])

    time.sleep(30) # gpu cooler :)

## Plot results

In [None]:
plt.plot(train_loss_values)
plt.plot(test_loss_values)
#plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('configuration number')
plt.legend(['train', 'test'], loc='upper right')
plt.show()