In [None]:
"""
testing.ipynb

File for performing testing to implement lottery ticket experiments.

Authors: Jordan Bourdeau, Casey Forey
Date Created: 3/8/24
"""

%load_ext tensorboard
import copy
import functools
from importlib import reload
import numpy as np
import os
import random
import tempfile
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow import keras
from keras.callbacks import Callback
from keras import backend as K
from keras import Sequential
from keras.layers import Dense, Input
from keras.losses import CategoricalCrossentropy

from tensorflow_model_optimization.sparsity import keras as sparsity
from tensorflow_model_optimization.sparsity.keras import ConstantSparsity, PolynomialDecay, prune_low_magnitude

from src.harness.constants import Constants as C
from src.harness.dataset import download_data, load_and_process_mnist
from src.harness.experiment import ExperimentData
from src.harness.model import create_lenet_300_100, create_masked_nn, create_pruned_lenet, save_model, load_model
from src.harness.pruning import create_pruning_callbacks, create_pruning_parameters
from src.harness.training import test_step, train, get_train_one_step, training_loop, TrainingRound
from src.harness.utils import count_params, get_layer_weight_counts, set_seed

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

## Parameters

In [None]:
X_train, X_test, Y_train, Y_test = load_and_process_mnist()

num_epochs: int = 10
input_shape: tuple = X_train[0].shape
num_classes: int = 10
batch_size: int = len(X_train)

num_train_samples: int = X_train.shape[0]

# Make pruning parameters
target_sparsity: float = 0.01
step_pruning_percent: float = 0.2
end_step: int = np.ceil(1.0 * num_train_samples / batch_size).astype(np.int32) * num_epochs
frequency: int = 1
pruning_parameters: dict = create_pruning_parameters(target_sparsity, 0, end_step, frequency)

## Building

In [None]:
# Create a model with the same architecture using all Keras components to check its accuracy with the same parameters
set_seed(0)

original_model: keras.Model = create_lenet_300_100(input_shape, num_classes)
# original_model.summary()
# original_model.trainable_variables

original_mask_model: keras.Model = create_masked_nn(create_pruned_lenet, input_shape, num_classes, pruning_parameters)
# original_mask_model.summary()
# original_mask_model.trainable_variables

## Training

In [None]:
# Use the original model as a reference
loss_fn: tf.keras.losses.Loss = C.LOSS_FUNCTION()
accuracy_metric: tf.keras.metrics.Metric = tf.keras.metrics.CategoricalAccuracy()

