# Iterative pruning pipeline
Model: Multi Layer Perceptron

*Pruning functions as class methods*

In [15]:
ARCHITECTURE = 'MLP'
ITERATIONS = 10


In [16]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
import json
import numpy as np
import pandas as pd
from tensorflow.keras import layers
from tqdm import tqdm
import matplotlib.pyplot as plt
from tqdm import tqdm
import foolbox as fb


tf.compat.v1.enable_eager_execution()
tf.keras.backend.clear_session()  # For easy reset of notebook state.

# Prune, Train Attack Pipeline

In [None]:
def run_experiment(structure='unstructured', method='random', scope='global', iterations=10):
    
    experiment_name = f'{ARCHITECTURE}-{method}-{scope}-{structure}'
    pgd_success_rates = []
    cw_success_rates = []
    all_accuracies = []

    #compression_rates = [1, 2, 4, 64]
    compression_rates = [tf.math.pow(2, x).numpy() for x in range(7)]
    pruning_ratios = [1-1/x for x in compression_rates]

    for j in tqdm(range(iterations)):
        accuracies = []
        pgd_success_rate = []
        cw_success_rate = []
        try: 
            del model
        except:
            ;
        tf.compat.v1.reset_default_graph()
        tf.keras.backend.clear_session()
        
        model = initialize_base_model(j, experiment_name=experiment_name, save_weights=True)
        for index, pruning_ratio in tqdm(enumerate(pruning_ratios)):

            model.load_weights(f'./saved-weights/{experiment_name}-{j}')
            #print(f'./saved-weights/{experiment_name}-{j}')

            for i in range(index + 1):
                print(f'current pruning ratio is{pruning_ratios[i]}, goal ratio ist {pruning_ratio}')
                if i != index:

                    if  method=='random' and scope=='global' and structure=='unstructured':
                        model.prune_random_global_unstruct(pruning_ratios[i])
                    if  method=='random' and scope=='global' and structure=='structured':
                        model.prune_random_global_struct(pruning_ratios[i])
                    if  method=='random' and scope=='local' and structure=='unstructured':
                        model.prune_random_local_unstruct(pruning_ratios[i])
                    if  method=='random' and scope=='local' and structure=='structured':
                        model.prune_random_local_struct(pruning_ratios[i])
                    if  method=='magnitude' and scope=='global' and structure=='unstructured':
                        model.prune_magnitude_global_unstruct(pruning_ratios[i])
                    if  method=='magnitude' and scope=='global' and structure=='structured':
                        model.prune_magnitude_global_struct(pruning_ratios[i])
                    if  method=='magnitude' and scope=='local' and structure=='unstructured':
                        model.prune_magnitude_local_unstruct(pruning_ratios[i])
                    if  method=='magnitude' and scope=='local' and structure=='structured':
                        model.prune_magnitude_local_struct(pruning_ratios[i])


                    # must recompile otherwise pruned weights will get updated
                    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) ,
                                  metrics=['accuracy'],
                                  experimental_run_tf_function=False
                                 )
                    model = train_model(model, to_convergence=False)
                if i == index:
                    if  method=='random' and scope=='global' and structure=='unstructured':
                        model.prune_random_global_unstruct(pruning_ratios[i])
                    if  method=='random' and scope=='global' and structure=='structured':
                        model.prune_random_global_struct(pruning_ratios[i])
                    if  method=='random' and scope=='local' and structure=='unstructured':
                        model.prune_random_local_unstruct(pruning_ratios[i])
                    if  method=='random' and scope=='local' and structure=='structured':
                        model.prune_random_local_struct(pruning_ratios[i])
                    if  method=='magnitude' and scope=='global' and structure=='unstructured':
                        model.prune_magnitude_global_unstruct(pruning_ratios[i])
                    if  method=='magnitude' and scope=='global' and structure=='structured':
                        model.prune_magnitude_global_struct(pruning_ratios[i])
                    if  method=='magnitude' and scope=='local' and structure=='unstructured':
                        model.prune_magnitude_local_unstruct(pruning_ratios[i])
                    if  method=='magnitude' and scope=='local' and structure=='structured':
                        model.prune_magnitude_local_struct(pruning_ratios[i])
                    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) ,
                                  metrics=['accuracy'],
                                  experimental_run_tf_function=False
                                 )
                    model = train_model(model, to_convergence=True)
                    accuracies.append(model.evaluate(x_test, y_test, verbose=0))
                    pgd_success_rate.append(pgd_attack(model))
                    cw_success_rate.append(cw2_attack(model))
        all_accuracies.append(accuracies)
        pgd_success_rates.append(pgd_success_rate)
        cw_success_rates.append(cw_success_rate)
    #write to csv and json

    #pd.DataFrame(all_accuracies).to_csv(f'saved-results/{experiment_name}-accuracies.csv',index=False)
    with open(f'saved-results/{experiment_name}-accuracies.json', 'w') as f:
        json.dump(all_accuracies, f)

    #pd.DataFrame(pgd_success_rates).to_csv(f'saved-results/{experiment_name}-pgd-success.csv',index=False)
    with open(f'saved-results/{experiment_name}-pgd-success.json', 'w') as f:
        json.dump(pgd_success_rates, f)
    #pd.DataFrame(cw_success_rates).to_csv(f'saved-results/{experiment_name}-cw0-success.csv',index=False)
    with open(f'saved-results/{experiment_name}-cw2-success', 'w') as f:
        json.dump(cw_success_rates, f)

        
        

