In [None]:
"""
testing.ipynb

File for performing testing to implement lottery ticket experiments.

Authors: Jordan Bourdeau, Casey Forey
Date Created: 3/8/24
"""

%load_ext tensorboard
import copy
import functools
from importlib import reload
import numpy as np
import os
import pickle
import tensorflow as tf
from tensorflow import keras

from src.harness import constants as C
from src.harness import dataset as ds
from src.harness import history
from src.harness import experiment
from src.harness import mixins
from src.harness import model as mod
from src.harness import paths
from src.harness import pruning
from src.harness import rewind
from src.harness import training as train
from src.harness import utils

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

## Parameters

In [None]:
reload(ds)
reload(mod)
reload(pruning)

# Select the dataset
mnist_dataset: ds.Dataset = ds.Dataset(ds.Datasets.MNIST)
X_train, X_test, Y_train, Y_test = mnist_dataset.load()
input_shape: tuple = mnist_dataset.input_shape
num_classes: int = mnist_dataset.num_classes

print(f'Input Shape: {input_shape}')
print(f'Num Classes: {num_classes}')
print(f'X_train Shape: {X_train.shape}, Y_train Shape: {Y_train.shape}')
print(f'X_test Shape: {X_test.shape}, Y_test Shape: {Y_test.shape}')

num_epochs: int = 10
batch_size: int = len(X_train)

## Building

In [None]:
reload(ds)
reload(mod)
reload(utils)

# Create a model with the same architecture using all Keras components to check its accuracy with the same parameters
utils.set_seed(0)
make_lenet: callable = functools.partial(mod.create_lenet_300_100, input_shape, num_classes)

original_model: keras.Model = make_lenet()
# original_model.summary()
# original_model.trainable_variables

original_mask_model: keras.Model = mod.create_masked_nn(make_lenet)
original_mask_model.summary()
# original_mask_model.trainable_variables

## Training

### Initialize Loss Function and Accuracy Objects

In [None]:
reload(C)
reload(train)

# Use the original model as a reference
loss_fn: tf.keras.losses.Loss = C.LOSS_FUNCTION()
accuracy_metric: tf.keras.metrics.Metric = tf.keras.metrics.CategoricalAccuracy()

test_loss, test_accuracy = train.test_step(original_model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'Test Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

### Single Step of Training

In [None]:
reload(C)
reload(train)

# Test single step of training

# Define the optimizer outside of the function
optimizer = C.OPTIMIZER()
train_one_step: callable = train.get_train_one_step()
accuracy_metric.reset_state()

# Copy originals
model: keras.Model = mod.load_model(0, 1)
mask_model: keras.Model = mod.load_model(0, 1, masks=True)

# Sanity Check
for _ in range(100):
    test_loss, test_accuracy = train.test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
    print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}\n')

epochs: int = 1
batch_size: int = len(X_train)

train_accuracies: np.array = np.zeros(epochs)
test_accuracies: np.array = np.zeros(epochs)

original_weights: list[np.ndarray] = copy.deepcopy(model.get_weights())
masks: list[np.ndarray] = copy.deepcopy(mask_model.get_weights())

# Single training steps
for _ in range(10):
    train_loss, train_accuracy = train_one_step(model, mask_model, X_train, Y_train, optimizer)
    test_loss, test_accuracy = train.test_step(model, X_test, Y_test, loss_fn, accuracy_metric)

ending_weights: list[np.ndarray] = copy.deepcopy(model.get_weights())

# Compare the masks with the starting/ending training weights to make sure the ones masked off haven't changed
# Picked a layer near the end so it would have been affected by backpropagation
for idx, (mask, start_weight, end_weight) in enumerate(zip(masks[0][4], original_weights[0][4], ending_weights[0][4])):
    if mask == 0:
        assert start_weight == end_weight, f'Weights not equal at index {idx}'
    # This could technically fail, but is unable to since even a small update would trigger it
    elif mask == 0:
        assert start_weight != end_weight, f'Weights not equal at index {idx}'

### Training Loop Function

In [None]:
reload(C)
reload(ds)
reload(train)

