In [None]:
"""
testing.ipynb

File for performing testing to implement lottery ticket experiments.

Authors: Jordan Bourdeau, Casey Forey
Date Created: 3/8/24
"""

%load_ext tensorboard
import copy
import functools
from importlib import reload
import numpy as np
import os
import random
import tempfile
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow import keras
from keras.callbacks import Callback
from keras import backend as K
from keras import Sequential
from keras.layers import Dense, Input
from keras.losses import CategoricalCrossentropy

import src.harness.constants as C
import src.harness.dataset as dataset
import src.harness.experiment as experiment
import src.harness.model as mod
import src.harness.pruning as pruning
import src.harness.rewind as rewind
import src.harness.training as train
import src.harness.utils as utils

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

## Parameters

In [None]:
reload(mod)
reload(pruning)

X_train, X_test, Y_train, Y_test = dataset.load_and_process_mnist()

num_epochs: int = 10
input_shape: tuple = X_train[0].shape
num_classes: int = 10
batch_size: int = len(X_train)

num_train_samples: int = X_train.shape[0]

## Building

In [None]:
reload(mod)
reload(utils)

# Create a model with the same architecture using all Keras components to check its accuracy with the same parameters
utils.set_seed(0)
make_lenet: callable = functools.partial(mod.create_lenet_300_100, input_shape, num_classes)

original_model: keras.Model = make_lenet()
# original_model.summary()
# original_model.trainable_variables

original_mask_model: keras.Model = mod.create_masked_nn(make_lenet)
original_mask_model.summary()
# original_mask_model.trainable_variables

## Training

In [None]:
reload(C)
reload(train)

# Use the original model as a reference
loss_fn: tf.keras.losses.Loss = C.LOSS_FUNCTION()
accuracy_metric: tf.keras.metrics.Metric = tf.keras.metrics.CategoricalAccuracy()

test_loss, test_accuracy = train.test_step(original_model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'Test Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

In [None]:
reload(train)

# Test that deepcopy is exactly the same as the original model
# Copy originals
model: keras.Model = copy.deepcopy(original_model)
mask_model: keras.Model = copy.deepcopy(original_mask_model)

test_loss, test_accuracy = train.test_step(original_model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'Test Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

In [None]:
reload(C)
reload(train)

# Test single step of training

# Define the optimizer outside of the function
optimizer = C.OPTIMIZER()
train_one_step: callable = train.get_train_one_step()
accuracy_metric.reset_states()

# Copy originals
model: keras.Model = copy.deepcopy(original_model)
mask_model: keras.Model = copy.deepcopy(original_mask_model)

# Sanity Check
test_loss, test_accuracy = train.test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}\n')

epochs: int = C.TRAINING_EPOCHS

train_accuracies: np.array = np.zeros(epochs)
test_accuracies: np.array = np.zeros(epochs)


for i in range(epochs):
    train_loss, train_accuracy = train_one_step(model, mask_model, X_train, Y_train, optimizer)
    train_accuracies[i] = train_accuracy

    test_loss, test_accuracy = train.test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
    test_accuracies[i] = test_accuracy

    print(f'Epoch {i + 1} Train Loss: {train_loss:.6f}, Train Accuracy: {train_accuracy:.6f}, Test Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

print(f'Test Accuracies:')
print(test_accuracies)
print(f'Training Accuracies:')
print(train_accuracies)

# Get test parameters
accuracy_metric.reset_states()
test_loss, test_accuracy = train.test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

In [None]:
reload(C)
reload(train)

# Testing `training_loop` function
epochs: int = C.TRAINING_EPOCHS

# Copy originals
model: keras.Model = copy.deepcopy(original_model)
mask_model: keras.Model = copy.deepcopy(original_mask_model)

# Sanity Check
test_loss, test_accuracy = train.test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}\n')

training_round = train.training_loop(0, model, mask_model, dataset.load_and_process_mnist, epochs)

print(f'Took {np.sum(training_round.test_accuracies != 0)} / {epochs} epochs')
print(f'Ended with a best training accuracy of {np.max(training_round.train_accuracies) * 100:.2f}% and test accuracy of {np.max(training_round.test_accuracies) * 100:.2f}%')

print(f'Test Accuracies:')
print(training_round.test_accuracies)
print(f'Training Accuracies:')
print(training_round.train_accuracies)

# Get test parameters
test_loss, test_accuracy = train.test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

In [None]:
# Testing `train` function

# Copy originals
model: keras.Model = copy.deepcopy(original_model)
mask_model: keras.Model = copy.deepcopy(original_mask_model)

# Sanity Check
test_loss, test_accuracy = train.test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}\n')

training_round = train.train(0, 0, model, mask_model, dataset.load_and_process_mnist, batch_size=C.BATCH_SIZE)

print(f'\nTook {np.sum(training_round.test_accuracies != 0)} / {C.TRAINING_EPOCHS} epochs')
print(f'Ended with a best training accuracy of {np.max(training_round.train_accuracies) * 100:.2f}% and test accuracy of training accuracy of {np.max(training_round.test_accuracies) * 100:.2f}%')

print(f'Test Accuracies:')
print(training_round.test_accuracies)
print(f'Training Accuracies:')
print(training_round.train_accuracies)

# Get test parameters
test_loss, test_accuracy = train.test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

In [None]:
# Test loading the model back
loaded_model: keras.Model = mod.load_model(0, 0)

