In [None]:
import tensorflow as tf
import numpy as np
import datetime
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import h5py
import math
import heapq
from collections import Counter
import pickle
import gzip

from tensorflow_model_optimization.python.core.keras.compat import keras
import tensorflow_model_optimization as tfmot

from tensorflow.keras import Model, Sequential
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout, Layer, Add, MaxPool2D
from tensorflow.keras.layers import Input, Activation, Concatenate, Convolution2D, GlobalAveragePooling2D

from keras.callbacks import TensorBoard

from tensorflow.keras.optimizers import SGD
from tensorflow.keras.callbacks import LearningRateScheduler, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import regularizers
from tensorflow.keras.models import clone_model

We import the training and validation data

In [None]:
train_dir = '../imagenet2/imagenet2/train'
val_dir = '../imagenet2/imagenet2/val'


batch_size=64
def random_crop_preprocessing(img):
    crop_size = 224
    h, w, _ = img.shape
    top = np.random.randint(0, h - crop_size + 1)
    left = np.random.randint(0, w - crop_size + 1)
    cropped = img[top:top+crop_size, left:left+crop_size, :]
    mean = np.array([104, 117, 123], dtype=np.float32)
    cropped = cropped - mean
    return cropped

train_datagen = ImageDataGenerator(
    rescale=1.,
    horizontal_flip=True,
    preprocessing_function=random_crop_preprocessing
)

def center_crop_preprocessing(img):
    crop_size = 224
    h, w, _ = img.shape
    top = (h - crop_size) // 2
    left = (w - crop_size) // 2
    cropped = img[top:top+crop_size, left:left+crop_size, :]
    mean = np.array([104, 117, 123], dtype=np.float32)
    cropped = cropped - mean
    return cropped

val_datagen = ImageDataGenerator(
    rescale=1.,
    preprocessing_function=center_crop_preprocessing
)


train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=batch_size,
    class_mode='categorical'
)

val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(224, 224),
    batch_size=batch_size,
    class_mode='categorical'
)
val_generator1 = val_datagen.flow_from_directory(
    val_dir,
    target_size=(224, 224),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=False
)

Found 9559 images belonging to 10 classes.
Found 3963 images belonging to 10 classes.
Found 3963 images belonging to 10 classes.


We define the squeezenet class