# Testing `training_loop` function
epochs: int = C.TRAINING_EPOCHS
batch_size: int = 60
X_train, _, _, _ = mnist_dataset.load()
num_datapoints: int = X_train.shape[0]

# Copy originals
model: keras.Model = copy.deepcopy(original_model)
mask_model: keras.Model = copy.deepcopy(original_mask_model)

# Sanity Check
test_loss, test_accuracy = train.test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}\n')

trial_data: history.TrialData = train.training_loop(0, model, mask_model, mnist_dataset, epochs, batch_size=batch_size)

iteration_count: int = np.sum(trial_data.train_accuracies != 0)
print(f'Took {iteration_count} iterations')
print(f'Ended on epoch {np.ceil(iteration_count * batch_size / num_datapoints)} out of {epochs}')
print(f'Ended with a best training accuracy of {np.max(trial_data.train_accuracies) * 100:.2f}% and test accuracy of {np.max(trial_data.test_accuracies) * 100:.2f}%')

### Train Function

In [None]:
reload(train)

# Testing `train` function

# Copy originals
model: keras.Model = copy.deepcopy(original_model)
mask_model: keras.Model = copy.deepcopy(original_mask_model)

# Sanity Check
test_loss, test_accuracy = train.test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}\n')

trial_data = train.train(0, 0, model, mask_model, mnist_dataset, batch_size=C.BATCH_SIZE)

print(f'\nTook {np.sum(trial_data.test_accuracies != 0)} / {C.TRAINING_EPOCHS} epochs')
print(f'Ended with a best training accuracy of {np.max(trial_data.train_accuracies) * 100:.2f}% and test accuracy of training accuracy of {np.max(trial_data.test_accuracies) * 100:.2f}%')

print(f'Test Accuracies:')
print(trial_data.test_accuracies)
print(f'Training Accuracies:')
print(trial_data.train_accuracies)

# Get test parameters
test_loss, test_accuracy = train.test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

In [None]:
# Test loading the model back
loaded_model: keras.Model = mod.load_model(0, 0)

# Get test parameters
test_loss, test_accuracy = train.test_step(loaded_model, X_test, Y_test, loss_fn, accuracy_metric)
print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')

## Pruning

### Layerwise Pruning

In [None]:
reload(pruning)
reload(utils)

# Test loading the model back
loaded_model: keras.Model = mod.load_model(0, 0)
mask_model: keras.Model = mod.load_model(0, 0, masks=True)
target_sparsity = 0.1

def test_pruning_sparsity(model: keras.Model, mask_model: keras.Model, target_sparsity: float):
    """
    Test function to verify pruning correctness.
    NOTE: Sensitive to boundary conditions and rounding.

    Args:
        model (keras.Model): Keras model to copy and test pruning on.
        target_sparsity (float): Target sparsity to test with.
    """
    copy_model: keras.Model = copy.deepcopy(model)
    mask_model: keras.Model = copy.deepcopy(mask_model)
    
    print(f'Test Pruning Sparsity: Target Sparsity = {target_sparsity}')
    sparse_model = copy.deepcopy(copy_model)

    total, nonzero = utils.count_total_and_nonzero_params(copy_model)
    print(f'Before Pruning: Total Params: {total}, Nonzero Params: {nonzero}')
    pruning.prune(sparse_model, mask_model, pruning.low_magnitude_pruning, target_sparsity)
    
    # Add some small wiggle room for rounding- even with output being pruned at half the rate this is correct
    error_tolerance: int = int(target_sparsity * total / 20)

    pruned_total, pruned_nonzero = utils.count_total_and_nonzero_params(sparse_model)
    print(f'After Pruning:  Total Params: {pruned_total}, Nonzero Params: {pruned_nonzero}')
    
    assert pruned_total == total
    assert np.abs(pruned_nonzero - total * target_sparsity) < error_tolerance

    sparse_layer_weight_counts: list[int] = utils.count_total_and_nonzero_params_per_layer(sparse_model)
    print(f'Layer total and nonzero weight counts: {sparse_layer_weight_counts}')

    # Test that pruning worked as expected
    for idx in range(len(sparse_layer_weight_counts))[::2]:
        total_synapses, nonzero_synapses = sparse_layer_weight_counts[idx]
        total_biases, nonzero_biases = sparse_layer_weight_counts[idx + 1]
        assert np.abs((total_synapses + total_biases) * target_sparsity - nonzero_synapses + nonzero_biases) < error_tolerance
        
    # Test that we can prune the model to half of what it is currently at as well
    target_sparsity /= 2
    total, nonzero = utils.count_total_and_nonzero_params(sparse_model)
    pruning.prune(sparse_model, mask_model, pruning.low_magnitude_pruning, target_sparsity)
    pruned_total, pruned_nonzero = utils.count_total_and_nonzero_params(sparse_model)
    
    assert np.abs(pruned_nonzero - int(total * target_sparsity)) < error_tolerance
    
