# Import

In [None]:
import os
import numpy as np
from sklearn.model_selection import train_test_split
import tensorflow as tf
from modules_demo import models
from matplotlib import pyplot as plt
import copy
import pickle as pkl

# set only the specified GPUs as "available"
os.environ["CUDA_VISIBLE_DEVICES"]= "7"

# set a seed
SEED = 5435

### Set SEED

In [None]:
# set seeds for reproducible results
os.environ['PYTHONHASHSEED'] = str(SEED)
tf.random.set_seed(SEED)
np.random.seed(SEED)

# prepare GPU
physical_gpus = tf.config.list_physical_devices('GPU')

if physical_gpus:
    try:
        for _gpu in physical_gpus:
            tf.config.experimental.set_memory_growth(_gpu, True)
    except RuntimeError as e:
        print(e)
logical_gpus = tf.config.list_logical_devices('GPU')
print('number of Logical GPUs:', len(logical_gpus))

## Hyper-Parameters

In [None]:
# weight of the crossentropy(H) inside the loss function
# loss = loss_weight*H + (1-loss_weight)*regularization_term
# The final pruning intensity is stricktly linked to the loss_weight:
# the lower the strongher the pruning.
loss_weight = 0.1

# threshold used to binarize both the probability masks at inferce time:
threshold_to_binarize_mask_values = 0.5

# Used at training time by the loss "Mask_lp_distance_from_binary"
# to promote (values close to 1) or not (values close to 0.5) the
# final binarization and pruning strength of the probability masks
regularization_binarizer = 0.9

batch_size = 64

# Max number of epochs allowed before training stops 
# (Early Stop terminates the training if it converges)
# if 0 the training is skipped
# train the model without the pruning masks
max_epochs_initial_train = 0

# pruning training (both weights and masks)
max_epochs_pruning_train = 1000
# for "one-step-train" set 0 and 1000
# for "two-step-train" set 1000 and 1000

# corresponding learning rate
# we suggest to use a lower lr for the pruning train, e.g., 10x lower
learning_rate_initial_train = 1e-3
learning_rate_pruning_train = 1e-4

# where to save the model weights, if None: skip savings
weights_path = None

## Dataset

In [None]:
# load cifar100
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()

# hot labeling
y_train = tf.keras.utils.to_categorical(y_train, 100)
y_test = tf.keras.utils.to_categorical(y_test, 100)

# create validation set
x_train, x_val, y_train, y_val = train_test_split(
        x_train,
        y_train, 
        test_size=0.2, 
        random_state=SEED,
    )

# augment the size of the images
new_image_shape = (224, 224)

to_my_tensor_x = lambda x: tf.image.resize(
                            tf.convert_to_tensor(x),
                            new_image_shape)

# create tensors
x_train = to_my_tensor_x(x_train)
x_val = to_my_tensor_x(x_val)
x_test = to_my_tensor_x(x_test)

## Retrieve model

In [None]:
input_shape = np.shape(x_train[0])

# retrieve AlexNet
model = models.model_AlexNet(
        input_shape,
    )

model.summary()

In [None]:
# retrieve the number of weights of the model
# only the trainable weights: "model.trainable_weights"
w_tmp_list = model.weights
tot_par_init = 0
for w_tmp in w_tmp_list:
    par_tmp = 1
    for shape_tmp in np.shape(w_tmp):
        par_tmp = par_tmp * shape_tmp
    tot_par_init = tot_par_init + par_tmp
print(f'Number of weights of the model BEFORE PRUNING: {tot_par_init}')

### Prepare model for training

In [None]:
model.compile(
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate_initial_train),
        loss = tf.keras.losses.CategoricalCrossentropy(),
        metrics = ['accuracy'],
    )

callback_list=[
        tf.keras.callbacks.TerminateOnNaN(), 
        tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor = 0.2,
                patience = 40,
                min_lr = 10e-6,
            ),
        tf.keras.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=100,
                restore_best_weights=True,
            ),
    ]

### Train

In [None]:
if max_epochs_initial_train > 0:
    model.fit(
            x = x_train,
            y = y_train,
            callbacks=callback_list,
            validation_data=(x_val, y_val),
            batch_size=batch_size,
            epochs=max_epochs_initial_train,
            verbose=2,
         );

### Model performances

In [None]:
model_evaluate = model.evaluate(x_test, y_test, batch_size = 100)

# Pruning

## Prepare model for pruning

