<a href="https://colab.research.google.com/github/ChrisW2420/FedDistill/blob/main/Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Prototype

## Import packages

In [1]:
!pip install -q tensorflow-model-optimization

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/242.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/242.5 kB[0m [31m794.6 kB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.9/242.5 kB[0m [31m1.1 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━[0m [32m204.8/242.5 kB[0m [31m2.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.sparsity import keras as sparsity
import tf_keras as keras
import tempfile
from keras.callbacks import EarlyStopping, Callback

In [None]:
!pip install wandb
import wandb
wandb.login()
from wandb.keras import WandbMetricsLogger

## Prepare Dataset

In [3]:
# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
validation_split = 0.1

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


## Models

In [4]:
def smallCNN():
  model = keras.Sequential(
      [
          keras.Input(shape=(28, 28, 1)),
          keras.layers.Conv2D(8, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.ReLU(),
          keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
          keras.layers.Conv2D(8, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.Flatten(),
          keras.layers.Dense(10),
      ],
      name="smallcnn",
  )
  return model

def mediumCNN():
  model = keras.Sequential(
      [
          keras.Input(shape=(28, 28, 1)),
          keras.layers.Conv2D(8, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.ReLU(),
          keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
          keras.layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.ReLU(),
          keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
          keras.layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.Flatten(),
          keras.layers.Dense(10),
      ],
      name="mediumcnn",
  )
  return model

def bigCNN():
  model = keras.Sequential(
      [
          keras.Input(shape=(28, 28, 1)),
          keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.ReLU(),
          keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
          keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.ReLU(),
          keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
          keras.layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
          keras.layers.Flatten(),
          keras.layers.Dense(10),
      ],
      name="bigcnn",
  )
  return model

In [5]:
# early stopping when training converges on validation loss
early_stopping = EarlyStopping(
    monitor='val_loss',
    min_delta=0.001,  # only consider as improvement significant changes
    patience=2,      # number of epochs with no improvement after which training will be stopped
    verbose=1,
    mode='min'        # 'min' because we want to minimize the loss
)

def trainCNN_weights(model, _epoch):
  model.compile(
      optimizer='adam',
      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[keras.metrics.SparseCategoricalAccuracy()]
  )

  model.fit(x_train, y_train, epochs=_epoch,validation_split=validation_split, callbacks=[early_stopping])
  model.evaluate(x_test, y_test)

  return model

  # _, pretrained_weights = tempfile.mkstemp('.tf')

  # model.save_weights(pretrained_weights)

  # return pretrained_weights

#Pruning

### Basic implementation

In [6]:
# functions
def prune_finetrain(base_model, _epochs):
  callbacks = [
      sparsity.UpdatePruningStep(),
      early_stopping
  ]

  model_for_pruning = sparsity.prune_low_magnitude(base_model) #default constant sparsity of 50%
  model_for_pruning.summary()

  model_for_pruning.compile(
        optimizer='adam',
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy()]
  )

  model_for_pruning.fit(
      x_train,
      y_train,
      validation_split=validation_split,
      callbacks=callbacks,
      epochs=_epochs,
  )

  return model_for_pruning

def get_model_sparsity(model):
    total_weights = 0
    zero_weights = 0
    for weight in model.get_weights():
        total_weights += weight.size
        zero_weights += np.count_nonzero(weight == 0)
    return zero_weights / total_weights

def get_gzipped_model_size(model):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, keras_file = tempfile.mkstemp('.h5')
  model.save(keras_file, include_optimizer=False)

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(keras_file)

  return os.path.getsize(zipped_file)

In [7]:
# training base model

model = mediumCNN()
model = trainCNN_weights(model, 5)
_, pretrained_weights = tempfile.mkstemp('.tf')
model.save_weights(pretrained_weights)

# pruning
base_model = mediumCNN()
base_model.load_weights(pretrained_weights) # optional but recommended.
pruned_model = prune_finetrain(base_model, 5)

# continue training base model for performance comparison
base_model_copy = mediumCNN()
base_model_copy.load_weights(pretrained_weights)
base_model_copy = trainCNN_weights(base_model_copy, 5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
Model: "mediumcnn"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_conv2d  (None, 14, 14, 8)         154       
 _3 (PruneLowMagnitude)                                          
                                                                 
 prune_low_magnitude_re_lu_  (None, 14, 14, 8)         1         
 2 (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_max_po  (None, 14, 14, 8)         1         
 oling2d_2 (PruneLowMagnitu                                      
 de)                                                             
                                                                 
 prune_low_magnitude_conv2d  (None, 7, 7, 16)          2322      
 _4 (PruneLowMagnitude)                                          
       

In [8]:
pruned_model_stripped = sparsity.strip_pruning(pruned_model)
print("final model")
pruned_model.summary()

final model
Model: "mediumcnn"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_conv2d  (None, 14, 14, 8)         154       
 _3 (PruneLowMagnitude)                                          
                                                                 
 prune_low_magnitude_re_lu_  (None, 14, 14, 8)         1         
 2 (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_max_po  (None, 14, 14, 8)         1         
 oling2d_2 (PruneLowMagnitu                                      
 de)                                                             
                                                                 
 prune_low_magnitude_conv2d  (None, 7, 7, 16)          2322      
 _4 (PruneLowMagnitude)                                          
                                             

In [12]:
# compare accuracy, sparsity and file size
print(f"Base Model test accuracy: {base_model_copy.evaluate(x_test, y_test)[1]:.2f}%")
print(f"Pruned Model test accuracy: {pruned_model.evaluate(x_test, y_test)[1]:.2f}%")

sparsity_percentage = get_model_sparsity(base_model_copy) * 100
print(f"Base Model sparsity: {sparsity_percentage:.2f}%")

sparsity_percentage = get_model_sparsity(pruned_model) * 100
print(f"Pruned Model sparsity: {sparsity_percentage:.2f}%")

print('\n')
print("Size of gzipped base model: %.2f bytes" % (get_gzipped_model_size(base_model_copy)))
print("Size of gzipped pruned model: %.2f bytes" % (get_gzipped_model_size(pruned_model_stripped)))

Base Model test accuracy: 0.99%
Pruned Model test accuracy: 0.99%
Base Model sparsity: 0.00%
Pruned Model sparsity: 49.74%






Size of gzipped base model: 25653.00 bytes
Size of gzipped pruned model: 16903.00 bytes


## Pruning to as small as possible with > 99.5% accuracy as best

In [41]:
# function
# save good performance ones
class SaveSparseModelCallback(Callback):
    def __init__(self, model, path='best_sparse_model.h5'):
        super(SaveSparseModelCallback, self).__init__()
        self.model = model
        self.best_accuracy = 0.98
        self.path = path

    def on_train_batch_end(self, batch, logs=None):
        current_accuracy = logs.get('sparse_categorical_accuracy')
        if current_accuracy > self.best_accuracy:
            print(f"batch {batch + 1}: training accuracy {current_accuracy*100:.2f}% exceeds {self.best_accuracy*100:.2f}%. Saving model.")
            self.model = sparsity.strip_pruning(self.model)
            self.model.save(self.path)
            self.best_accuracy = current_accuracy


def prune_finetrain(base_model, _epochs):
  steps_per_epoch = len(x_train)*(1-validation_split) // batch_size
  print('steps_per_epoch: ', steps_per_epoch)
  pruning_schedule = sparsity.PolynomialDecay(initial_sparsity=0.30, final_sparsity=0.60,
                                              begin_step=0, end_step=3000) # increase sparsity


  model_for_pruning = sparsity.prune_low_magnitude(base_model
                                                  , pruning_schedule=pruning_schedule
                                                   )
  model_for_pruning.summary()


  callbacks = [
    sparsity.UpdatePruningStep(),
    sparsity.PruningSummaries(log_dir='/path/to/logs'),
    # SaveSparseModelCallback(pruned_model),
    # early_stopping
]


  model_for_pruning.compile(
        optimizer='adam',
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy()]
  )

  model_for_pruning.fit(
      x_train,
      y_train,
      batch_size=batch_size,
      validation_split=validation_split,
      callbacks=callbacks,
      epochs=_epochs,
  )

  return model_for_pruning

In [42]:
# adaptive pruning
base_model = mediumCNN()
base_model.load_weights(pretrained_weights) # optional but recommended.
adaptive_pruned_model = prune_finetrain(base_model, 3)

steps_per_epoch:  843.0
Model: "mediumcnn"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_conv2d  (None, 14, 14, 8)         154       
 _39 (PruneLowMagnitude)                                         
                                                                 
 prune_low_magnitude_re_lu_  (None, 14, 14, 8)         1         
 26 (PruneLowMagnitude)                                          
                                                                 
 prune_low_magnitude_max_po  (None, 14, 14, 8)         1         
 oling2d_26 (PruneLowMagnit                                      
 ude)                                                            
                                                                 
 prune_low_magnitude_conv2d  (None, 7, 7, 16)          2322      
 _40 (PruneLowMagnitude)                                         
                                 



Epoch 2/3
 162/1688 [=>............................] - ETA: 15s - loss: 0.0519 - sparse_categorical_accuracy: 0.9848

KeyboardInterrupt: 

In [28]:
# compare performance between adaptive pruning and basic pruning
print(f"adaptive Model test accuracy: {adaptive_pruned_model.evaluate(x_test, y_test)[1]:.2f}%")
print(f"pruned Model test accuracy: {pruned_model.evaluate(x_test, y_test)[1]:.2f}%")

sparsity_percentage = get_model_sparsity(adaptive_pruned_model) * 100
print(f"adaptive Model sparsity: {sparsity_percentage:.2f}%")

sparsity_percentage = get_model_sparsity(pruned_model) * 100
print(f"Pruned Model sparsity: {sparsity_percentage:.2f}%")

adaptive_pruned_model = sparsity.strip_pruning(adaptive_pruned_model)
print('\n')
print("Size of gzipped adaptive model: %.2f bytes" % (get_gzipped_model_size(adaptive_pruned_model)))
print("Size of gzipped pruned model: %.2f bytes" % (get_gzipped_model_size(pruned_model_stripped)))





Size of gzipped adaptive model: 14595.00 bytes
Size of gzipped pruned model: 16903.00 bytes