def test_global_pruning(model: keras.Model, mask_model: keras.Model, target_sparsity: float):
    """
    Testing function to demonstrate correctness of global pruning.

    Args:
        model (keras.Model): Keras model to copy and test pruning on.
        target_sparsity (float): Target sparsity to test with.
    """
    
    print(f'Test Global Pruning: Target Sparsity = {target_sparsity}')
    
    # Global pruning will not necessarily have equal pruning in each layer, but overall will be correct
    sparse_model = copy.deepcopy(model)
    mask_model: keras.Model = copy.deepcopy(mask_model)
    
    print(f'Mask model: {mask_model.trainable_variables[0][:10]}')
    print(f'Sparse model: {sparse_model.trainable_variables[0][:10]}')
    
    pruning.prune(sparse_model, mask_model, pruning.low_magnitude_pruning, target_sparsity, global_pruning=True)
    total, nonzero = utils.count_total_and_nonzero_params(model)
    print(f'Before Pruning: Total Params: {total}, Nonzero Params: {nonzero}')
    
    pruned_total, pruned_nonzero = utils.count_total_and_nonzero_params(sparse_model)
    print(f'After Pruning:  Total Params: {pruned_total}, Nonzero Params: {pruned_nonzero}')
    
    # Add some small wiggle room for rounding- even with output being pruned at half the rate this is correct
    error_tolerance: int = int(target_sparsity * total / 20)
    
    pruned_total, pruned_nonzero = utils.count_total_and_nonzero_params(sparse_model)
    assert np.abs(pruned_nonzero - total * target_sparsity) < error_tolerance
    
    sparse_layer_weight_counts: list[int] = utils.count_total_and_nonzero_params_per_layer(sparse_model)
    print(f'Layer total and nonzero weight counts: {sparse_layer_weight_counts}')
    
    print(f'Mask model: {mask_model.trainable_variables[4][0][:10]}')
    print(f'Sparse model: {sparse_model.trainable_variables[4][0][:10]}')
    
print()
test_pruning_sparsity(loaded_model, mask_model, 0.5)
print()
test_global_pruning(loaded_model, mask_model, 0.5)


### Rewinding

In [None]:
reload(mod)
reload(rewind)
    
model: keras.Model = copy.deepcopy(original_model)
mask_model: keras.Model = copy.deepcopy(original_mask_model)
original_weights = model.get_weights()

rewind_to_original_weights: callable = functools.partial(rewind.rewind_to_original_init, 0)
rewind_to_random_weights: callable = functools.partial(rewind.rewind_to_random_init, 0, tf.initializers.GlorotUniform())
rewind.rewind_model_weights(model, mask_model, rewind_to_random_weights)

print(original_weights[0][0][:10])
print(model.get_weights()[0][0][:10])

### Prune Low Magnitude

In [None]:
reload(mod)
reload(pruning)

model: keras.Model = copy.deepcopy(original_model)
mask_model: keras.Model = copy.deepcopy(original_mask_model)

# Asserting that every array in the mask model's weights are 1s
for layer in mask_model.layers:
    for weights in layer.get_weights():
        assert np.all(weights == 1), "Error: Not all elements in mask model's weights are 1s after updating masks"