In [None]:
# Get a "Model_for_pruning" instance. This class inherits the tf.keras.Model
# class and adds some functionalities necessary to handle the pruning train.
# The functino adds the "mask layers" to the model:
# each Dense layer in the model, apart the outout layer, is coupled with
# a binary mask that can be trained. The zeros of the binary masks
# indicate what neurons are pruned, while the ones indicate what are saved.
model_pruning = models.add_mask_pruning(
        model = model,
        model_name = 'model_pruning',
    )

# Notice that now the model is composed by "Functional" layers.
# Under the hood a "Functiional" is a tf.keras.Model, used as a layer.
# The Functional named "sub_model" contains the part of the original model
# layers, also its last layer is a Dense layer.
# Every "sub_model" is followed by a Functional "pruning_mask", that 
# takes as input the output of the previous sub_model'dense'output to create
# a trainable binary mask that has the same dimension and 
# to multiply element-wise the input with the mask (mask apply pruning).
# All but the last sub_model are followed by a mask. The output, in general,
# cannot be pruned. 
# Also another output is introduced: the concatenation of the prob_masks. This
# is created by the Functional "model_concatenate_mask" that takes all the
# prob_masks and concatenate them together. This output is used as input for 
# the regularization term, to reduce regulate the number of 1s/0s in the 
# final masks, hence to decide the intensity of the pruning.
model_pruning.summary(expand_nested=True)

### Dataset for Pruning

In [None]:
# the addition of a new output requires to change the dataset: 
# a new label (for the second output) is to be given to the model.
# In reality, the new label will be useless, it fact the regularization
# term does not require it. For this reason we create a fake label.
# To the already existsing labels we couple a fake one.
zeros = lambda y: np.zeros((len(y), ) + model_pruning.outputs[1].shape[1:])
    
[y_train_fit,
 y_val_fit,
 y_test_fit,] = [
    [y, zeros(y)]
    for y 
    in [y_train, y_val, y_test]
]

### Random Mask Initialization

In [None]:
# randomly initialize the mask so that the prob_mask has samples drawn
# from a gaussian distribution. The initialization slightly influences
# the final performances. However, it can modify the final pruning intensity.
init_prob_masks = np.random.normal(
        0.5, 
        0.1,
        np.shape(model_pruning.read_prob_masks()),
    )

# values of prob_mask must be inside the range [0, 1]
init_prob_masks = np.clip(init_prob_masks, 0.00001, 0.9999)
model_pruning.write_prob_masks(init_prob_masks)

### Prepare for Pruning

In [None]:
def Mask_lp_distance_from_binary(threshold = 0.5, p = 2, name = 'pruning'):
    
    class Loss_Mask_lp_distance_from_binary(tf.keras.losses.Loss):
        
        """
        Regularization term used to binarize the prob_mask,
        it promotes the mask to have values close to 0 or 1. The higher 
        "threshold" the more the 0s in the prob_mask.
        
        regularization = -mean(|mask-threshold|**p)
        
        Args:
            threshold: float in ]0.5, 1[, the value from where the mask
                samples are pushed away.
            p: Int, the norm degree
            name: String, the norm of the loss
        """
        
        def call(self, _, y_pred):
            diff = tf.math.abs(y_pred - self.threshold)
            diff = tf.pow(diff, tf.ones(tf.shape(y_pred)) * self.p)
            measured_sparsity = -tf.keras.backend.mean(diff)
            return measured_sparsity

        def set_attributes(self, threshold = 0.5, p = 2):
            self.threshold = threshold
            self.p = p
            return
    
    loss = Loss_Mask_lp_distance_from_binary(name = name)
    loss.set_attributes(threshold, p)
    
    return loss

In [None]:
# loss = loss_weight*crossentropy + (1-loss_weight)*regularization_term
loss = [[tf.keras.losses.CategoricalCrossentropy(name = 'entropy')],
        [Mask_lp_distance_from_binary(
            threshold = regularization_binarizer, 
            name = 'loss_pruning',
        )],
       ]

# to visualize the accuracy during training
metrics = [['accuracy',],
           [],
         ]

regularization_weight = 1 - loss_weight

loss_weights = [
        loss_weight,
        regularization_weight
    ]

# setting a lower learning rate helps the convergence of the masks
model_pruning.compile(
    optimizer = tf.keras.optimizers.Adam(
            learning_rate=learning_rate_pruning_train,
        ),
    loss = loss,
    metrics = metrics,
    loss_weights = loss_weights,
   )

