In [None]:
"""
experiments.ipynb

File for running lottery ticket experiments.

Authors: Jordan Bourdeau, Casey Forey
Date Created: 3/8/24
"""

%load_ext tensorboard
import copy
import functools
from importlib import reload
import numpy as np
import os
import random
import tempfile
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow import keras
from keras.callbacks import Callback
from keras import backend as K
from keras import Sequential
from keras.layers import Dense, Input
from keras.losses import CategoricalCrossentropy

from tensorflow_model_optimization.sparsity import keras as sparsity
from tensorflow_model_optimization.sparsity.keras import ConstantSparsity, PolynomialDecay, prune_low_magnitude

import src.harness.constants as C
import src.harness.dataset as dataset
import src.harness.experiment as experiment
import src.harness.model as mod
import src.harness.paths as paths
import src.harness.pruning as prune
import src.harness.training as train
import src.harness.utils as utils

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

## Parameters

In [None]:
reload(mod)
reload(prune)

X_train, X_test, Y_train, Y_test = dataset.load_and_process_mnist()

num_epochs: int = 10
input_shape: tuple = X_train[0].shape
num_classes: int = 10
batch_size: int = len(X_train)

num_train_samples: int = X_train.shape[0]

# Make pruning parameters
target_sparsity: float = 0.01
step_pruning_percent: float = 0.2
end_step: int = np.ceil(1.0 * num_train_samples / batch_size).astype(np.int32) * num_epochs
frequency: int = 1
create_pruning_parameters: callable = functools.partial(prune.create_pruning_parameters, target_sparsity, 0, end_step, frequency)

## Building

## Training

In [None]:
reload(mod)
reload(paths)
reload(train)

make_lenet: callable = functools.partial(mod.create_lenet_300_100, input_shape, num_classes)

# Hardcoded for now until actual experiment code is done
for seed in range(2):
    utils.set_seed(seed)
    # Make models
    model: keras.Model = make_lenet()
    mask_model: keras.Model = mod.create_masked_nn(mod.create_pruned_lenet, input_shape, num_classes, create_pruning_parameters)   

    mod.save_model(model, seed, 0, initial=True)
    
    # We aren't actually doing any pruning right now
    for pruning_step in range(2):
        loss_fn: tf.keras.losses.Loss = C.LOSS_FUNCTION()
        accuracy_metric: tf.keras.metrics.Metric = tf.keras.metrics.CategoricalAccuracy()

        # Sanity Check
        test_loss, test_accuracy = train.test_step(model, X_test, Y_test, loss_fn, accuracy_metric)
        print(f'\nModel {seed}\nStarting Test Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}\n')

        trained_model, pruned_mask_model, training_round = train.train(seed, pruning_step, model, mask_model, dataset.load_and_process_mnist, batch_size=C.BATCH_SIZE)

        print(f'\nTook {np.sum(training_round.test_accuracies != 0)} / {C.TRAINING_EPOCHS} epochs')
        print(f'Ended with a best training accuracy of {np.max(training_round.train_accuracies) * 100:.2f}% and test accuracy of training accuracy of {np.max(training_round.test_accuracies) * 100:.2f}%')

        print(f'Test Accuracies:')
        print(training_round.test_accuracies)
        print(f'Training Accuracies:')
        print(training_round.train_accuracies)

        # Get test parameters
        test_loss, test_accuracy = train.test_step(trained_model, X_test, Y_test, loss_fn, accuracy_metric)
        print(f'\nTest Loss: {test_loss:.6f}, Test Accuracy: {test_accuracy:.6f}')