# Iterative pruning pipeline
Model: Multi Layer Perceptron

Pruning Mehtods: global and local iterative magnitude pruning

*Pruning functions as class methods*

In [2]:
EXPERIMENT_NAME = 'mlp-global-magnitude-unstruct'


In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
import json
import numpy as np
import pandas as pd
from tensorflow.keras import layers
from tqdm import tqdm
import matplotlib.pyplot as plt
from tqdm import tqdm
import foolbox as fb


tf.compat.v1.enable_eager_execution()
tf.keras.backend.clear_session()  # For easy reset of notebook state.

# Prune, Train Attack Pipeline

In [None]:
loc_pruning_pgd_success_rates = []
loc_pruning_cw_success_rates = []
glob_pruning_pgd_success_rates = []
glob_pruning_cw_success_rates = []
loc_pruning_all_accuracies = []
glob_pruning_all_accuracies = []
for j in tqdm(range(3)):
    model_for_glob_pruning = initialize_base_model(j, save_weights=True)
    model_for_loc_pruning = initialize_base_model(j)
    loc_accuracies = []
    loc_pgd_success_rate = []
    loc_cw_success_rate = []
    glob_accuracies = []
    glob_pgd_success_rate = []
    glob_cw_success_rate = []
    compression_rates = [1, 2, 4, 8, 16, 32, 64]
    pruning_ratios = [1-1/x for x in compression_rates]
    for index, pruning_ratio in tqdm(enumerate(pruning_ratios)):
        model_for_glob_pruning.load_weights(f'./saved-weights/{EXPERIMENT_NAME}-{j}')
        model_for_loc_pruning.load_weights(f'./saved-weights/base-model-weights-{j}')
        #iteratively prune and train (only to convergence if the final stage of pruning is reached)
        for i in range(index + 1):
            if i != index:
                #glocbal pruning
                model_for_glob_pruning.prune_globally(pruning_ratio)
                model_for_glob_pruning.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) ,
                              metrics=['accuracy'],
                             )
                
                #local pruning
                model_for_loc_pruning.prune_globally(pruning_ratio)
                model_for_loc_pruning.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) ,
                              metrics=['accuracy'],
                             )
                #fine-tune
                model_for_glob_pruning = train_model(model_for_glob_pruning, to_convergence=False)
                model_for_loc_pruning = train_model(model_for_loc_pruning, to_convergence=False)
            if i == index:
                print('final pruning and eval')
                model_for_glob_pruning.prune_globally(pruning_ratio)
                model_for_glob_pruning.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) ,
                              metrics=['accuracy'],
                             )
                model_for_loc_pruning.prune_globally(pruning_ratio)
                model_for_loc_pruning.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) ,
                              metrics=['accuracy'],
                             )
                model_for_glob_pruning = train_model(model_for_glob_pruning, to_convergence=True)
                model_for_loc_pruning = train_model(model_for_loc_pruning, to_convergence=True)
                loc_accuracies.append(model_for_loc_pruning.evaluate(x_test, y_test, verbose=0))
                loc_pgd_success_rate.append(pgd_attack(model_for_loc_pruning))
                #loc_cw_success_rate.append(cw2_attack(model_for_loc_pruning))
                glob_accuracies.append(model_for_glob_pruning.evaluate(x_test, y_test, verbose=0))
                glob_pgd_success_rate.append(pgd_attack(model_for_glob_pruning))
                #glob_cw_success_rate.append(cw2_attack(model_for_glob_pruning))
                
    loc_pruning_all_accuracies.append(loc_accuracies)
    loc_pruning_pgd_success_rates.append(loc_pgd_success_rate)
    loc_pruning_cw_success_rates.append(loc_cw_success_rate)
    glob_pruning_all_accuracies.append(glob_accuracies)
    glob_pruning_pgd_success_rates.append(glob_pgd_success_rate)
    glob_pruning_cw_success_rates.append(glob_cw_success_rate)

    

#write to csv and json
pd.DataFrame(loc_pruning_all_accuracies).to_csv('saved-results/mlp-loc-accuracies.csv',index=False)
with open('saved-results/mlp-loc-accuracies.json', 'w') as f:
    json.dump(loc_pruning_all_accuracies, f)
    