callback_pruning_list=[
        tf.keras.callbacks.TerminateOnNaN(), 
        tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor = 0.2,
                patience = 40,
                min_lr = 1e-6,
            ),
        tf.keras.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=100,
                restore_best_weights=True,
            ),
    ]

In [None]:
# Before trainig, be sure to call these methods:
# by setting verbose = True, all the interested layer names are printed.
verbose = False

# Be sure that pruning is "on". 
model_pruning.activate_pruning(verbose = verbose)
# Turn the mask trainable by disabling the binarization
# and by activating the stochastic behavior.
model_pruning.set_model_for_training(verbose = verbose)
# set masks layers attribute "trainable = True"
model_pruning.set_mask_trainability(mask_trainability = True,
                                    verbose = verbose)
# set all the original layers attribute "trainable = True"
model_pruning.set_normal_layers_trainability(True, verbose = verbose)

## Pruning

In [None]:
if max_epochs_pruning_train > 0:
    model_pruning.fit(
            x_train,
            y_train_fit,
            callbacks=callback_pruning_list,
            validation_data=(x_val, y_val_fit),
            batch_size=batch_size,
            epochs=max_epochs_pruning_train,
            verbose=2,
        )
    
    # save weights
    if weights_path is not None:
        model_pruning.save_weights(weights_path);

In [None]:
# load the weights of the model
if weights_path is not None:
    model_pruning.load_weights(weights_path);

### Visualization

In [None]:
pruning_training_evaluate = model_pruning.evaluate(
        x_test,
        y_test_fit,
        batch_size = 100,
    );

In [None]:
# disable the stocharstic behavior of the mask, that becomes fixed
# and binarize the mask with a threshold_value in [0, 1]
model_pruning.set_model_for_inference(
        minval = threshold_to_binarize_mask_values-0.0001,
        maxval = threshold_to_binarize_mask_values+0.0001,
        verbose = verbose,
    )

pruning_evaluate = model_pruning.evaluate(
        x_test,
        y_test_fit,
        batch_size = 100,
    );

In [None]:
# get all the prob_masks
prob_masks = model_pruning.read_prob_masks()

# get a mask realization.
# if the model is in "infecerence mode" the mask is fixed and the 
# result is unique. If the model is in "training mode", the inference masks
# may vary.
inference_masks = model_pruning.read_inference_masks(
        threshold_value = 0.5,
        minval = threshold_to_binarize_mask_values-0.0001,
        maxval = threshold_to_binarize_mask_values+0.0001,
    )

In [None]:
# for visualization purposes we reshape the masks as a square matrix
mask_reshape = (64, 64)

[init_prob_masks_plot, prob_masks_plot, inference_masks_plot] = [
        [
            np.reshape(p, mask_reshape)
            for p
            in m
        ]
        for m
        in [ init_prob_masks, prob_masks, inference_masks]
    ]

# every i-th column shows the masks associated with dense layer i-th
fig, axss = plt.subplots(3, len(prob_masks))
fig.set_figheight(10)
for axs, mask_tmp, title in zip(axss, 
                         [init_prob_masks_plot, 
                          prob_masks_plot,
                          inference_masks_plot],
                         ['initial random prob_mask',
                          'trained prob mask',
                          'final mask realization'],
                        ):
    
    for i, (ax, m) in enumerate(zip(axs, mask_tmp)):
        
        ax.imshow(m, cmap = 'gray', vmin = 0, vmax = 1)
        ax.title.set_text(title)

plt.show()
# the mean of the masks is the average number of pruned neurons
print(f'masks mean = {np.mean(prob_masks)}')

In [None]:
# The histogram of the prob mask shows how much the prob mask is binarized.
# In general, the more it is binary, the higher the performances, but the fewer
# the possible levels of pruning one can obtain by varying the threshold that
# binarizes the mask
plt.hist(np.ravel(prob_masks_plot), 100);
plt.title('histogram all of prob masks')
plt.xlabel('bins')
plt.ylabel('count')
plt.show()

In [None]:
print(
    f'initial model acc = {model_evaluate[1]*100}%',
    f'\npruning model acc = {pruning_evaluate[3]*100}%',
    f'\n\n% of remaining neurons = {100*np.mean(inference_masks)}%',
    f'\n% of pruned neurons = {100-100*np.mean(inference_masks)}%',
)

# Pruned Model

