In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load in 

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# Any results you write to the current directory are saved as output.

In [None]:
from tensorflow.keras.layers import DepthwiseConv2D, Conv2D, BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import MaxPooling2D, GlobalAveragePooling2D, Dense, Dropout, Flatten
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from keras.callbacks import ReduceLROnPlateau, CSVLogger, EarlyStopping
from tensorflow import keras
import tensorflow as tf

import zipfile
import os

from keras.datasets import cifar10
from keras.datasets import cifar100
from keras.utils.np_utils import to_categorical

import numpy as np

import pandas as pd

# Train and Validate pruned and unpruned MobileNet models on Cifar10

## Obtain training, validation and test datasets

In [None]:
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train, X_test = X_train.astype('float32') / 255.0, X_test.astype('float32') / 255.0
y_train, y_test = to_categorical(y_train), to_categorical(y_test)

X_val, y_val = X_train[0:10000], y_train[0:10000]
X_train, y_train = X_train[10000:], y_train[10000:] 

## Define Model

In [None]:
def MobileNet(num_classes, alpha, delta):
    img_input = keras.Input(shape=(32, 32, 3))
    x = Conv2D(int(32 * alpha), (3, 3), strides=1, padding='valid', kernel_initializer='he_normal', use_bias=False)(img_input)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='valid', data_format='channels_last')(x)
    
    x = DepthwiseConv2D((3, 3), strides=(1, 1), depth_multiplier=delta, kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv2D(int(64*alpha), (3, 3), strides=(1, 1), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='same', data_format='channels_last')(x)
    x = Dropout(0.1)(x)
    
    x = DepthwiseConv2D((3, 3), strides=(1, 1), depth_multiplier=delta, padding='valid', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv2D(int(128*alpha), (1, 1), strides=(1, 1), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = DepthwiseConv2D((3, 3), strides=(1, 1), depth_multiplier=delta, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(int(128*alpha), (1, 1), strides=(1, 1), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='same', data_format='channels_last')(x)
    x = Dropout(0.1)(x)
    
    x = DepthwiseConv2D((3, 3), strides=(1, 1), depth_multiplier=delta, padding='valid', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv2D(int(256*alpha), (1, 1), strides=(1, 1), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = DepthwiseConv2D((3, 3), strides=(1, 1), depth_multiplier=delta, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv2D(int(256*alpha), (1, 1), strides=(1, 1), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='same', data_format='channels_last')(x)
    x = Dropout(0.1)(x)
    
    x = DepthwiseConv2D((3, 3), strides=(1, 1), depth_multiplier=delta, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    x = Conv2D(int(512*alpha), (1, 1), strides=(1, 1), padding='valid', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    for _ in range(5):
        x = DepthwiseConv2D((3, 3), strides=(1, 1), depth_multiplier=delta, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

        x = Conv2D(int(512*alpha), (1, 1), strides=(1, 1), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
    
    x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='same', data_format='channels_last')(x)
    x = Dropout(0.1)(x)
    
    x = DepthwiseConv2D((3, 3), strides=(2, 2), depth_multiplier=delta, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(int(1024*alpha), (1, 1), strides=(1, 1), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = DepthwiseConv2D((3, 3), strides=(1, 1), depth_multiplier=delta, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(int(1024*alpha), (1, 1), strides=(1, 1), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4), use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = GlobalAveragePooling2D()(x)
    x = Flatten()(x)
    x = Dropout(0.1)(x)
    x = Dense(int(1024*alpha), activation='relu')(x)
    x = Dropout(0.1)(x)
    out = Dense(num_classes, activation='softmax')(x)
  
    model = keras.Model(img_input, out)
    return model

## Define custom training loop

In [None]:
def pruningTrain(model, max_epochs, pruning_schedule, batch_size, weight_threshold, gradient_threshold, alpha, delta):
    
    training_dictionary = {'epoch': [],
                  'loss': [],
                  'accuracy': [],
                  'val_loss': [],
                  'val_accuracy': []}
    
    pruning_dictionary = {'epoch': [],
                         'epoch_pruned': [],
                         'total_pruned': [],
                         'val_loss': [],
                         'val_accuracy': []}
    
    loss_fn = keras.losses.CategoricalCrossentropy()

    train_acc_metric = keras.metrics.CategoricalAccuracy()
    val_acc_metric = keras.metrics.CategoricalAccuracy()
    
    # bitmask tensors to prevent pruned weights from being updated again
    bitmask_tensors = [np.ones(layer.shape) for layer in model.trainable_weights]
    
    train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

    val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))
    val_dataset = val_dataset.batch(batch_size)
    
    # Validation accuracy trackers to reduce learning rate on plateau
    max_val_acc, max_val_acc_epoch = 0.0, 0
    patience = 5
    
    # total number of weights pruned thus far 
    total_pruned = 0
    
    for epoch in range(1, max_epochs):
        print('Epoch: {}'.format(epoch))
        if epoch % pruning_schedule == 0:
            # compute gradients across the validation dataset
            with tf.GradientTape() as tape:
                logits = model(X_val)
                loss = loss_fn(y_val, logits)

            grads = tape.gradient(loss, model.trainable_weights)   
            # weights pruned this epoch
            epoch_pruned = 0
            for i in range(len(model.trainable_weights)):
                trainable_weights = keras.backend.get_value(model.trainable_weights[i])
                # previously pruned weights for this layer
                prev_pruned = (bitmask_tensors[i] == 0.0).sum()
                bool_weights, bool_grads = abs(trainable_weights) < weight_threshold, abs(grads[i]) < gradient_threshold
                indices = np.logical_not(np.logical_and(bool_weights, bool_grads)).astype(int)
                bitmask_tensors[i], trainable_weights = np.multiply(bitmask_tensors[i], indices), np.multiply(trainable_weights, indices)      
                # all pruned weights for this layer
                layer_pruned = (bitmask_tensors[i] == 0.0).sum()
                # weights pruned this epoch for this layer
                epoch_layer_pruned = layer_pruned - prev_pruned
                epoch_pruned += epoch_layer_pruned
                model.trainable_weights[i].assign(trainable_weights)
            total_pruned += epoch_pruned
            print('Pruning complete')
            print('Weights pruned this epoch: {}, total weights pruned thus far: {}'.format(epoch_pruned, total_pruned))
            # Run a validation loop at the end of each pruning schedule
            loss, count = 0.0, 0.0
            for (x_batch_val, y_batch_val) in val_dataset:
                val_logits = model(x_batch_val)
                val_acc_metric(y_batch_val, val_logits)
                loss += loss_fn(y_batch_val, val_logits)
                count += 1.0
            val_acc = val_acc_metric.result()
            val_loss = float(loss / count)
            print('Validation accuracy after pruning: {}, loss after pruning: {}'.format(float(val_acc), val_loss))
            
            pruning_dictionary['epoch'].append(epoch)
            pruning_dictionary['epoch_pruned'].append(epoch_pruned)
            pruning_dictionary['total_pruned'].append(total_pruned)
            pruning_dictionary['val_loss'].append(val_loss)
            pruning_dictionary['val_accuracy'].append(val_acc)
        else:
            total_loss, counter = 0.0, 0.0
            for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
                with tf.GradientTape() as tape:
                    logits = model(x_batch_train, training=True)
                    loss = loss_fn(y_batch_train, logits)
                total_loss += loss
                grads = tape.gradient(loss, model.trainable_weights)  
                # apply bitmask tensors to gradients 
                for i in range(len(grads)):
                    grads[i] = np.multiply(grads[i], bitmask_tensors[i]) 
                model.optimizer.apply_gradients(zip(grads, model.trainable_weights))
                train_acc_metric(y_batch_train, logits)
                counter += 1.0
            train_loss_val = float(total_loss / counter)
            train_acc_val = train_acc_metric.result()
            print('Training accuracy over epoch: {}, loss over epoch: {}'.format(float(train_acc_val), train_loss_val))
            
            # Run a validation loop at the end of each epoch
            total_loss, counter = 0.0, 0.0
            for (x_batch_val, y_batch_val) in val_dataset:
                val_logits = model(x_batch_val)
                val_acc_metric(y_batch_val, val_logits)
                total_loss += loss_fn(y_batch_val, val_logits)
                counter += 1.0
            val_loss = float(total_loss / counter)
            val_acc = val_acc_metric.result()
            print('Validation accuracy over epoch: {}, loss over epoch: {}'.format(float(val_acc), val_loss))
            
            # Reduce learning rate by a factor of 0.2 if validation accuracy has stagnated
            if val_acc > max_val_acc:
                max_val_acc = val_acc
                max_val_epoch = epoch
            
            elif epoch - max_val_epoch >= patience:
                model.optimizer.lr.assign(model.optimizer.lr * 0.2)
                print('Learning rate now is {}'.format(model.optimizer.lr))
                max_val_acc, max_val_epoch = val_acc, epoch
            
            training_dictionary['epoch'].append(epoch)
            training_dictionary['loss'].append(train_loss_val)
            training_dictionary['accuracy'].append(float(train_acc_val))
            training_dictionary['val_loss'].append(val_loss)
            training_dictionary['val_accuracy'].append(float(val_acc))
            
            val_acc_metric.reset_states()
            train_acc_metric.reset_states()
    #save training and pruning history
    training_file = 'mobilenet_pruned_training_alpha_' + str(alpha) + '_depth_' + str(delta) + '_pr_' + str(pruning_schedule) + '_wt_' + str(weight_threshold) + '_gt_' + str(gradient_threshold) + '_cifar10.csv'
    pruning_file = 'mobilenet_pruned_pruning_alpha_' + str(alpha) + '_depth_' + str(delta) + '_pr_' + str(pruning_schedule) +'_wt_' + str(weight_threshold) + '_gt_' + str(gradient_threshold) + '_cifar10.csv'
    training_df, pruning_df = pd.DataFrame(training_dictionary), pd.DataFrame(pruning_dictionary)
    training_df.to_csv(training_file, index=False)
    pruning_df.to_csv(pruning_file, index=False)

## MobileNet (alpha = 0.5, depth = 4) unpruned

In [None]:
model = MobileNet(10, 0.5, 4)

lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1), cooldown=0, verbose=1, patience=5, min_lr=1e-6)

model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])
logger = CSVLogger('mobilenet_unpruned_alpha_0.5_depth_4_cifar10.csv')
model.count_params()

In [None]:
history_unpruned_10_alpha_0_5_depth_4 = model.fit(X_train, y_train, epochs=max_epochs, validation_data=(X_val, y_val), shuffle=True, callbacks=[logger, lr_reducer])

In [None]:
tf.keras.models.save_model(model, 'unpruned_cifar10_alpha_0.5_depth_4.h5', include_optimizer=False)
unpruned_10_alpha_0_5_depth_4_zip = 'unpruned_10_alpha_0.5_depth_4.zip'
with zipfile.ZipFile(unpruned_10_alpha_0_5_depth_4_zip, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write('unpruned_cifar10_alpha_0.5_depth_4.h5')
print('Size of the pruned model before compression: {} MB'.format(os.path.getsize('unpruned_cifar10_alpha_0.5_depth_4.h5') / float(2**20)))
print('Size of the pruned model after compression: {} MB'.format(os.path.getsize(unpruned_10_alpha_0_5_depth_4_zip) / float(2**20)))

## MobileNet (alpha = 0.5, depth = 4) pruned

In [None]:
model = MobileNet(10, 0.5, 4)

model.compile(optimizer='Adam',
             loss='categorical_crossentropy',
             metrics=['accuracy'])

In [None]:
alpha, delta = 0.5, 4
# number of epochs to train the model
max_epochs = 65
# pruning frequency
pruning_schedule = 6
# mini-batch size
batch_size = 128
# Weight threshold: weights below this value will be pruned from the network at each pruning step
weight_threshold = 0.05
# Gradient threshold: gradients below this threshold indicate that the weight has reached its resting value
gradient_threshold = 0.05

pruningTrain(model, max_epochs, pruning_schedule, batch_size, weight_threshold, gradient_threshold, 0.5, 4)

## MobileNet (alpha = 1, depth = 4) unpruned

In [None]:
model = MobileNet(10, 1, 4)

lr_reducer = ReduceLROnPlateau(monitor='val_accuracy', factor=0.2, verbose=1, cooldown=0, patience=5, min_lr=1e-6)

model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])
logger = CSVLogger('mobilenet_unpruned_alpha_1_depth_4_cifar10.csv')
model.count_params()

In [None]:
history_unpruned_10_alpha_1_depth_4 = model.fit(X_train, y_train, epochs=max_epochs, validation_data=(X_val, y_val), shuffle=True, callbacks=[lr_reducer, logger])

In [None]:
tf.keras.models.save_model(model, 'unpruned_cifar10_alpha_1_depth_4.h5', include_optimizer=False)
unpruned_10_alpha_1_depth_4_zip = 'unpruned_10_alpha_1_depth_4.zip'
with zipfile.ZipFile(unpruned_10_alpha_1_depth_4_zip, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write('unpruned_cifar10_alpha_1_depth_4.h5')
print('Size of the pruned model before compression: {} MB'.format(os.path.getsize('unpruned_cifar10_alpha_1_depth_4.h5') / float(2**20)))
print('Size of the pruned model after compression: {} MB'.format(os.path.getsize(unpruned_10_alpha_1_depth_4_zip) / float(2**20)))

## MobileNet (alpha = 1, delta = 4) pruned

In [None]:
model = MobileNet(10, 1, 4)

model.compile(optimizer='Adam',
             loss='categorical_crossentropy',
             metrics=['accuracy'])

In [None]:
# number of epochs to train the model
max_epochs = 100
# pruning frequency
pruning_schedule = 12
# mini-batch size
batch_size = 128
# Weight threshold: weights below this value will be pruned from the network at each pruning step
weight_threshold = 0.05
# Gradient threshold: gradients below this threshold indicate that the weight has reached its resting value
gradient_threshold = 0.05
alpha, delta = 1, 4

pruningTrain(model, max_epochs, pruning_schedule, batch_size, weight_threshold, gradient_threshold, 1, 4)

## MobileNet (alpha = 0.5, depth = 2) unpruned

In [None]:
model = MobileNet(10, 0.5, 2)
model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])
logger = CSVLogger('mobilenet_unpruned_alpha_0.5_depth_2_cifar10.csv')
model.count_params()

In [None]:
history_unpruned_10_alpha_0_5_depth_2 = model.fit(X_train, y_train, epochs=max_epochs, validation_data=(X_val, y_val), shuffle=True, callbacks=[lr_reducer, logger])

In [None]:
tf.keras.models.save_model(model, 'unpruned_cifar10_alpha_0.5_depth_2.h5', include_optimizer=False)
unpruned_10_alpha_0_5_depth_2_zip = 'unpruned_10_alpha_0.5_depth_2.zip'
with zipfile.ZipFile(unpruned_10_alpha_0_5_depth_2_zip, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write('unpruned_cifar10_alpha_0.5_depth_2.h5')
print('Size of the pruned model before compression: {} MB'.format(os.path.getsize('unpruned_cifar10_alpha_0.5_depth_2.h5') / float(2**20)))
print('Size of the pruned model after compression: {} MB'.format(os.path.getsize(unpruned_10_alpha_0_5_depth_2_zip) / float(2**20)))

## MobileNet (alpha = 0.5, depth = 2) pruned

In [None]:
model = MobileNet(10, 0.5, 2)
model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
# number of epochs to train the model
max_epochs = 75
# pruning frequency
pruning_schedule = 6
# mini-batch size
batch_size = 128
# Weight threshold: weights below this value will be pruned from the network at each pruning step
weight_threshold = 0.05
# Gradient threshold: gradients below this threshold indicate that the weight has reached its resting value
gradient_threshold = 0.05
alpha, delta = 0.5, 2
pruningTrain(model, max_epochs, pruning_schedule, batch_size, weight_threshold, gradient_threshold, 1, 4)

## MobileNet (alpha = 1, depth = 2) unpruned

In [None]:
model = MobileNet(10, 1, 2)
model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])
logger = CSVLogger('mobilenet_unpruned_alpha_1_depth_2_cifar10.csv')
model.count_params()

In [None]:
history_unpruned_10_alpha_1_depth_2 = model.fit(X_train, y_train, epochs=max_epochs, validation_data=(X_val, y_val), shuffle=True, callbacks=[lr_reducer, logger])

In [None]:
tf.keras.models.save_model(model, 'unpruned_cifar10_alpha_1_depth_2.h5', include_optimizer=False)
unpruned_10_alpha_1_depth_2_zip = 'unpruned_10_alpha_1_depth_2.zip'
with zipfile.ZipFile(unpruned_10_alpha_1_depth_2_zip, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write('unpruned_cifar10_alpha_1_depth_2.h5')
print('Size of the pruned model before compression: {} MB'.format(os.path.getsize('unpruned_cifar10_alpha_1_depth_2.h5') / float(2**20)))
print('Size of the pruned model after compression: {} MB'.format(os.path.getsize(unpruned_10_alpha_1_depth_2_zip) / float(2**20)))

## MobileNet (alpha = 2, depth = 1) pruned

In [None]:
model = MobileNet(10, 1, 2)
model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
# number of epochs to train the model
max_epochs = 75
# pruning frequency
pruning_schedule = 6
# mini-batch size
batch_size = 128
# Weight threshold: weights below this value will be pruned from the network at each pruning step
weight_threshold = 0.05
# Gradient threshold: gradients below this threshold indicate that the weight has reached its resting value
gradient_threshold = 0.05
alpha, delta = 1, 2

pruningTrain(model, max_epochs, pruning_schedule, batch_size, weight_threshold, gradient_threshold, 1, 4)

### MobileNet baseline (alpha = 1, delta = 1) unpruned

In [None]:
model = MobileNet(10, 1, 1)
model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])
logger = CSVLogger('mobilenet_unpruned_alpha_1_depth_1_cifar10.csv')
model.count_params()

In [None]:
history_unpruned_10_alpha_1_depth_1 = model.fit(X_train, y_train, epochs=max_epochs, validation_data=(X_val, y_val), shuffle=True, callbacks=[lr_reducer, logger])

In [None]:
tf.keras.models.save_model(model, 'unpruned_cifar10_alpha_1_depth_1.h5', include_optimizer=False)
unpruned_10_alpha_1_depth_1_zip = 'unpruned_10_alpha_1_depth_1.zip'
with zipfile.ZipFile(unpruned_10_alpha_1_depth_1_zip, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write('unpruned_cifar10_alpha_1_depth_1.h5')
print('Size of the pruned model before compression: {} MB'.format(os.path.getsize('unpruned_cifar10_alpha_1_depth_1.h5') / float(2**20)))
print('Size of the pruned model after compression: {} MB'.format(os.path.getsize(unpruned_10_alpha_1_depth_1_zip) / float(2**20)))

## MobileNet baseline (alpha = 1, delta = 1) pruned

In [None]:
model = MobileNet(10, 1, 1)
model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
# number of epochs to train the model
max_epochs = 75
# pruning frequency
pruning_schedule = 6
# mini-batch size
batch_size = 128
# Weight threshold: weights below this value will be pruned from the network at each pruning step
weight_threshold = 0.05
# Gradient threshold: gradients below this threshold indicate that the weight has reached its resting value
gradient_threshold = 0.05
alpha, delta = 1, 1

pruningTrain(model, max_epochs, pruning_schedule, batch_size, weight_threshold, gradient_threshold, 1, 4)