# Helper Functions

In [14]:
def train_model(model, to_convergence=True):
    if to_convergence == True:
        callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
        model.fit(
            x=x_train,
            y=y_train,
            batch_size=64,
            epochs=500,
            callbacks=[callback],
            validation_data=(x_test, y_test),
            )
    if to_convergence == False:
        model.fit(
            x=x_train,
            y=y_train,
            batch_size=64,
            epochs=100,
            validation_data=(x_test, y_test),
            )
    return model


def pgd_attack(model_to_attack):
    fmodel = fb.models.TensorFlowModel(model_to_attack, bounds=(0,1))
    attack = fb.attacks.LinfProjectedGradientDescentAttack()
    adversarials = attack(
        fmodel,
        x,
        y,
        epsilons=[15/255]
    )
    return np.count_nonzero(adversarials[2])/len(y)

def cw2_attack(model_to_attack):
    fmodel = fb.models.TensorFlowModel(model_to_attack, bounds=(0,1))
    attack = fb.attacks.L2CarliniWagnerAttack()
    adversarials = attack(
        fmodel,
        x,
        y,
        epsilons=[.5]
    )
    return np.count_nonzero(adversarials[2])/len(y)

def initialize_base_model(index, save_weights=False):
    model = LeNet300_100()
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) ,
                  metrics=['accuracy'],
                  experimental_run_tf_function=False
                 )

    callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
    model.fit(x=x_train,
              y=y_train,
              batch_size=64,
              epochs=1,
              callbacks=[callback],
              validation_data=(x_test, y_test),
             )
    if save_weights == True:
        model.save_weights(f'./saved-weights/{EXPERIMENT_NAME}-{index}')
    return model
    

# Load Data

In [10]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

x = tf.convert_to_tensor(x_train[:500].reshape(500,28*28))
y = tf.convert_to_tensor([y_train[:500]])[0];

# Define Model

In [11]:
class CustomLayer(layers.Layer):

    def __init__(self, units=32, activation='relu'):
        super(CustomLayer, self).__init__()
        self.units = units
        self.activation = activation

    def build(self, input_shape):
        #print(input_shape)
        self.w = self.add_weight(shape=(input_shape[-1], self.units),
                                 initializer='random_normal',
                                 trainable=True,
                                name='unpruned_weights')
        self.mask = self.add_weight(shape=(self.w.shape),
                                    initializer='ones',
                                    trainable=False,
                                   name='pruning_mask')

        
    def call(self, inputs):
        #self.mask_2 = tf.multiply(self.mask, self.mask_2)
        x = tf.multiply(self.w, self.mask)
        #print(self.pruned_w.eval())
        x = tf.matmul(inputs, x)
        
        if self.activation == 'relu':
            return tf.keras.activations.relu(x)
        if self.activation == 'softmax':
            return tf.keras.activations.softmax(x)
        raise ValueError('Activation function not implemented')