In [None]:
def SqueezeNet(nb_classes, inputs=(224, 224,3)):

    input_img = Input(shape=inputs)
    conv1 = Convolution2D(
        96, (7, 7), activation='relu', kernel_initializer='glorot_uniform',
        strides=(2, 2), padding='same', name='conv1',
        data_format="channels_last")(input_img)
    maxpool1 = MaxPooling2D(
        pool_size=(3, 3), strides=(2, 2), name='maxpool1',
        data_format="channels_last")(conv1)
    fire2_squeeze = Convolution2D(
        16, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire2_squeeze',
        data_format="channels_last")(maxpool1)
    fire2_expand1 = Convolution2D(
        64, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire2_expand1',
        data_format="channels_last")(fire2_squeeze)
    fire2_expand2 = Convolution2D(
        64, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire2_expand2',
        data_format="channels_last")(fire2_squeeze)
    merge2 = Concatenate(axis=-1)([fire2_expand1, fire2_expand2])

    fire3_squeeze = Convolution2D(
        16, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire3_squeeze',
        data_format="channels_last")(merge2)
    fire3_expand1 = Convolution2D(
        64, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire3_expand1',
        data_format="channels_last")(fire3_squeeze)
    fire3_expand2 = Convolution2D(
        64, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire3_expand2',
        data_format="channels_last")(fire3_squeeze)
    merge3 = Concatenate(axis=-1)([fire3_expand1, fire3_expand2])

    fire4_squeeze = Convolution2D(
        32, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire4_squeeze',
        data_format="channels_last")(merge3)
    fire4_expand1 = Convolution2D(
        128, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire4_expand1',
        data_format="channels_last")(fire4_squeeze)
    fire4_expand2 = Convolution2D(
        128, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire4_expand2',
        data_format="channels_last")(fire4_squeeze)
    merge4 = Concatenate(axis=-1)([fire4_expand1, fire4_expand2])

    maxpool4 = MaxPooling2D(
        pool_size=(3, 3), strides=(2, 2), name='maxpool4',
        data_format="channels_last")(merge4)
    fire5_squeeze = Convolution2D(
        32, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire5_squeeze',
        data_format="channels_last")(maxpool4)
    fire5_expand1 = Convolution2D(
        128, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire5_expand1',
        data_format="channels_last")(fire5_squeeze)
    fire5_expand2 = Convolution2D(
        128, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire5_expand2',
        data_format="channels_last")(fire5_squeeze)
    merge5 = Concatenate(axis=-1)([fire5_expand1, fire5_expand2])

    fire6_squeeze = Convolution2D(
        48, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire6_squeeze',
        data_format="channels_last")(merge5)
    fire6_expand1 = Convolution2D(
        192, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire6_expand1',
        data_format="channels_last")(fire6_squeeze)
    fire6_expand2 = Convolution2D(
        192, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire6_expand2',
        data_format="channels_last")(fire6_squeeze)
    merge6 = Concatenate(axis=-1)([fire6_expand1, fire6_expand2])

    fire7_squeeze = Convolution2D(
        48, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire7_squeeze',
        data_format="channels_last")(merge6)
    fire7_expand1 = Convolution2D(
        192, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire7_expand1',
        data_format="channels_last")(fire7_squeeze)
    fire7_expand2 = Convolution2D(
        192, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire7_expand2',
        data_format="channels_last")(fire7_squeeze)
    merge7 = Concatenate(axis=-1)([fire7_expand1, fire7_expand2])

    fire8_squeeze = Convolution2D(
        64, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire8_squeeze',
        data_format="channels_last")(merge7)
    fire8_expand1 = Convolution2D(
        256, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire8_expand1',
        data_format="channels_last")(fire8_squeeze)
    fire8_expand2 = Convolution2D(
        256, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire8_expand2',
        data_format="channels_last")(fire8_squeeze)
    merge8 = Concatenate(axis=-1)([fire8_expand1, fire8_expand2])

    maxpool8 = MaxPooling2D(
        pool_size=(3, 3), strides=(2, 2), name='maxpool8',
        data_format="channels_last")(merge8)
    fire9_squeeze = Convolution2D(
        64, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire9_squeeze',
        data_format="channels_last")(maxpool8)
    fire9_expand1 = Convolution2D(
        256, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire9_expand1',
        data_format="channels_last")(fire9_squeeze)
    fire9_expand2 = Convolution2D(
        256, (3, 3), activation='relu', kernel_initializer='glorot_uniform',
        padding='same', name='fire9_expand2',
        data_format="channels_last")(fire9_squeeze)
    merge9 = Concatenate(axis=-1)([fire9_expand1, fire9_expand2])

    fire9_dropout = Dropout(0.5, name='fire9_dropout')(merge9)
    conv10 = Convolution2D(
        nb_classes, (1, 1), activation='relu', kernel_initializer='glorot_uniform',
        padding='valid', name='conv10',
        data_format="channels_last")(fire9_dropout)

    global_avgpool10 = GlobalAveragePooling2D(data_format='channels_last')(conv10)
    softmax = Activation("softmax", name='softmax')(global_avgpool10)
    return Model(inputs=input_img, outputs=softmax)

We load the trained Squeezenet model