# Get test parameters
test_loss, test_accuracy = train.test_step(loaded_model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

In [None]:
print(f'Nonzero parameters after training but before pruning:')
print(utils.count_total_and_nonzero_params(model)[1])

## Pruning

In [None]:
# Test loading the model back
loaded_model: keras.Model = mod.load_model(0, 0)
target_sparsity = 0.5

# Get test parameters
test_loss, test_accuracy = train.test_step(loaded_model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

sparse_model = copy.deepcopy(loaded_model)

total, nonzero = utils.count_total_and_nonzero_params(loaded_model)
print(f'Total Params: {total}, Nonzero Params: {nonzero}')
print('Pruning')
pruning.prune(sparse_model, pruning.low_magnitude_pruning, target_sparsity)

total2, nonzero2 = utils.count_total_and_nonzero_params(sparse_model)
print(f'Total Params: {total}, Nonzero Params: {nonzero}')

assert total2 == total
assert nonzero2 == nonzero // 2

In [None]:
reload(mod)
reload(rewind)
    
model: keras.Model = copy.deepcopy(original_model)
mask_model: keras.Model = copy.deepcopy(original_mask_model)
original_weights = model.get_weights()

rewind_to_original_weights: callable = functools.partial(rewind.rewind_to_original_init, 0)
rewind_to_random_weights: callable = functools.partial(rewind.rewind_to_random_init, 0, tf.initializers.GlorotUniform())
rewind.rewind_model_weights(model, mask_model, rewind_to_random_weights)

print(original_weights[0][0][:10])
print(model.get_weights()[0][0][:10])

In [None]:
reload(mod)
reload(pruning)

model: keras.Model = copy.deepcopy(original_model)
mask_model: keras.Model = copy.deepcopy(original_mask_model)

# Asserting that every array in the mask model's weights are 1s
for layer in mask_model.layers:
    for weights in layer.get_weights():
        assert np.all(weights == 1), "Error: Not all elements in mask model's weights are 1s after updating masks"

pruning.update_masks(model, mask_model)

# Asserting that every array in the mask model's weights are still 1s
for layer in mask_model.layers:
    for weights in layer.get_weights():
        assert np.all(weights == 1), "Error: Not all elements in mask model's weights are 1s after updating masks"
        
pruning.prune(model, pruning.low_magnitude_pruning, 0.5)

In [None]:
reload(mod)
reload(pruning)
reload(utils)

# Pruning Parameters
first_step_pruning: float = 0.2
target_sparsity: float = 0.8
make_lenet: callable = functools.partial(mod.create_lenet_300_100, input_shape, num_classes)

global_pruning: bool = False

# Hardcoded for now until actual experiment code is done
for seed in range(1):
    utils.set_seed(seed)
    # Make models and save them
    model: keras.Model = make_lenet()
    mask_model: keras.Model = mod.create_masked_nn(make_lenet)   
    mod.save_model(model, seed, 0, initial=True)
    mod.save_model(mask_model, seed, 0, masks=True, initial=True)
    
    # Create the rewind rule
    rewind_to_original_weights: callable = functools.partial(rewind.rewind_to_original_init, seed)
    
    # Get the pruning percents and iterate over them 
    pruning_percents: list[float] = pruning.get_sparsity_percents(model, first_step_pruning, target_sparsity)
    for pruning_step, sparsity in enumerate(pruning_percents):
        # Retrieve original model weights
        original_weights: list[np.ndarray] = mod.load_model(seed, initial=True)
        
        # Prune the model to the new sparsity
        pruning.prune(model, pruning.low_magnitude_pruning, sparsity, global_pruning=global_pruning)
        
        # Update mask model
        pruning.update_masks(model, mask_model)
        
        # Reset unpruned weights to original values.
        rewind.rewind_model_weights(model, mask_model, rewind_to_original_weights)
        
        loss_fn: tf.keras.losses.Loss = C.LOSS_FUNCTION()
        accuracy_metric: tf.keras.metrics.Metric = tf.keras.metrics.CategoricalAccuracy()

        # Sanity Check
        test_loss, test_accuracy = train.test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
        print(f'\nModel {seed}\nStarting Test Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}\n')

        training_round = train.train(seed, pruning_step, model, mask_model, dataset.load_and_process_mnist, batch_size=C.BATCH_SIZE)

        print(f'\nTook {np.sum(training_round.test_accuracies != 0)} / {C.TRAINING_EPOCHS} epochs')
        print(f'Ended with a best training accuracy of {np.max(training_round.train_accuracies) * 100:.2f}% and test accuracy of training accuracy of {np.max(training_round.test_accuracies) * 100:.2f}%')

## Experiments

In [None]:
reload(experiment)
reload(pruning)
reload(rewind)
reload(train)

# Pruning Parameters
first_step_pruning: float = 0.2
target_sparsity: float = 0.05
make_lenet: callable = functools.partial(mod.create_lenet_300_100, input_shape, num_classes)

global_pruning: bool = False
sparsities: list[float] = pruning.get_sparsity_percents(model, first_step_pruning, target_sparsity)
experiment_data: experiment.ExperimentData = experiment.run_iterative_pruning_experiment(
    0, 
    make_lenet, 
    dataset.load_and_process_mnist,
    sparsities
)

In [None]:
reload(experiment)
reload(train)
reload(utils)

for step_index in range(len(sparsities)):
    round = experiment_data.pruning_rounds[step_index]
    print(f'{round.get_sparsity():.3f}%')