In [1]:
# untar
!tar -xvzf dataset.tar.gz

train_images.pkl
train_labels.pkl
val_images.pkl
val_labels.pkl


In [2]:
import os
import numpy as np
import pickle

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, models, regularizers
from tensorflow.keras.layers import *
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from copy import deepcopy

from matplotlib import pyplot as plt

%matplotlib inline

print(tf.version.VERSION)

2.12.0


In [None]:
root_dir = "../pruned_models_and_notebooks_global_thresholding/"

In [3]:
# load train
train_images = pickle.load(open("../train_images.pkl", "rb"))
train_labels = pickle.load(open("../train_labels.pkl", "rb"))

# load val
val_images = pickle.load(open("../val_images.pkl", "rb"))
val_labels = pickle.load(open("../val_labels.pkl", "rb"))


datagen = ImageDataGenerator(
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
)

In [4]:
class CustomModel(keras.models.Sequential):
    def __init__(self, weights, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.custom_masks = [tf.cast(weight != 0, tf.float32) for weight in weights]

    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        trainable_vars = self.trainable_variables
        
        gradients = tape.gradient(loss, trainable_vars)
        gradients = [tf.multiply(grad, self.custom_masks[i]) for i, grad in enumerate(gradients)]
        
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.compiled_metrics.update_state(y, y_pred)
        
        return {m.name: m.result() for m in self.metrics}
    
def get_custom_model(weights):
    model = CustomModel(deepcopy(weights))
    model.add(Conv2D(32, (3, 3), padding="same", kernel_regularizer=regularizers.l2(1e-5), input_shape=(25,25,3)))
    model.add(Activation("relu"))
    model.add(Conv2D(32, (3, 3), kernel_regularizer=regularizers.l2(1e-5)))
    model.add(Activation("relu"))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Conv2D(64, (3, 3), padding="same", kernel_regularizer=regularizers.l2(1e-5)))
    model.add(Activation("relu"))
    model.add(Conv2D(64, (3, 3), kernel_regularizer=regularizers.l2(1e-5)))
    model.add(Activation("relu"))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(512))
    model.add(Activation("relu"))
    model.add(Dropout(0.5))
    model.add(Dense(5))
    model.add(Activation("softmax"))
    return model


def get_original_model():
    model = models.Sequential()
    model.add(Conv2D(32, (3, 3), padding="same", kernel_regularizer=regularizers.l2(1e-5), input_shape=(25,25,3)))
    model.add(Activation("relu"))
    model.add(Conv2D(32, (3, 3), kernel_regularizer=regularizers.l2(1e-5)))
    model.add(Activation("relu"))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Conv2D(64, (3, 3), padding="same", kernel_regularizer=regularizers.l2(1e-5)))
    model.add(Activation("relu"))
    model.add(Conv2D(64, (3, 3), kernel_regularizer=regularizers.l2(1e-5)))
    model.add(Activation("relu"))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(512))
    model.add(Activation("relu"))
    model.add(Dropout(0.5))
    model.add(Dense(5))
    model.add(Activation("softmax"))
    return model

In [5]:
batch_size = 128

train_dataset = datagen.flow(train_images, train_labels, batch_size=batch_size)

In [6]:
def measure_sparsity(weights):
    num_zeros = 0
    num_nonzeros = 0
    for weight in weights:
        z = tf.math.count_nonzero(tf.equal(weight, 0)).numpy()
        nz = tf.size(weight).numpy() - z
        num_zeros += z
        num_nonzeros += nz

    return num_zeros / (num_zeros + num_nonzeros)

In [7]:
def categorical_loss_with_label_smoothing(y, yhat):
    y = tf.one_hot(tf.cast(y, tf.int32), 5)
    yhat = tf.expand_dims(yhat, axis=1)
    return tf.keras.losses.categorical_crossentropy(y, yhat, label_smoothing=0.1)

In [8]:
def calc_std(weights):
    tmp = []
    for w in weights:
        tmp.append(w.numpy().flatten())
    oned_stackedweights = np.hstack(tmp)
    return np.std(oned_stackedweights)