In [None]:
model = SqueezeNet(nb_classes=10, inputs=(224, 224,3))
num_classes=10
model.load_weights('Squeeze_net.h5')
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1 (Conv2D)                 (None, 112, 112, 96  14208       ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 maxpool1 (MaxPooling2D)        (None, 55, 55, 96)   0           ['conv1[0][0]']                  
                                                                                              

In [None]:
model.save_weights('trained_model.weights.h5')

We check the accuracy of the Vanilla model

In [None]:
val_steps = val_generator1.samples // batch_size

predictions = model.predict(val_generator1, steps=val_steps)

true_labels = val_generator1.classes[:val_steps * batch_size]

num_classes = len(val_generator1.class_indices)
true_labels_one_hot = tf.keras.utils.to_categorical(true_labels, num_classes)

top1_metric = tf.keras.metrics.CategoricalAccuracy()
top5_metric = tf.keras.metrics.TopKCategoricalAccuracy(k=5)

top1_metric.update_state(true_labels_one_hot, predictions)
top5_metric.update_state(true_labels_one_hot, predictions)

print("Top-1 Accuracy on validation set:", top1_metric.result().numpy())
print("Top-5 Accuracy on validation set:", top5_metric.result().numpy())

Top-1 Accuracy on validation set: 0.77715164
Top-5 Accuracy on validation set: 0.9695184


We prune the model with sparcity values indicated in the paper and fine-tune the weights after pruning.

In [None]:
steps_per_epoch = train_generator.samples // batch_size
total_epochs = 10

desired_sparsities = {
    "fire2_expand2": 0.66,
    "fire3_expand2": 0.66,
    "fire4_expand2": 0.66,
    "fire5_expand2": 0.66,
    "fire6_expand1": 0.5,
    "fire6_expand2": 0.66,
    "fire7_squeeze": 0.5,
    "fire7_expand2": 0.66,
    "fire8_expand1": 0.5,
    "fire8_expand2": 0.66,
    "fire9_squeeze": 0.5,
    "fire9_expand2": 0.7,
    "conv10": 0.8
}

def custom_clone(layer):
    if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.Dense)):
        for key_substring, final_sparsity in desired_sparsities.items():
            if key_substring in layer.name:
                print(f"Wrapping layer {layer.name} with final_sparsity={final_sparsity}")
                pruning_params = {
                    'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(
                        target_sparsity=final_sparsity,
                        begin_step=0,
                        end_step=1000,
                        frequency=100
                    )
                }
                return tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)
    return layer

pruned_model = clone_model(model, clone_function=custom_clone)
pruned_model.build(model.input_shape)
pruned_model.summary()

pruned_model.compile(optimizer=tf.keras.optimizers.Adam(0.0001),
                     loss='categorical_crossentropy',
                     metrics=['accuracy'])


pruning_callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep()
]

pruned_model.fit(
    train_generator,
    epochs=total_epochs,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_generator,
    validation_steps=val_generator.samples // batch_size,
    callbacks=pruning_callbacks
)

pruned_model_stripped = tfmot.sparsity.keras.strip_pruning(pruned_model)
pruned_model_stripped.summary()


Wrapping layer fire2_expand2 with final_sparsity=0.66
Wrapping layer fire3_expand2 with final_sparsity=0.66
Wrapping layer fire4_expand2 with final_sparsity=0.66
Wrapping layer fire5_expand2 with final_sparsity=0.66
Wrapping layer fire6_expand1 with final_sparsity=0.5
Wrapping layer fire6_expand2 with final_sparsity=0.66
Wrapping layer fire7_squeeze with final_sparsity=0.5
Wrapping layer fire7_expand2 with final_sparsity=0.66
Wrapping layer fire8_expand1 with final_sparsity=0.5
Wrapping layer fire8_expand2 with final_sparsity=0.66
Wrapping layer fire9_squeeze with final_sparsity=0.5
Wrapping layer fire9_expand2 with final_sparsity=0.7
Wrapping layer conv10 with final_sparsity=0.8
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               

We check the accuracy of the pruned model

In [None]:
val_steps = val_generator1.samples // batch_size