test_loss, test_accuracy = test_step(original_model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'Test Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

In [None]:
# Test that deepcopy is exactly the same as the original model
# Copy originals
model: keras.Model = copy.deepcopy(original_model)
mask_model: keras.Model = copy.deepcopy(sparsity.strip_pruning(original_mask_model))

test_loss, test_accuracy = test_step(original_model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'Test Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

In [None]:
# # Test single step of training

# # Define the optimizer outside of the function
# optimizer = C.OPTIMIZER()
# train_one_step: callable = get_train_one_step()
# accuracy_metric.reset_states()

# # Copy originals
# model: keras.Model = copy.deepcopy(original_model)
# mask_model: keras.Model = copy.deepcopy(sparsity.strip_pruning(original_mask_model))

# # Sanity Check
# test_loss, test_accuracy = test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
# print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}\n')

# epochs: int = C.TRAINING_EPOCHS

# train_accuracies: np.array = np.zeros(epochs)
# test_accuracies: np.array = np.zeros(epochs)

# for i in range(epochs):
#     train_loss, train_accuracy = train_one_step(model, mask_model, X_train, Y_train, optimizer)
#     train_accuracies[i] = train_accuracy

#     test_loss, test_accuracy = test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
#     test_accuracies[i] = test_accuracy

#     print(f'Iteration {i + 1} Train Loss: {train_loss:.6f}, Train Accuracy: {train_accuracy:.6f}, Test Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

# print(f'Test Accuracies:')
# print(test_accuracies)
# print(f'Training Accuracies:')
# print(train_accuracies)

# # Get test parameters
# accuracy_metric.reset_states()
# test_loss, test_accuracy = test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
# print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

In [None]:
# # Testing `training_loop` function

# epochs: int = C.TRAINING_EPOCHS

# # Copy originals
# model: keras.Model = copy.deepcopy(original_model)
# mask_model: keras.Model = copy.deepcopy(sparsity.strip_pruning(original_mask_model))

# # Sanity Check
# test_loss, test_accuracy = test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
# print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}\n')

# trained_model, training_round = training_loop(0, model, mask_model, load_and_process_mnist, epochs, len(X_train))

# print(f'Took {np.sum(training_round.test_accuracies != 0)} / {epochs} epochs')
# print(f'Ended with a best training accuracy of {np.max(training_round.train_accuracies) * 100:.2f}% and test accuracy of {np.max(training_round.test_accuracies) * 100:.2f}%')

# print(f'Test Accuracies:')
# print(training_round.test_accuracies)
# print(f'Training Accuracies:')
# print(training_round.train_accuracies)

# # Get test parameters
# test_loss, test_accuracy = test_step(trained_model, X_test, Y_test, loss_fn, accuracy_metric)
# print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

In [None]:
# # Testing `train` function

# # Copy originals
# model: keras.Model = copy.deepcopy(original_model)
# mask_model: keras.Model = copy.deepcopy(sparsity.strip_pruning(original_mask_model))

# # Sanity Check
# test_loss, test_accuracy = test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
# print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}\n')

# trained_model, pruned_mask_model, training_round = train(0, 0, model, mask_model, load_and_process_mnist, batch_size=len(X_train))

# print(f'\nTook {np.sum(training_round.test_accuracies != 0)} / {C.TRAINING_EPOCHS} epochs')
# print(f'Ended with a best training accuracy of {np.max(training_round.train_accuracies) * 100:.2f}% and test accuracy of training accuracy of {np.max(training_round.test_accuracies) * 100:.2f}%')

# print(f'Test Accuracies:')
# print(training_round.test_accuracies)
# print(f'Training Accuracies:')
# print(training_round.train_accuracies)

# # Get test parameters
# test_loss, test_accuracy = test_step(trained_model, X_test, Y_test, loss_fn, accuracy_metric)
# print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

In [None]:
# Test loading the model back
loaded_model: keras.Model = load_model(0, 0)

# Get test parameters
test_loss, test_accuracy = test_step(loaded_model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

In [None]:
print(f'Nonzero parameters after training but before pruning:')
print(count_params(model)[1])

## Pruning

In [None]:
def create_pruning_parameters(target_sparsity: float, begin_step: int, end_step: int, frequency: int) -> dict:
    """
    Create the dictionary of pruning parameters to be used.

    :param target_sparsity: The target sparsity to achieve during pruning, a float value between 0 and 1.
    :param begin_step:      The step at which pruning begins.
    :param end_step:        The step at which pruning ends.
    :param frequency:       The frequency with which pruning is applied, typically expressed in terms of epochs or steps.

    :returns: A dictionary containing the pruning parameters.
    """
    return {
        'pruning_schedule': sparsity.ConstantSparsity(
            target_sparsity=target_sparsity, 
            begin_step=begin_step,
            end_step=end_step, 
            frequency=frequency
        )
    }

def get_pruning_callbacks(
    monitor: str = 'val_loss', 
    patience: int = C.PATIENCE, 
    minimum_delta: float = C.MINIMUM_DELTA
    ) -> list:
    """
    Create a callback to be performed during pruning.

    :param monitor:       Metric to monitor for early stopping (e.g. 'val_loss').
    :param patience:      Number of steps without an increase in performance before stopping.
    :param minimum_delta: Minimum improvement required to be considered an improvement.

    :returns: List of pruning callbacks.
    """
    return [
        sparsity.UpdatePruningStep(),
        # sparsity.PruningSummaries(log_dir = logdir, profile_batch=0),
        tf.keras.callbacks.EarlyStopping(
            monitor=monitor, 
            patience=patience,
            min_delta=minimum_delta
        )
    ]

# For each layer, there are synaptic connections from the previous layer and the neurons


def get_pruning_percents(
        layer_weight_counts: list[int], 
        first_step_pruning_percent: float,
        target_sparsity: float
        ) -> list[np.array]:
    """
    Function to get arrays of model sparsity at each step of pruning based on a constant pruning %
    applied to nonzer-parameters.
    """

    def total_sparsity(
            original_weight_counts: list[int], 
            current_weight_counts: list[int]
            ) -> float:
        """
        Helper function to calculate total sparsity of parameters.
        """
        return np.sum(current_weight_counts) / np.sum(original_weight_counts)
    
    def sparsify(
            original_weight_counts: list[int], 
            current_weight_counts: list[int], 
            original_pruning_percent: float
            ) -> list[float]:
        sparsities: list[float] = []
        for idx, (original, current) in enumerate(zip(original_weight_counts, current_weight_counts)):
            if original == 0:
                continue
            new_weight_count: int = np.round(current * (1 - original_pruning_percent))
            sparsities.append((original - new_weight_count) / original)
            current_weight_counts[idx] = new_weight_count
        return np.round(np.mean(sparsities), decimals=5)
    
    sparsities: list[float] = []
    
    # Elementwise copy
    current_weight_counts: list[int] = [weight_count for weight_count in layer_weight_counts]
    
    while total_sparsity(layer_weight_counts, current_weight_counts) > target_sparsity:
        sparsities.append(sparsify(layer_weight_counts, current_weight_counts, first_step_pruning_percent))

    return sparsities


In [None]:
count_params(model)

In [None]:
mask_model: keras.Model = copy.deepcopy(sparsity.strip_pruning(original_mask_model))

In [None]:
# Testing pruning

# Copy originals
model: keras.Model = copy.deepcopy(original_model)
mask_model: keras.Model = copy.deepcopy(sparsity.strip_pruning(original_mask_model))

# Sanity Check
test_loss, test_accuracy = test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}\n')

trained_model, pruned_mask_model, training_round = train(0, 0, model, mask_model, load_and_process_mnist, batch_size=len(X_train))

print(f'\nTook {np.sum(training_round.test_accuracies != 0)} / {C.TRAINING_EPOCHS} epochs')
print(f'Ended with a best training accuracy of {np.max(training_round.train_accuracies) * 100:.2f}% and test accuracy of training accuracy of {np.max(training_round.test_accuracies) * 100:.2f}%')

print(f'Test Accuracies:')
print(training_round.test_accuracies)
print(f'Training Accuracies:')
print(training_round.train_accuracies)

# Get test parameters
test_loss, test_accuracy = test_step(trained_model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

In [None]:
# Create the callbacks, pruning params, and pruning percents
callbacks: list[tf.keras.callbacks.Callback] = get_pruning_callbacks()
pruning_parameters: dict = create_pruning_parameters(target_sparsity, 0, end_step, frequency)
layer_weight_counts: list[int] = get_layer_weight_counts(trained_model)
pruning_percents: list[float] = get_pruning_percents(layer_weight_counts, step_pruning_percent, target_sparsity)
print(pruning_percents)

In [None]:
mask_model.summary()

## Experiments

In [None]:
def run_experiment(
    random_seed: int,
    create_model: callable,
    create_masked_model: callable,
    pruning_percents: np.array,
    ) -> ExperimentData:
  pass