class LeNet300_100(tf.keras.Model):
    def __init__(self):
        super(LeNet300_100, self).__init__()
        self.dense1 = CustomLayer(300)
        self.dense2 = CustomLayer(100)
        self.dense3 = CustomLayer(10, activation='softmax')
        
    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        return self.dense3(x)
    
    def prune_random_local_unstruct(self, ratio):
        def prune_conv_layers_locally(self, ratio):
            conv_layer_to_prune = [0, 3]
            weights = self.get_weights()
            for layer in conv_layer_to_prune:
                converted_weights = convert_from_hwio_to_iohw(weights[layer]).numpy()
                converted_mask = convert_from_hwio_to_iohw(weights[layer + 2]).numpy()
                for input_index, input_layer in enumerate(converted_weights):
                    for kernel_index, kernel in enumerate(input_layer):
                        shape = kernel.shape
                        flat_weights = kernel.flatten()
                        flat_masks = converted_mask[input_index][kernel_index].flatten()
                        
                        no_of_weighs_to_prune = ratio * len(flat_weights)
                        # find unpruned weights
                        non_zero_weights = np.nonzero(flat_masks)[0]
                        # calculate the amount of weights to be pruned this round
                        no_of_weights_to_prune_left = int(no_of_weighs_to_prune - (len(flat_weights) - len(non_zero_weights)) )
                        # shuffle all non-zero weights
                        random.shuffle(non_zero_weights)
                        # and take the indices of the first x weights where x is the number of weights to be pruned this round
                        indices_to_delete = non_zero_weights[:no_of_weights_to_prune_left]
                        
                        for idx_to_delete in indices_to_delete:
                            flat_masks[idx_to_delete] = 0
                            flat_weights[idx_to_delete] = 0
                        converted_mask[input_index][kernel_index] = flat_masks.reshape(shape)
                        converted_weights[input_index][kernel_index] = flat_weights.reshape(shape)
                back_converted_mask = convert_from_iohw_to_hwio(converted_mask)
                back_converted_weights = convert_from_iohw_to_hwio(converted_weights)
                weights[layer] = back_converted_weights
                weights[layer+2] = back_converted_mask
            self.set_weights(weights)
            return True
        
        def prune_dense_layers_locally(self, ratio):
            dense_layer_to_prune = [6, 9, 12]
            weights = self.get_weights()
            for index, weight in enumerate(weights):
                if index in dense_layer_to_prune:
                    shape = weight.shape
                    flat_weights = weight.flatten()
                    flat_mask = weights[index+2].flatten()
                    no_of_weighs_to_prune = ratio * len(flat_weights)
                    # find unpruned weights
                    non_zero_weights = np.nonzero(flat_mask)[0]
                    # calculate the amount of weights to be pruned this round
                    no_of_weights_to_prune_left = int(no_of_weighs_to_prune - (len(flat_weights) - len(non_zero_weights)) )
                    # shuffle all non-zero weights
                    random.shuffle(non_zero_weights)
                    # and take the indices of the first x weights where x is the number of weights to be pruned this round
                    indices_to_delete = non_zero_weights[:no_of_weights_to_prune_left]
                    for idx_to_delete in indices_to_delete:
                        flat_mask[idx_to_delete] = 0
                        flat_weights[idx_to_delete] = 0

                    mask_reshaped = flat_mask.reshape(shape)
                    weights_reshaped = flat_weights.reshape(shape)
                    weights[index+2] = mask_reshaped
                    weights[index] = weights_reshaped
            self.set_weights(weights)
            return True
        prune_conv_layers_locally(self, ratio)
        prune_dense_layers_locally(self,ratio)
    
    def prune_magnitude_global_unstruct(self, ratio):
        #flat out all weights:
        conv_layer_to_prune = [0, 3]
        dense_layer_to_prune = [6, 9, 12]
        weights = self.get_weights()
        flat_weights = []
        flat_mask = []
        for x in conv_layer_to_prune + dense_layer_to_prune:
            flat_weights = np.append(flat_weights, weights[x])
            flat_mask = np.append(flat_mask, weights[x+2])
            
        no_of_weights_to_prune = int(len(flat_weights)*ratio)
        indices_to_delete = np.abs(flat_weights).argsort(0)[:no_of_weights_to_prune]
        
        for idx_to_delete in indices_to_delete:
            flat_mask[idx_to_delete] = 0
            flat_weights[idx_to_delete] = 0
        z = 0
        for x in conv_layer_to_prune + dense_layer_to_prune:
            weights[x] = flat_weights[z:z + np.prod(weights[x].shape)].reshape(weights[x].shape)
            weights[x + 2] = flat_mask[z:z + np.prod(weights[x].shape)].reshape(weights[x].shape)
            z = z + np.prod(weights[x].shape)            
        self.set_weights(weights)
        
        
            
        
        
    
    def prune_magnitude_local_unstruct(self, ratio):
        def prune_conv_layers_locally(self, ratio):
            conv_layer_to_prune = [0, 3]
            #print('inside conv prune func',get_zeros_ratio(self.get_weights()))
            weights = self.get_weights()
            
            for layer in conv_layer_to_prune:
                converted_weights = convert_from_hwio_to_iohw(weights[layer]).numpy()
                converted_mask = convert_from_hwio_to_iohw(weights[layer + 2]).numpy()
                for input_index, input_layer in enumerate(converted_weights):
                    for kernel_index, kernel in enumerate(input_layer):
                        shape = kernel.shape
                        flat_weights = kernel.flatten()
                        flat_masks = converted_mask[input_index][kernel_index].flatten()
                        #flat_weights_df = pd.DataFrame(flat_weights)
                        #flat_mask_df = pd.DataFrame(flat_masks)
                        no_of_weights_to_prune = int(len(flat_weights)*ratio)
                        #print(no_of_weights_to_prune)
                        #indices_to_delete = flat_weights_df.abs().values.argsort(0)[:no_of_weights_to_prune]
                        indices_to_delete = np.abs(flat_weights).argsort(0)[:no_of_weights_to_prune]


                        for idx_to_delete in indices_to_delete:
                            flat_masks[idx_to_delete] = 0
                            flat_weights[idx_to_delete] = 0

                        converted_mask[input_index][kernel_index] = flat_masks.reshape(shape)
                        converted_weights[input_index][kernel_index] = flat_weights.reshape(shape)
                back_converted_mask = convert_from_iohw_to_hwio(converted_mask)
                back_converted_weights = convert_from_iohw_to_hwio(converted_weights)
                weights[layer] = back_converted_weights
                weights[layer+2] = back_converted_mask
            self.set_weights(weights)
            return True
        
        def prune_dense_layers_locally(self, ratio):
            
            dense_layer_to_prune = [6, 9, 12]
            weights = self.get_weights()
            for index, weight in enumerate(weights):
                if index in dense_layer_to_prune:
                    shape = weight.shape
                    flat_weights = weight.flatten()
                    flat_mask = weights[index+2].flatten()

                    no_of_weights_to_prune = int(len(flat_weights)*ratio)
                    indices_to_delete = np.abs(flat_weights).argsort()[:no_of_weights_to_prune]
                    for idx_to_delete in indices_to_delete:
                        flat_mask[idx_to_delete] = 0
                        flat_weights[idx_to_delete] = 0
                    mask_reshaped = flat_mask.reshape(shape)
                    weights_reshaped = flat_weights.reshape(shape)
                    weights[index+2] = mask_reshaped
                    weights[index] = weights_reshaped
            self.set_weights(weights)
            return True
        prune_conv_layers_locally(self,ratio)
        prune_dense_layers_locally(self,ratio)
        return
    
    def prune_random_local_struct(self, ratio):
        def prune_conv_layers(conv_layers_to_prune, weights):
            for x in conv_layers_to_prune:

                vals = []
                iohw_weights = convert_from_hwio_to_iohw(weights[x])
                iohw_mask = convert_from_hwio_to_iohw(weights[x+2])
                converted_shape = iohw_weights.shape
                no_of_channels = converted_shape[0]*converted_shape[1]
                no_of_channels_to_prune = int(np.round(ratio * no_of_channels))
                channels = tf.reshape(iohw_weights, (no_of_channels,converted_shape[2],converted_shape[3])).numpy()
                #print(channels)
                non_zero_channels = np.nonzero([np.sum(channel) for channel in channels])[0]
                #print(non_zero_channels)
                no_of_channels_to_prune_left = no_of_channels_to_prune - (len(channels) - len(non_zero_channels))
                random.shuffle(non_zero_channels)
                channels_to_prune = non_zero_channels[:no_of_channels_to_prune_left]
                mask = tf.reshape(iohw_mask, (no_of_channels,converted_shape[2],converted_shape[3])).numpy()

                for channel_to_prune in channels_to_prune:
                    channels[channel_to_prune] = tf.zeros([converted_shape[2],converted_shape[3]])
                    mask[channel_to_prune] = tf.zeros([converted_shape[2],converted_shape[3]])

                reshaped_mask = tf.reshape(mask, converted_shape)
                reshaped_weights = tf.reshape(channels, converted_shape)
                weights[x] = convert_from_iohw_to_hwio(reshaped_weights)
                weights[x+2] = convert_from_iohw_to_hwio(reshaped_mask)
            #self.set_weights(weights)
            return True
        def prune_dense_layers(dense_layers_to_prune, weights):
            for layer_to_prune in dense_layers_to_prune:
                rows = weights[layer_to_prune]
                no_of_rows_to_prune = int(ratio * len(weights[layer_to_prune]))
                non_zero_rows = np.nonzero([np.sum(row) for row in rows])[0]
                no_of_rows_to_prune_left = no_of_rows_to_prune - (len(rows) - len(non_zero_rows))
                random.shuffle(non_zero_rows)
                rows_to_prune = non_zero_rows[:no_of_rows_to_prune_left]
                
                for row_to_prune in rows_to_prune:
                    weights[layer_to_prune][row_to_prune] = tf.zeros(len(weights[layer_to_prune][row_to_prune]))
                    weights[layer_to_prune+2][row_to_prune] = tf.zeros(len(weights[layer_to_prune][row_to_prune]))
            return True
        weights = self.get_weights()
        conv_layers_to_prune = [0,3]
        dense_layers_to_prune = [6,9,12]
        prune_conv_layers(conv_layers_to_prune, weights)
        prune_dense_layers(dense_layers_to_prune, weights)
        self.set_weights(weights)
        return True

    def prune_random_global_struct(self, ratio):
        raise Warning('Not yet implemented')
        return False
    def prune_magnitude_local_struct(self, ratio):
        def prune_conv_layers(conv_layers_to_prune, weights):
            for x in conv_layers_to_prune:

                vals = []
                iohw_weights = convert_from_hwio_to_iohw(weights[x])
                iohw_mask = convert_from_hwio_to_iohw(weights[x+2])
                converted_shape = iohw_weights.shape
                no_of_channels = converted_shape[0]*converted_shape[1]
                no_of_channels_to_prune = int(np.round(ratio * no_of_channels))
                channels = tf.reshape(iohw_weights, (no_of_channels,converted_shape[2],converted_shape[3])).numpy()
                
                mask = tf.reshape(iohw_mask, (no_of_channels,converted_shape[2],converted_shape[3])).numpy()
                for channel in channels:
                    vals.append(tf.math.reduce_sum(tf.math.abs(channel)))
                channels_to_prune = np.argsort(vals)[:no_of_channels_to_prune]

                for channel_to_prune in channels_to_prune:
                    channels[channel_to_prune] = tf.zeros([converted_shape[2],converted_shape[3]])
                    mask[channel_to_prune] = tf.zeros([converted_shape[2],converted_shape[3]])

                reshaped_mask = tf.reshape(mask, converted_shape)
                reshaped_weights = tf.reshape(channels, converted_shape)
                weights[x] = convert_from_iohw_to_hwio(reshaped_weights)
                weights[x+2] = convert_from_iohw_to_hwio(reshaped_mask)
            #self.set_weights(weights)
            return weights
        def prune_dense_layers(dense_layers_to_prune, weights):
            for layer_to_prune in dense_layers_to_prune:
                no_of_rows_to_prune = int(ratio * len(weights[layer_to_prune]))
                vals = []
                for row in weights[layer_to_prune]:
                    vals.append(np.sum(np.abs(row)))
                rows_to_prune = np.argsort(vals)[:no_of_rows_to_prune]
                for row_to_prune in rows_to_prune:

                    weights[layer_to_prune][row_to_prune] = tf.zeros(len(weights[layer_to_prune][row_to_prune]))
                    weights[layer_to_prune+2][row_to_prune] = tf.zeros(len(weights[layer_to_prune][row_to_prune]))
            return weights
        weights = self.get_weights()
        conv_layers_to_prune = [0,3]
        dense_layers_to_prune = [6,9,12]
        weights = prune_conv_layers(conv_layers_to_prune, weights)
        weights = prune_dense_layers(dense_layers_to_prune, weights)
        self.set_weights(weights)
        return True
        
    def prune_magnitude_global_struct(self, ratio):
        def prune_conv_layers(conv_layers_to_prune, weights):
            #kernels = []
            #converted_shape = iohw_weights.shape
            
            #no_of_channels_to_prune = int(np.round(ratio * no_of_channels))
            #channels = tf.reshape(iohw_weights, (no_of_channels,converted_shape[2],converted_shape[3])).numpy()
            all_channels = np.empty((0,5,5))
            original_shapes = []
            for layer_to_prune in conv_layers_to_prune:
                iohw_weights = convert_from_hwio_to_iohw(weights[layer_to_prune])
                
                converted_shape = iohw_weights.shape
                no_of_channels = converted_shape[0]*converted_shape[1]
                channels = tf.reshape(iohw_weights, (no_of_channels,converted_shape[2],converted_shape[3])).numpy()

                all_channels = np.concatenate((all_channels, channels))
            mask = np.ones(all_channels.shape)

            vals = [np.sum(np.abs(channel)) for channel in all_channels]
            no_of_channels_to_prune = int(ratio * len(vals))
            channels_to_prune = np.argsort(vals)[:no_of_channels_to_prune]

            for channel_to_prune in channels_to_prune:
                all_channels[channel_to_prune] = tf.zeros((5,5))
                mask[channel_to_prune] = tf.zeros((5,5))
                
            z = 0
            for i, layer_to_prune in enumerate(conv_layers_to_prune):
                original_shape = convert_from_hwio_to_iohw(weights[layer_to_prune]).shape
                pruned_layer = tf.reshape(all_channels[z:z + original_shape[0]*original_shape[1]], original_shape)
                weights[layer_to_prune] = convert_from_iohw_to_hwio(pruned_layer)
                z = original_shape[0]*original_shape[1]
            
            
                
            return weights
        def prune_dense_layers(conv_layers_to_prune, weights):
            return True
        weights = self.get_weights()
        conv_layers_to_prune = [0,3]
        dense_layers_to_prune = [6,9,12]
        weights = prune_conv_layers(conv_layers_to_prune, weights)
        self.set_weights(weights)
        #raise Warning('Not yet implemented')
        return False

        

    
    
    
    def prune_globally(self,ratio):
                
        shape1 = self.dense1.w.shape
        shape2 = self.dense2.w.shape
        shape3 = self.dense3.w.shape

        flat_weights = np.append(self.dense1.w.numpy().flatten() ,self.dense2.w.numpy().flatten())
        flat_weights = np.append(flat_weights ,self.dense3.w.numpy().flatten())
        flat_mask = np.append(self.dense1.mask.numpy().flatten(), self.dense2.mask.numpy().flatten())
        flat_mask = np.append(flat_mask, self.dense3.mask.numpy().flatten())
        
        no_of_weights_to_prune = int(len(flat_weights)*ratio)
        indices_to_delete = np.abs(flat_weights).argsort()[:no_of_weights_to_prune]
        
        for idx_to_delete in indices_to_delete:
            flat_mask[idx_to_delete] = 0
            flat_weights[idx_to_delete] = 0
            
        w1 = flat_weights[:shape1[0]*shape1[1]].reshape(shape1)
        w2 = flat_weights[shape1[0]*shape1[1]:shape1[0]*shape1[1]+shape2[0]*shape2[1]].reshape(shape2)
        w3 = flat_weights[-shape3[0]*shape3[1]:].reshape(shape3)
        m1 = flat_mask[:shape1[0]*shape1[1]].reshape(shape1)
        m2 = flat_mask[shape1[0]*shape1[1]:shape1[0]*shape1[1]+shape2[0]*shape2[1]].reshape(shape2)
        m3 = flat_mask[-shape3[0]*shape3[1]:].reshape(shape3)
        self.set_weights([w1,m1,w2,m2,w3,m3])
        #print(weights)
        return
    
    def prune_locally(self, ratio):
        layers = self.get_weights()
        for index, weights in enumerate(layers):
            if (index == 0) or (index == 2) or (index == 4):
                shape = weights.shape
                flat_weights = weights.flatten()
                mask = layers[index+1].flatten()
                
                no_of_weights_to_prune = int(len(flat_weights)*ratio)
                indices_to_delete = np.abs(flat_weights).argsort()[:no_of_weights_to_prune]
                for idx_to_delete in indices_to_delete:
                    mask[idx_to_delete] = 0
                    flat_weights[idx_to_delete] = 0
                
                mask_reshaped = mask.reshape(shape)
                weights_reshaped = flat_weights.reshape(shape)
                layers[index+1] = mask_reshaped
                layers[index] = weights_reshaped
        self.set_weights(layers)
        return 

# Compile and Train Model

In [None]:
model = LeNet300_100()

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) ,
              metrics=['accuracy'],
              experimental_run_tf_function=False
             )

callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
model.fit(x=x_train,
          y=y_train,
          batch_size=128,
          epochs=1,
          callbacks=[callback],
          validation_data=(x_test, y_test),
         )

model.save('./saved-models/mini-pipeline-mlp-baseline-model')
model.save_weights('./saved-models/weights')