# Asserting that every array in the mask model's weights are still 1s
for layer in mask_model.layers:
    for weights in layer.get_weights():
        assert np.all(weights == 1), "Error: Not all elements in mask model's weights are 1s after updating masks"
        
pruning.prune(model, mask_model, pruning.low_magnitude_pruning, 0.5)

## Experiments

In [None]:
reload(experiment)
reload(pruning)
reload(rewind)
reload(train)

# Pruning Parameters
first_step_pruning: float = 0.2
target_sparsity: float = 0.10
make_lenet: callable = functools.partial(mod.create_lenet_300_100, input_shape, num_classes)

global_pruning: bool = False
sparsities: list[float] = pruning.get_sparsity_percents(model, first_step_pruning, target_sparsity)
experiment_data: history.ExperimentData = experiment.run_iterative_pruning_experiment(
    0, 
    make_lenet, 
    mnist_dataset,
    sparsities,
)

In [None]:
loaded_model: keras.Model = mod.load_model(0, 0, initial=True)
print(loaded_model.get_weights()[0][0])

for round in experiment_data.pruning_rounds:
    print(round)
    print(round.masks[0][0])
    print(round.initial_weights[0][0])
    print(round.final_weights[0][0])

In [None]:
reload(experiment)
reload(paths)
reload(pruning)
reload(rewind)
reload(train)

# Test high level experiment API
experiment_directory: str = 'testing_experiment'
experiment_summary: history.ExperimentSummary = experiment.run_experiments(
    1, 
    experiment_directory,
    functools.partial(experiment.get_lenet_300_100_experiment_parameters, ds.Datasets.MNIST, 0.2, 0.65),
    experiment.run_iterative_pruning_experiment,
)

In [None]:
print(len(experiment_summary.experiments[0].pruning_rounds))

# TODO: Check out why the final weights aren't still masked off
for seed, experiment_data in experiment_summary.experiments.items():
    original_model: keras.Model = mod.load_model(seed, 0, initial=True)
    for round in experiment_data.pruning_rounds:
        print(f'Pruning Round {round.pruning_step}')
        # print(f'Original Model:')
        # print(original_model.get_weights()[4][0][:20])
        # print('Masks:')
        # print(mask_model.get_weights()[4][0][:20])
        print('Initial Weights:')
        print(round.initial_weights[4][0][:20])
        print('Final Weights')
        print(round.final_weights[4][0][:20])
        # print()
        # break
        

In [None]:
print(len(experiment_summary.experiments[0].pruning_rounds))

# TODO: Check out why the final weights aren't still masked off
for seed, experiment_data in experiment_summary.experiments.items():
    original_model: keras.Model = mod.load_model(seed, 0, initial=True)
    for round in experiment_data.pruning_rounds:
        print(f'Pruning Round {round.pruning_step}')
        # print(f'Original Model:')
        # print(original_model.get_weights()[4][0][:20])
        # print('Masks:')
        # print(mask_model.get_weights()[4][0][:20])
        print('Initial Weights:')
        print(round.initial_weights[4][0][:20])
        print('Final Weights')
        print(round.final_weights[4][0][:20])
        # print()
        # break
        

In [None]:
# TODO: Check out why the final weights aren't still masked off
loaded_experiment_summary = history.ExperimentData.load_from(os.path.join('testing_experiment', 'experiment_summary.pkl'))
for seed, experiment_data in loaded_experiment_summary.experiments.items():
    original_model: keras.Model = mod.load_model(seed, 0, initial=True)
    for round in experiment_data.pruning_rounds:
        print(f'Pruning Round {round.pruning_step}')
        mask_model: keras.Model = mod.load_model(seed, round.pruning_step, masks=True)
        print(f'Original Model:')
        print(original_model.get_weights()[4][0][:20])
        print('Masks:')
        print(mask_model.get_weights()[4][0][:20])
        print('Initial Weights:')
        print(round.initial_weights[4][0][:20])
        print('Final Weights')
        print(round.final_weights[4][0][:20])
        print()
        