In [None]:
import time
from copy import deepcopy
from jax import numpy as jnp
from flax import linen as nn

from modules.MNIST import MNIST
from modules.trainer import TrainerModule
from modules.pruner import *

MNIST DS

In [None]:
(train_images, train_labels), (val_images, val_labels), (test_images, test_labels) = MNIST("../MNIST_DATASET")

# Add channel dimension (1 for grayscale images)
test_images  = jnp.expand_dims(test_images, axis=-1)
val_images   = jnp.expand_dims(val_images, axis=-1)
train_images = jnp.expand_dims(train_images, axis=-1)

In [None]:
init = nn.initializers.xavier_normal
class CNN(nn.Module):
  # We will have to change these during channel pruning
  out_channels: dict

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features= self.out_channels['Conv_0'], kernel_size=(3, 3), kernel_init= init())(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape= (2, 2), strides=(2, 2))
    x = nn.Conv(features=self.out_channels['Conv_1'], kernel_size=(3, 3), kernel_init= init())(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2,2))
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=256, kernel_init= init())(x)
    x = nn.relu(x)
    x = nn.Dense(features=10, kernel_init= init())(x)
    return x

org_tm_cnn = TrainerModule(CNN, {'out_channels': {'Conv_0': 32, 'Conv_1': 64}}, "adam", 1e-3, jnp.ones((10, 28, 28, 1)))

In [None]:
best_state = org_tm_cnn.train(train_data= (train_images, train_labels), val_data= (val_images, val_labels), num_epochs= 25)

In [None]:
dense_model_acc = org_tm_cnn.test(test_data= (test_images, test_labels))
dense_model_size = org_tm_cnn.get_model_size()
print(f"Dense model accuracy: {dense_model_acc*100:.2f}%")
print(f"Dense model size: {dense_model_size}")

Distribution of weight values

In [None]:
org_wt_dist_fig = org_tm_cnn.plot_weight_distribution()
org_wt_dist_fig.show()

Fine grained pruning

In [None]:
# creating a copy of org model
fp_tm_cnn = deepcopy(org_tm_cnn)

accuracies, sparsities = sensitivity_scan(fp_tm_cnn, test_data= (test_images, test_labels), verbose= False)
plot_sensitivity_scan(sparsities, accuracies)

Number of parameters in each layer

Fine Grained Pruning

In [None]:
sparsity_dict = {
  'Conv_0': 0.6,
  'Conv_1': 0.7,
  'Dense_0': 0.9,
  'Dense_1': 0.8,
}

fgpruner = FineGrainedPruner(fp_tm_cnn, sparsity_dict)

fgpruner.apply()
print(f"Model sparsity before pruning: {org_tm_cnn.get_model_sparsity()}")
print(f"Model size before pruning: {org_tm_cnn.get_model_size()}")
print(f"Model accuracy before pruning: {org_tm_cnn.test(test_data= (test_images, test_labels))}")

print(f"Model sparsity after pruning: {fp_tm_cnn.get_model_sparsity()}")
print(f"Model size after pruning: {fp_tm_cnn.get_model_size()}")
print(f"Model accuracy after pruning: {fp_tm_cnn.test(test_data= (test_images, test_labels))}")

wt_dist_fp_fig = fp_tm_cnn.plot_weight_distribution(count_nonzero_only=True)
wt_dist_fp_fig.show()

Fine tuning the pruned model

In [None]:
best_state = fp_tm_cnn.train(train_data= (train_images, train_labels), val_data= (val_images, val_labels), num_epochs= 5, callbacks= [fgpruner.apply], verbose = False)

print(f"Pruned model sparsity after finetuning: {fp_tm_cnn.get_model_sparsity()}")
print(f"Pruned model size after finetuning: {fp_tm_cnn.get_model_size()}")
print(f"Pruned model accuracy after finetuning: {fp_tm_cnn.test(test_data= (test_images, test_labels))}")
fp_tm_cnn.plot_weight_distribution(count_nonzero_only=True)

fig = None
fig = fp_tm_cnn.plot_num_parameters(count_nonzero_only= True, color='red', fig= fig)
fig = org_tm_cnn.plot_num_parameters(count_nonzero_only= True, color='blue', fig= fig)
fig.show()

Channel Pruning

In [None]:
# Creating a copy of org model for channel pruning
cp_tm_cnn = deepcopy(org_tm_cnn)

cpruner = ChannelPruner(cp_tm_cnn)
cpruner.apply(prune_ratio = 0.8)

cp_tm_cnn.init_train_state(cp_tm_cnn.model.apply, cp_tm_cnn.state.params, cp_tm_cnn.tx)
print(f"Pruned model size: {cp_tm_cnn.get_model_size()}")
print(f"Pruned model accuracy: {cp_tm_cnn.test(test_data= (test_images, test_labels))}")

best_state = cp_tm_cnn.train(train_data= (train_images, train_labels), val_data= (val_images, val_labels), num_epochs= 5, verbose= False)
print(f"Pruned model size after finetuning: {cp_tm_cnn.get_model_size()}")
print(f"Pruned model accuracy after finetuning: {cp_tm_cnn.test(test_data= (test_images, test_labels))}")

fig = None
fig = cp_tm_cnn.plot_num_parameters(count_nonzero_only= True, color='red', fig= fig)
fig = org_tm_cnn.plot_num_parameters(count_nonzero_only= True, color='blue', fig= fig)
fig.show()

Computational improvements

In [None]:
def measure_latency(tm_cnn: TrainerModule, dummy_inp, n_warmup= 20, n_test = 1000):
  for _ in range(n_warmup):
    tm_cnn.test(dummy_inp)
  
  st = time.perf_counter()
  for _ in range(n_test):
    tm_cnn.test(dummy_inp)
  end = time.perf_counter()

  rt_us = ( ((end - st) / n_test) / len(dummy_inp[1]) ) * 1e6
  return rt_us

print(f"Org Model            | Size: {org_tm_cnn.get_model_size()}   | Acc: {org_tm_cnn.test(test_data= (test_images, test_labels)):.3f} | Latency: {measure_latency(org_tm_cnn, (test_images, test_labels)):.2f} us")
print(f"Channel Pruned Model | Size: {cp_tm_cnn.get_model_size()} | Acc: {cp_tm_cnn.test(test_data= (test_images, test_labels)):.3f} | Latency: {measure_latency(cp_tm_cnn, (test_images, test_labels)):.2f} us")