predictions = pruned_model_stripped.predict(val_generator1, steps=val_steps)

true_labels = val_generator1.classes[:val_steps * batch_size]

num_classes = len(val_generator1.class_indices)
true_labels_one_hot = tf.keras.utils.to_categorical(true_labels, num_classes)

top1_metric = tf.keras.metrics.CategoricalAccuracy()
top5_metric = tf.keras.metrics.TopKCategoricalAccuracy(k=5)

top1_metric.update_state(true_labels_one_hot, predictions)
top5_metric.update_state(true_labels_one_hot, predictions)

print("Top-1 Accuracy on validation set:", top1_metric.result().numpy())
print("Top-5 Accuracy on validation set:", top5_metric.result().numpy())

Top-1 Accuracy on validation set: 0.77459013
Top-5 Accuracy on validation set: 0.9710553


We verify that the model has been pruned by checking the sparcity of each layer

In [None]:
for layer in pruned_model_stripped.layers:
    weights = layer.get_weights()
    if weights:
        print(f"Layer: {layer.name}")
        for idx, w in enumerate(weights):
            total = np.prod(w.shape)
            nonzeros = np.count_nonzero(w)
            sparsity = 1 - (nonzeros / total)
            print(f"  Weight {idx}: shape={w.shape}, total={total}, nonzeros={nonzeros}, sparsity={sparsity:.2%}")


Layer: conv1
  Weight 0: shape=(7, 7, 3, 96), total=14112, nonzeros=14112, sparsity=0.00%
  Weight 1: shape=(96,), total=96, nonzeros=96, sparsity=0.00%
Layer: fire2_squeeze
  Weight 0: shape=(1, 1, 96, 16), total=1536, nonzeros=1536, sparsity=0.00%
  Weight 1: shape=(16,), total=16, nonzeros=16, sparsity=0.00%
Layer: fire2_expand1
  Weight 0: shape=(1, 1, 16, 64), total=1024, nonzeros=1024, sparsity=0.00%
  Weight 1: shape=(64,), total=64, nonzeros=64, sparsity=0.00%
Layer: fire2_expand2
  Weight 0: shape=(3, 3, 16, 64), total=9216, nonzeros=3133, sparsity=66.00%
  Weight 1: shape=(64,), total=64, nonzeros=64, sparsity=0.00%
Layer: fire3_squeeze
  Weight 0: shape=(1, 1, 128, 16), total=2048, nonzeros=2048, sparsity=0.00%
  Weight 1: shape=(16,), total=16, nonzeros=16, sparsity=0.00%
Layer: fire3_expand1
  Weight 0: shape=(1, 1, 16, 64), total=1024, nonzeros=1024, sparsity=0.00%
  Weight 1: shape=(64,), total=64, nonzeros=64, sparsity=0.00%