In [None]:
# return the pruned model, that does not have the mask layers
# any more, and the Dense layers have been pruned (only the neurons
# corresponding to the 1s in the masks are kept, along with 
# their weights).
model_pruned = model_pruning.return_pruned_model(
        minval = threshold_to_binarize_mask_values - 0.0001,
        maxval = threshold_to_binarize_mask_values + 0.0001,
        model_name = 'pruned_model',
    )

In [None]:
# Notice the model is identical to the original one, but 
# the number of parameters has been lowered (if the masks contained
# at least one 0)
model_pruned.summary()

In [None]:
model_pruned.compile(
        optimizer = tf.keras.optimizers.Adam(),
        loss = tf.keras.losses.CategoricalCrossentropy(),
        metrics = ['accuracy'],
    )

In [None]:
model_pruned_evaluate = model_pruned.evaluate(
        x_test,
        y_test,
        batch_size = 100,
);

In [None]:
print(
    f'initial model acc = {model_evaluate[1]*100}%',
    f'\npruning model acc = {pruning_evaluate[3]*100}%',
    f'\npruned model acc = {model_pruned_evaluate[1]*100}%',
    
    f'\n\n% of remaining neurons = {100*np.mean(inference_masks)}%',
    f'\n% of pruned neurons = {100-100*np.mean(inference_masks)}%',
)

In [None]:
file_name = 'results.pkl' 
file_name = 'for_sujuk.pkl'
    
try:
    with open(file_name, 'rb') as f:
        results = pkl.load(f)
except:
    results = []
        

update_results = True
for r in results:
    if r['weights_path'] == weights_path:
        update_results = False

if update_results == True:
    
    new_results = {
        'param_removed_list': param_removed_list,
        'top_1_list': top_1_list,
        'loss_weight': loss_weight,
        'threshold_to_binarize_mask_values': threshold_to_binarize_mask_values,
        'regularization_binarizer': regularization_binarizer,
        'weights_path': weights_path,
    }
    
    results += [new_results]
    with open(file_name, 'wb') as f:
        pkl.dump(results, f)

        print('results updated')

## Further Analysis

In [None]:
# explore "accuracy" and "% of pruned neurons" variations
# by changing "threshold_to_binarize_mask_values" 

# all "threshold_to_binarize_mask_values" values to explore
meanval_list = np.arange(-0.05, 1, 0.05)

for meanval in meanval_list:
    
    maxval = meanval + 0.0001
    minval = meanval - 0.0001
    
    # create a pruned model based on prob_masks > meanval
    model_tmp = model_pruning.return_pruned_model(
            minval = minval,
            maxval = maxval,
            model_name = 'pruned_model',
        )
    
    # compile before running "evaluate"
    model_tmp.compile(
            optimizer = tf.keras.optimizers.Adam(),
            loss = tf.keras.losses.CategoricalCrossentropy(),
            metrics = ['accuracy'],
        )
    
    # read the inference masks to compute the % of pruned parameters
    inference_masks = model_pruning.read_inference_masks(
            threshold_value = 0.5,
            minval = minval,
            maxval = maxval,
            silence_alert = True,
        )
    
    # find accuracy of the pruned model
    evaluate_tmp = model_tmp.evaluate(
            x_test,
            y_test,
            batch_size = 100,
            verbose = False,
        )

    
    # find the number of weights of the pruned model
    w_tmp_list = model_tmp.weights
    tot_par_tmp = 0
    for w_tmp in w_tmp_list:
        par_tmp = 1
        for shape_tmp in np.shape(w_tmp):
            par_tmp = par_tmp * shape_tmp
        tot_par_tmp = tot_par_tmp + par_tmp
    
    # compute the % of pruned parameters
    percentage_pruned_param = (tot_par_init - tot_par_tmp)/tot_par_init*100
    
    print()
    print(f'T = {meanval} --- accuracy top 1 = {np.round(evaluate_tmp[1]*100, 5)}%')
    for i, m in enumerate(inference_masks):
        print(f'{i+1}-th Dense layer mantained neurons: {np.sum(m)}')
    print(f'-- % of pruned parameters = {np.round(percentage_pruned_param, 3)}%') 
    print()

In [None]:
# at inference, one can move the threshold to obtain 
# models with different pruning intensities and different accuracy

plt.plot(
        100-np.array(param_removed_list)[np.argsort(top_1_list)],
        np.array(top_1_list)[np.argsort(top_1_list)],
        'x--',
    )
plt.xlabel('% of PRUNED parameters')
plt.ylabel('accuracy top_1');