pd.DataFrame(loc_pruning_pgd_success_rates).to_csv('saved-results/mlp-loc-pgd-success.csv',index=False)
with open('saved-results/mlp-loc-pgd-success.json', 'w') as f:
    json.dump(loc_pruning_pgd_success_rates, f)
    
pd.DataFrame(loc_pruning_cw_success_rates).to_csv('saved-results/mlp-loc-cw2-success.csv',index=False)
with open('saved-results/mlp-loc-cw2-success.json', 'w') as f:
    json.dump(loc_pruning_cw_success_rates, f)
    
pd.DataFrame(glob_pruning_all_accuracies).to_csv('saved-results/mlp-glob-accuracies.csv',index=False)
with open('saved-results/mlp-glob-accuracies.json', 'w') as f:
    json.dump(glob_pruning_all_accuracies, f)
    
pd.DataFrame(glob_pruning_pgd_success_rates).to_csv('saved-results/mlp-glob-pgd-success.csv',index=False)
with open('saved-results/mlp-glob-pgd-success.json', 'w') as f:
    json.dump(glob_pruning_pgd_success_rates, f)
    
pd.DataFrame(glob_pruning_cw_success_rates).to_csv('saved-results/mlp-glob-cw2-success.csv',index=False)
with open('saved-results/mlp-glob-cw2-success.json', 'w') as f:
    json.dump(glob_pruning_cw_success_rates, f)

In [48]:
#write to csv and json
pd.DataFrame(loc_pruning_all_accuracies).to_csv('saved-results/mlp-loc-accuracies.csv',index=False)
with open('saved-results/mlp-loc-accuracies.json', 'w') as f:
    json.dump(loc_pruning_all_accuracies, f)
    
pd.DataFrame(loc_pruning_pgd_success_rates).to_csv('saved-results/mlp-loc-pgd-success.csv',index=False)
with open('saved-results/mlp-loc-pgd-success.json', 'w') as f:
    json.dump(loc_pruning_pgd_success_rates, f)
    
pd.DataFrame(loc_pruning_cw_success_rates).to_csv('saved-results/mlp-loc-cw2-success.csv',index=False)
with open('saved-results/mlp-loc-cw2-success.json', 'w') as f:
    json.dump(loc_pruning_cw_success_rates, f)
    
pd.DataFrame(glob_pruning_all_accuracies).to_csv('saved-results/mlp-glob-accuracies.csv',index=False)
with open('saved-results/mlp-glob-accuracies.json', 'w') as f:
    json.dump(glob_pruning_all_accuracies, f)
    
pd.DataFrame(glob_pruning_pgd_success_rates).to_csv('saved-results/mlp-glob-pgd-success.csv',index=False)
with open('saved-results/mlp-glob-pgd-success.json', 'w') as f:
    json.dump(glob_pruning_pgd_success_rates, f)
    
pd.DataFrame(glob_pruning_cw_success_rates).to_csv('saved-results/mlp-glob-cw2-success.csv',index=False)
with open('saved-results/mlp-glob-cw2-success.json', 'w') as f:
    json.dump(glob_pruning_cw_success_rates, f)

In [34]:
def get_average_success_rates(all_success_rates):
    success_per_pruning_rate=[]
    for i in range(len(all_success_rates)):
        for j in range(len(all_success_rates[i])):

            try:
                success_per_pruning_rate[j].append(all_success_rates[i][j])
            except:
                success_per_pruning_rate.append([])
                success_per_pruning_rate[j].append(all_success_rates[i][j])
    avg_success_per_pruning_rate = [sum(x)/len(x) for x in success_per_pruning_rate]
    return avg_success_per_pruning_rate

In [3]:
def get_average_accuracies(all_accuracies):
    acc_per_pruning_rate=[]
    for i in range(len(all_accuracies)):
        for j in range(len(all_accuracies[i])):

            try:
                acc_per_pruning_rate[j].append(all_accuracies[i][j][1])
            except:
                acc_per_pruning_rate.append([])
                acc_per_pruning_rate[j].append(all_accuracies[i][j][1])
    avg_acc_per_pruning_rate = [sum(x)/len(x) for x in acc_per_pruning_rate]
    return avg_acc_per_pruning_rate

# Helper Functions