Layer: fire3_expand2
  Weight 0: shape=(3, 3,

We save the pruned model

In [None]:
pruned_model_stripped.save("pruned_squeezenet_model", save_format="tf")
pruned_model_stripped.save_weights('Squeeze_net_pruned.weights.h5')





INFO:tensorflow:Assets written to: pruned_squeezenet_model\assets


INFO:tensorflow:Assets written to: pruned_squeezenet_model\assets


In [None]:
loaded_model = tf.keras.models.load_model("pruned_squeezenet_model")
loaded_model.summary()





Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1 (Conv2D)                 (None, 112, 112, 96  14208       ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 maxpool1 (MaxPooling2D)        (None, 55, 55, 96)   0           ['conv1[0][0]']                  
                                                                                              

We apply quantization with 2^6=64 clusters and then fine-tune the weights

In [None]:
batch_size=64
steps_per_epoch= val_generator.samples // batch_size
n_cluster=64

clustering_params = {
    'number_of_clusters': n_cluster,
}


clustered_model = tfmot.clustering.keras.cluster_weights(loaded_model, **clustering_params)

clustered_model.compile(
    optimizer=tf.keras.optimizers.Adam(0.0001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)


class UpdateClusterStepCallback(tf.keras.callbacks.Callback):
    def on_batch_end(self, batch, logs=None):
        for layer in self.model.layers:
            if hasattr(layer, 'update_cluster_step'):
                layer.update_cluster_step()

clustering_callbacks = [UpdateClusterStepCallback()]


clustered_model.fit(
    train_generator,
    epochs=10,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_generator,
    validation_steps=val_generator.samples // batch_size,
    callbacks=clustering_callbacks
)

final_model = tfmot.clustering.keras.strip_clustering(clustered_model)

final_model.summary()


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1 (Conv2D)                 (None, 112, 112, 96  14208       ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 maxpool1 (MaxPooling2D)        (None, 55, 55, 96)   0           ['conv1[0][0]']  

We verify that the model has been quantized by checking that each layer has at most 64 distinct weights

In [None]:
def verify_quantized_weights(model, max_unique=n_cluster):
    for layer in model.layers:
        weights = layer.get_weights()
        if not weights:
            continue
        for idx, weight in enumerate(weights):
            unique_vals = np.unique(weight)
            num_unique = len(unique_vals)
            print(f"Layer '{layer.name}', weight index {idx}: {num_unique} unique values")
            if num_unique > max_unique:
                print(f"  Warning: {num_unique} > {max_unique} unique values!")
            else:
                print(f"  OK: {num_unique} <= {max_unique} unique values.")

verify_quantized_weights(final_model, max_unique=n_cluster)


Layer 'conv1', weight index 0: 64 unique values
  OK: 64 <= 64 unique values.
Layer 'conv1', weight index 1: 96 unique values
Layer 'fire2_squeeze', weight index 0: 64 unique values
  OK: 64 <= 64 unique values.
Layer 'fire2_squeeze', weight index 1: 16 unique values
  OK: 16 <= 64 unique values.
Layer 'fire2_expand1', weight index 0: 64 unique values
  OK: 64 <= 64 unique values.
Layer 'fire2_expand1', weight index 1: 64 unique values
  OK: 64 <= 64 unique values.
Layer 'fire2_expand2', weight index 0: 63 unique values
  OK: 63 <= 64 unique values.
Layer 'fire2_expand2', weight index 1: 64 unique values
  OK: 64 <= 64 unique values.
Layer 'fire3_squeeze', weight index 0: 64 unique values
  OK: 64 <= 64 unique values.
Layer 'fire3_squeeze', weight index 1: 16 unique values
  OK: 16 <= 64 unique values.
Layer 'fire3_expand1', weight index 0: 64 unique values
  OK: 64 <= 64 unique values.
Layer 'fire3_expand1', weight index 1: 64 unique values
  OK: 64 <= 64 unique values.
Layer 'fire3_e

We check the accuracy of the pruned and quantized model

In [None]:
val_steps = val_generator1.samples // batch_size

predictions = final_model.predict(val_generator1, steps=val_steps)

true_labels = val_generator1.classes[:val_steps * batch_size]

num_classes = len(val_generator1.class_indices)
true_labels_one_hot = tf.keras.utils.to_categorical(true_labels, num_classes)

top1_metric = tf.keras.metrics.CategoricalAccuracy()
top5_metric = tf.keras.metrics.TopKCategoricalAccuracy(k=5)

top1_metric.update_state(true_labels_one_hot, predictions)
top5_metric.update_state(true_labels_one_hot, predictions)

print("Top-1 Accuracy on validation set:", top1_metric.result().numpy())
print("Top-5 Accuracy on validation set:", top5_metric.result().numpy())

Top-1 Accuracy on validation set: 0.7789447
Top-5 Accuracy on validation set: 0.97131145


In [None]:
final_model.save_weights('quantized_model_6bits.weights.h5')
final_model.save("final_quantized_model_6bits", save_format="tf")





INFO:tensorflow:Assets written to: final_quantized_model_6bits\assets


INFO:tensorflow:Assets written to: final_quantized_model_6bits\assets


In [None]:
loaded_final_model = tf.keras.models.load_model("final_quantized_model_6bits")
loaded_final_model.summary()





Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1 (Conv2D)                 (None, 112, 112, 96  14208       ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 maxpool1 (MaxPooling2D)        (None, 55, 55, 96)   0           ['conv1[0][0]']                  
                                                                                              

We apply Huffman encoding the save the model more efficiently

In [None]:
class HuffmanNode:
    def __init__(self, symbol, freq):
        self.symbol = symbol
        self.freq = freq
        self.left = None
        self.right = None
    def __lt__(self, other):
        return self.freq < other.freq

def build_huffman_tree(freq_dict):
    heap = []
    for symbol, freq in freq_dict.items():
        heapq.heappush(heap, HuffmanNode(symbol, freq))
    while len(heap) > 1:
        node1 = heapq.heappop(heap)
        node2 = heapq.heappop(heap)
        merged = HuffmanNode(None, node1.freq + node2.freq)
        merged.left = node1
        merged.right = node2
        heapq.heappush(heap, merged)
    return heap[0]

def build_huffman_codes(root):
    codes = {}
    def helper(node, current_code):
        if node is None:
            return
        if node.symbol is not None:
            codes[node.symbol] = current_code
            return
        helper(node.left, current_code + "0")
        helper(node.right, current_code + "1")
    helper(root, "")
    return codes

def huffman_encode_array(arr):
    flat = arr.flatten()
    freq = Counter(flat)
    root = build_huffman_tree(freq)
    codes = build_huffman_codes(root)
    encoded = "".join(codes[val] for val in flat)
    return encoded, codes

def bits_to_bytes(bit_string):
    extra_bits = (8 - len(bit_string) % 8) % 8
    bit_string += "0" * extra_bits
    byte_array = bytearray()
    for i in range(0, len(bit_string), 8):
        byte = bit_string[i:i+8]
        byte_array.append(int(byte, 2))
    return bytes(byte_array), extra_bits


weights = loaded_final_model.get_weights()

encoded_weights = {}
for idx, w in enumerate(weights):
    encoded, codes = huffman_encode_array(w)
    encoded_bytes, pad = bits_to_bytes(encoded)
    encoded_weights[f"weight_{idx}"] = {
        "encoded_bytes": encoded_bytes,
        "pad": pad,
        "codes": codes,
        "original_shape": w.shape
    }
    print(f"Weight {idx}: original shape {w.shape}, encoded length: {len(encoded_bytes)} bytes (after conversion)")


with gzip.open("huffman_compressed_weights.pkl.gz", "wb") as f:
    pickle.dump(encoded_weights, f)

print("Huffman compression complete and saved to 'huffman_compressed_weights.pkl.gz'.")


Weight 0: original shape (7, 7, 3, 96), encoded length: 10120 bytes (after conversion)
Weight 1: original shape (96,), encoded length: 80 bytes (after conversion)
Weight 2: original shape (1, 1, 96, 16), encoded length: 1137 bytes (after conversion)
Weight 3: original shape (16,), encoded length: 8 bytes (after conversion)
Weight 4: original shape (1, 1, 16, 64), encoded length: 761 bytes (after conversion)
Weight 5: original shape (64,), encoded length: 48 bytes (after conversion)
Weight 6: original shape (3, 3, 16, 64), encoded length: 3315 bytes (after conversion)
Weight 7: original shape (64,), encoded length: 48 bytes (after conversion)
Weight 8: original shape (1, 1, 128, 16), encoded length: 1513 bytes (after conversion)
Weight 9: original shape (16,), encoded length: 8 bytes (after conversion)
Weight 10: original shape (1, 1, 16, 64), encoded length: 757 bytes (after conversion)
Weight 11: original shape (64,), encoded length: 48 bytes (after conversion)
Weight 12: original sha

We load the model back to check that it was stored correctly

In [None]:

def bytes_to_bits(b):
    return "".join(f"{byte:08b}" for byte in b)


def huffman_decode(encoded, codes):
    inv_codes = {v: k for k, v in codes.items()}
    decoded_vals = []
    current_code = ""
    for bit in encoded:
        current_code += bit
        if current_code in inv_codes:
            decoded_vals.append(inv_codes[current_code])
            current_code = ""
    return np.array(decoded_vals, dtype=np.float32)

with gzip.open("huffman_compressed_weights.pkl.gz", "rb") as f:
    encoded_weights = pickle.load(f)

decoded_weights = []
for idx in range(len(encoded_weights)):
    weight_info = encoded_weights[f"weight_{idx}"]
    encoded_bytes = weight_info["encoded_bytes"]
    pad = weight_info["pad"]
    codes = weight_info["codes"]
    original_shape = weight_info["original_shape"]

    bit_string = bytes_to_bits(encoded_bytes)
    if pad:
        bit_string = bit_string[:-pad]

    flat_decoded = huffman_decode(bit_string, codes)
    weight_decoded = flat_decoded.reshape(original_shape)
    decoded_weights.append(weight_decoded)
    print(f"Decoded weight_{idx}: shape {weight_decoded.shape}")

restored_model = SqueezeNet(nb_classes=10, inputs=(224, 224, 3))

restored_model.set_weights(decoded_weights)

restored_model.summary()


Decoded weight_0: shape (7, 7, 3, 96)
Decoded weight_1: shape (96,)
Decoded weight_2: shape (1, 1, 96, 16)
Decoded weight_3: shape (16,)
Decoded weight_4: shape (1, 1, 16, 64)
Decoded weight_5: shape (64,)
Decoded weight_6: shape (3, 3, 16, 64)
Decoded weight_7: shape (64,)
Decoded weight_8: shape (1, 1, 128, 16)
Decoded weight_9: shape (16,)
Decoded weight_10: shape (1, 1, 16, 64)
Decoded weight_11: shape (64,)
Decoded weight_12: shape (3, 3, 16, 64)
Decoded weight_13: shape (64,)
Decoded weight_14: shape (1, 1, 128, 32)
Decoded weight_15: shape (32,)
Decoded weight_16: shape (1, 1, 32, 128)
Decoded weight_17: shape (128,)
Decoded weight_18: shape (3, 3, 32, 128)
Decoded weight_19: shape (128,)
Decoded weight_20: shape (1, 1, 256, 32)
Decoded weight_21: shape (32,)
Decoded weight_22: shape (1, 1, 32, 128)
Decoded weight_23: shape (128,)
Decoded weight_24: shape (3, 3, 32, 128)
Decoded weight_25: shape (128,)
Decoded weight_26: shape (1, 1, 256, 48)
Decoded weight_27: shape (48,)
Decod

In [None]:
val_steps = val_generator1.samples // batch_size

predictions = restored_model.predict(val_generator1, steps=val_steps)
true_labels = val_generator1.classes[:val_steps * batch_size]

num_classes = len(val_generator1.class_indices)
true_labels_one_hot = tf.keras.utils.to_categorical(true_labels, num_classes)

top1_metric = tf.keras.metrics.CategoricalAccuracy()
top5_metric = tf.keras.metrics.TopKCategoricalAccuracy(k=5)

top1_metric.update_state(true_labels_one_hot, predictions)
top5_metric.update_state(true_labels_one_hot, predictions)

print("Top-1 Accuracy on validation set:", top1_metric.result().numpy())
print("Top-5 Accuracy on validation set:", top5_metric.result().numpy())

Top-1 Accuracy on validation set: 0.7789447
Top-5 Accuracy on validation set: 0.97131145