def prune(weights, stdval, factor=0.1):
    pruned_weights = deepcopy(weights)
    threshold = stdval * factor
    for i, w in enumerate(pruned_weights):
        mask = tf.cast(tf.greater(
            tf.abs(w), threshold), tf.float32)
        pruned_weights[i] = (tf.multiply(w, mask))
    return pruned_weights

In [10]:
max_iterations = 100

model = get_original_model()
model_path = "unstructured_pruning_v1_factor_1.6973684210526316_sparsity_0.9168725640165077_val_acc_0.21465346217155457.h5"
model_path = os.path.join(root_dir, model_path)
model.load_weights(model_path)

custom_model = get_custom_model(model.trainable_weights)
custom_model.load_weights(model_path)

es_callback = tf.keras.callbacks.EarlyStopping(
    patience=10, 
    monitor="val_accuracy"
)

prev_acc = 0
es_stop = 10

val_accuracies = []
sparsities = []

for i in range(max_iterations):
    if i > 0:
        curr_weights_std = calc_std(custom_model.trainable_weights)
        pruned_weights = prune(custom_model.trainable_weights, curr_weights_std, factor=3)
        custom_model = get_custom_model(pruned_weights)
        custom_model.set_weights(pruned_weights)

    checkpoint_path = f"magnitude_pruning_itr_{i}.h5"

    cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_path,
        save_weights_only=True,
        verbose=1,
        monitor="val_accuracy"
    )

    custom_model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.0001, weight_decay=1e-6),
        loss=categorical_loss_with_label_smoothing,
        metrics=["accuracy"]
    )

    custom_model.fit(
        x=train_images, 
        y=train_labels, 
        epochs=50, 
        batch_size=128, 
        validation_data=(val_images, val_labels), 
        callbacks=[cp_callback, es_callback]
    )

    post_results = custom_model.evaluate(val_images, val_labels)
    model_sparsity = measure_sparsity(custom_model.trainable_weights)
    
    sparsities.append(model_sparsity)
    val_accuracies.append(post_results[1])

    print(f"Post retraining val loss: {post_results[0]} | val acc: {post_results[1]} | sparsity: {model_sparsity}")
    
    if post_results[1] < prev_acc:
        es_stop -= 1
    
    if es_stop == 0:
        break

Epoch 1/50
Epoch 1: saving model to magnitude_pruning_itr_0.h5
Epoch 2/50
Epoch 2: saving model to magnitude_pruning_itr_0.h5
Epoch 3/50
Epoch 3: saving model to magnitude_pruning_itr_0.h5
Epoch 4/50
Epoch 4: saving model to magnitude_pruning_itr_0.h5
Epoch 5/50
Epoch 5: saving model to magnitude_pruning_itr_0.h5
Epoch 6/50
Epoch 6: saving model to magnitude_pruning_itr_0.h5
Epoch 7/50
Epoch 7: saving model to magnitude_pruning_itr_0.h5
Epoch 8/50
Epoch 8: saving model to magnitude_pruning_itr_0.h5
Epoch 9/50
Epoch 9: saving model to magnitude_pruning_itr_0.h5
Epoch 10/50
Epoch 10: saving model to magnitude_pruning_itr_0.h5
Epoch 11/50
Epoch 11: saving model to magnitude_pruning_itr_0.h5
Epoch 12/50
Epoch 12: saving model to magnitude_pruning_itr_0.h5
Epoch 13/50
Epoch 13: saving model to magnitude_pruning_itr_0.h5
Epoch 14/50
Epoch 14: saving model to magnitude_pruning_itr_0.h5
Epoch 15/50
Epoch 15: saving model to magnitude_pruning_itr_0.h5
Epoch 16/50
Epoch 16: saving model to magni

KeyboardInterrupt: ignored

In [18]:
model = get_original_model()
model_path = "magnitude_pruning_itr_6.h5"
model.load_weights(model_path)

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.0001, weight_decay=1e-6),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=["accuracy"]
)

final_results = model.evaluate(val_images, val_labels)
final_model_sparsity = measure_sparsity(model.trainable_weights)

print(f"Post retraining val loss: {final_results[0]} | val acc: {final_results[1]} | sparsity: {final_model_sparsity}")

Post retraining val loss: 0.7592514753341675 | val acc: 0.7346534729003906 | sparsity: 0.9666168015610532