In [4]:
def train_model(model, to_convergence=True):
    if to_convergence == True:
        callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
        model.fit(
            x=x_train,
            y=y_train,
            batch_size=64,
            epochs=100,
            callbacks=[callback],
            validation_data=(x_test, y_test),
            )
    if to_convergence == False:
        model.fit(
            x=x_train,
            y=y_train,
            batch_size=64,
            epochs=2,
            validation_data=(x_test, y_test),
            )
    return model



def prune_weights(model, pruning_ratio):
    weights = model.get_weights()
    weights_to_prune = model.get_weights()
    for index, weight in enumerate(weights):
        if (index == 0) or (index == 2) or (index == 4):
            flat_weights = weight.flatten()
            flat_weights_to_prune = weights_to_prune[index].flatten()
            mask = weights_to_prune[index+1].flatten()
            #print (flat_weights_to_prune.shape, flat_weights.shape)
            flat_weights_df = pd.DataFrame(flat_weights)
            #mask_df = pd.DataFrame(mask)
            no_of_weights_to_prune = int(len(flat_weights)*pruning_ratio)
            #print(no_of_weights_to_prune)
            indices_to_delete = flat_weights_df.abs().values.argsort(0)[:no_of_weights_to_prune]
            for idx_to_delete in indices_to_delete:
                mask[idx_to_delete] = 0
                flat_weights_to_prune[idx_to_delete] = 0
            dims = weights_to_prune[index+1].shape
            mask_reshaped = mask.reshape(dims)
            weights_reshaped = flat_weights_to_prune.reshape(dims)
            weights_to_prune[index+1] = mask_reshaped
            weights_to_prune[index] = weights_reshaped
    
    return weights_to_prune



def pgd_attack(model_to_attack):
    fmodel = fb.models.TensorFlowModel(model_to_attack, bounds=(0,1))
    attack = fb.attacks.LinfProjectedGradientDescentAttack()
    adversarials = attack(
        fmodel,
        x,
        y,
        epsilons=[15/255]
    )
    return np.count_nonzero(adversarials[2])/len(y)

def cw2_attack(model_to_attack):
    fmodel = fb.models.TensorFlowModel(model_to_attack, bounds=(0,1))
    attack = fb.attacks.L2CarliniWagnerAttack()
    adversarials = attack(
        fmodel,
        x,
        y,
        epsilons=[.5]
    )
    return np.count_nonzero(adversarials[2])/len(y)

def initialize_base_model(index, save_weights=False):
    model = LeNet300_100()
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) ,
                  metrics=['accuracy'],
                  experimental_run_tf_function=False
                 )

    callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
    model.fit(x=x_train,
              y=y_train,
              batch_size=64,
              epochs=1,
              callbacks=[callback],
              validation_data=(x_test, y_test),
             )
    if save_weights == True:
        model.save_weights(f'./saved-weights/base-model-weights-{index}')
    return model
    

# Load Data

In [5]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

x = tf.convert_to_tensor(x_train[:500].reshape(500,28*28))
y = tf.convert_to_tensor([y_train[:500]])[0];

# Define Model

In [6]:
class CustomLayer(layers.Layer):

    def __init__(self, units=32, activation='relu'):
        super(CustomLayer, self).__init__()
        self.units = units
        self.activation = activation

    def build(self, input_shape):
        #print(input_shape)
        self.w = self.add_weight(shape=(input_shape[-1], self.units),
                                 initializer='random_normal',
                                 trainable=True,
                                name='unpruned_weights')
        self.mask = self.add_weight(shape=(self.w.shape),
                                    initializer='ones',
                                    trainable=False,
                                   name='pruning_mask')

        
    def call(self, inputs):
        #self.mask_2 = tf.multiply(self.mask, self.mask_2)
        x = tf.multiply(self.w, self.mask)
        #print(self.pruned_w.eval())
        x = tf.matmul(inputs, x)
        
        if self.activation == 'relu':
            return tf.keras.activations.relu(x)
        if self.activation == 'softmax':
            return tf.keras.activations.softmax(x)
        raise ValueError('Activation function not implemented')

class LeNet300_100(tf.keras.Model):
    def __init__(self):
        super(LeNet300_100, self).__init__()
        self.dense1 = CustomLayer(300)
        self.dense2 = CustomLayer(100)
        self.dense3 = CustomLayer(10, activation='softmax')
        
    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        return self.dense3(x)
    
    def prune_globally(self,ratio):
                
        shape1 = self.dense1.w.shape
        shape2 = self.dense2.w.shape
        shape3 = self.dense3.w.shape

        flat_weights = np.append(self.dense1.w.numpy().flatten() ,self.dense2.w.numpy().flatten())
        flat_weights = np.append(flat_weights ,self.dense3.w.numpy().flatten())
        flat_mask = np.append(self.dense1.mask.numpy().flatten(), self.dense2.mask.numpy().flatten())
        flat_mask = np.append(flat_mask, self.dense3.mask.numpy().flatten())
        
        no_of_weights_to_prune = int(len(flat_weights)*ratio)
        indices_to_delete = np.abs(flat_weights).argsort()[:no_of_weights_to_prune]
        
        for idx_to_delete in indices_to_delete:
            flat_mask[idx_to_delete] = 0
            flat_weights[idx_to_delete] = 0
            
        w1 = flat_weights[:shape1[0]*shape1[1]].reshape(shape1)
        w2 = flat_weights[shape1[0]*shape1[1]:shape1[0]*shape1[1]+shape2[0]*shape2[1]].reshape(shape2)
        w3 = flat_weights[-shape3[0]*shape3[1]:].reshape(shape3)
        m1 = flat_mask[:shape1[0]*shape1[1]].reshape(shape1)
        m2 = flat_mask[shape1[0]*shape1[1]:shape1[0]*shape1[1]+shape2[0]*shape2[1]].reshape(shape2)
        m3 = flat_mask[-shape3[0]*shape3[1]:].reshape(shape3)
        self.set_weights([w1,m1,w2,m2,w3,m3])
        #print(weights)
        return
    
    def prune_locally(self, ratio):
        layers = self.get_weights()
        for index, weights in enumerate(layers):
            if (index == 0) or (index == 2) or (index == 4):
                shape = weights.shape
                flat_weights = weights.flatten()
                mask = layers[index+1].flatten()
                
                no_of_weights_to_prune = int(len(flat_weights)*ratio)
                indices_to_delete = np.abs(flat_weights).argsort()[:no_of_weights_to_prune]
                for idx_to_delete in indices_to_delete:
                    mask[idx_to_delete] = 0
                    flat_weights[idx_to_delete] = 0
                
                mask_reshaped = mask.reshape(shape)
                weights_reshaped = flat_weights.reshape(shape)
                layers[index+1] = mask_reshaped
                layers[index] = weights_reshaped
        self.set_weights(layers)
        return 

# Compile and Train Model

In [None]:
model = LeNet300_100()

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) ,
              metrics=['accuracy'],
              experimental_run_tf_function=False
             )

callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
model.fit(x=x_train,
          y=y_train,
          batch_size=128,
          epochs=1,
          callbacks=[callback],
          validation_data=(x_test, y_test),
         )

model.save('./saved-models/mini-pipeline-mlp-baseline-model')
model.save_weights('./saved-models/weights')

In [88]:
weights = np.array(list(range(100)));zz

array([ 0,  1,  0,  3,  0,  5,  0,  7,  0,  9,  0, 11,  0, 13,  0, 15,  0,
       17,  0, 19,  0, 21,  0, 23,  0, 25,  0, 27,  0, 29,  0, 31,  0, 33,
        0, 35,  0, 37,  0, 39,  0, 41,  0, 43,  0, 45,  0, 47,  0, 49,  0,
       51,  0, 53,  0, 55,  0, 57,  0, 59,  0, 61,  0, 63,  0, 65,  0, 67,
        0, 69,  0, 71,  0, 73,  0, 75,  0, 77,  0, 79,  0, 81,  0, 83,  0,
       85,  0, 87,  0, 89,  0, 91,  0, 93,  0, 95,  0, 97,  0, 99])

In [93]:

no_of_weighs_to_prune = rate * len(weights)

non_zero_weights = np.nonzero(zz)[0]
no_of_weights_to_prune_left = int(no_of_weighs_to_prune - (len(weights) - len(non_zero_weights)) )

random.shuffle(non_zero_weights)
indices_to_delete = non_zero_weights[:no_of_weights_to_prune_left]