In [1]:
import matplotlib.pyplot as plt
import numpy as np # linear algebra
import os
import pandas as pd # data processing
import tensorflow as tf
from tensorflow.keras import datasets, optimizers, metrics
import time
# import tensorflow.compat.v1 as tf
# tf.disable_v2_behavior()
# tf.compat.v1.disable_eager_execution()

BASE_PATH = "/kaggle/input"
print(tf.__version__)

2.1.0-rc0


In [2]:
# inputs shapes
img_rows, img_cols = (28, 28)
num_input = img_rows * img_cols
num_classes = 10
batch_size = 64
epochs = 10

In [3]:
(xtrain, ytrain), (xval, yval) = datasets.fashion_mnist.load_data()
print('data shapes:', xtrain.shape, ytrain.shape, xval.shape, yval.shape)

xtrain = tf.convert_to_tensor(xtrain, dtype=tf.float32)/255.
xval = tf.convert_to_tensor(xval, dtype=tf.float32)/255.

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
data shapes: (60000, 28, 28) (60000,) (10000, 28, 28) (10000,)


In [4]:
train_dataset = tf.data.Dataset.from_tensor_slices((xtrain, ytrain)).batch(batch_size)
val_dataset = tf.data.Dataset.from_tensor_slices((xval, yval)).batch(batch_size)

In [5]:
# network = Sequential([layers.Dense(1000, activation='relu'),
#                       layers.Dense(1000, activation='relu'),
#                       layers.Dense(500, activation='relu'),
#                       layers.Dense(200, activation='relu'),
#                       layers.Dense(10)])

# network.build(input_shape=(None, num_input))
# # network.summary()

In [6]:
class FmnistModel(object):
    
    def __init__(self, output_shapes, param_initializer):
        
        self.output_shapes = output_shapes
        self.initializer = param_initializer
        self.trainable_params = []
        
        # intialize & store the weights for the model
        for i in range(len(self.output_shapes)):
            weight = self.get_weight(self.output_shapes[i], 
                                     name='weight_{}'.format(i))
            bias= self.get_bias(self.output_shapes[i][-1],
                                name='bias_{}'.format(i))
            
            self.trainable_params.append(weight)
            self.trainable_params.append(bias)

    def __call__(self, x):
        """
        input transformations
        """
        x = self.dense(x, self.trainable_params[0], self.trainable_params[1])
        x = self.dense(x, self.trainable_params[2], self.trainable_params[3])
        x = self.dense(x, self.trainable_params[4], self.trainable_params[5])
        x = self.dense(x, self.trainable_params[6], self.trainable_params[7])
        logits = tf.add(tf.matmul(x, self.trainable_params[8]), self.trainable_params[9])
        return logits
        
    def dense(self, x, W, b):
        
        """
        A function with operations of simple dense layer
        """
        
        # intialize
        x_is_sparse, W_is_sparse = False, False
#         sparse_limit = tf.constant(0.3)
#         # check for sparsity
#         if tf.greater(tf.nn.zero_fraction(x), sparse_limit):
#             x_is_sparse = True
#         if tf.greater(tf.nn.zero_fraction(W), sparse_limit):
#             W_is_sparse = True
        # matmul x, W
        xW = tf.matmul(x, W, a_is_sparse = x_is_sparse, 
                       b_is_sparse = W_is_sparse)
        return tf.nn.relu(tf.add(xW, b))
    
    def get_weight(self, shape , name):
        """
        to intialize the weights given shape
        """
        return tf.Variable(self.initializer(shape) , name=name)
    
    def get_bias(self, units, name):
        """
        to intialize the bias with given no.of units
        """
        return tf.Variable(self.initializer([units]), name=name)

In [7]:
# define weight intializer
initializer = tf.initializers.glorot_uniform()
# define no.of classes
num_classes = 10
# output shapes
shapes = [
    [ 28*28*1 , 1000 ] , 
    [ 1000 , 1000 ] ,
    [ 1000 , 500 ] , 
    [ 500 , 200 ] ,
    [ 200 , num_classes] ,
]

# initialize the model with output_shapes & param_intializer
network = FmnistModel(output_shapes = shapes, param_initializer = initializer)
len(network.trainable_params)

10

In [8]:
def loss_fn(y_pred, y_true):
    """
    loss function
    """
    return tf.reduce_sum(tf.square(y_pred-y_true))

# intialize optimizer with lr = 0.01
optimizer = optimizers.SGD(lr=0.01)

In [9]:
acc_meter = metrics.Accuracy()
val_acc = metrics.Accuracy()
epochs = 10

# iter over epochs
for e in range(epochs):
    epoch_start = time.time()
    # iter over train data
    for step, (xt, yt) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            # [bs, 28, 28] => [bs, 784]
            xt = tf.reshape(xt, (-1, 28*28))
            # [bs, 784] => [bs, 10]
            y_pred = network(xt)
            # [bs] => [bs, 10]
            y_true = tf.one_hot(yt, depth=10)
            # compute loss
            loss = loss_fn(y_pred, y_true)/batch_size
        
        # calc train accuracy
        acc_meter.update_state(tf.argmax(y_pred, axis=1), yt)
        # compute grads & apply them
        grads = tape.gradient(loss, network.trainable_params)
        optimizer.apply_gradients(zip(grads, network.trainable_params))

    # iter over val data
    for xv, yv in val_dataset:
        xv = tf.reshape(xv, (-1, 28*28))
        y_pred_val = network(xv)
        val_loss = loss_fn(y_pred_val, tf.one_hot(yv, depth=10))/batch_size
        val_acc.update_state(tf.argmax(y_pred_val, axis=1), yv)
        
    epoch_end = time.time()

    print('Epoch: %d' %e, ' Loss: %.5f' %float(loss), 
          ' Acc: %.3f' %acc_meter.result().numpy(), 
          ' Val_Loss: %.5f' %float(val_loss), 
          ' Val_Acc: %.3f' %val_acc.result().numpy(),
          ' Time(sec): %.2f' %(epoch_end-epoch_start))
    
    # reset states of acc meters
    acc_meter.reset_states()
    val_acc.reset_states()

Epoch: 0  Loss: 0.15970  Acc: 0.796  Val_Loss: 0.05350  Val_Acc: 0.826  Time(sec): 21.99
Epoch: 1  Loss: 0.13152  Acc: 0.847  Val_Loss: 0.04982  Val_Acc: 0.842  Time(sec): 21.58
Epoch: 2  Loss: 0.11621  Acc: 0.860  Val_Loss: 0.04669  Val_Acc: 0.849  Time(sec): 21.62
Epoch: 3  Loss: 0.10676  Acc: 0.868  Val_Loss: 0.04500  Val_Acc: 0.856  Time(sec): 21.79
Epoch: 4  Loss: 0.10064  Acc: 0.874  Val_Loss: 0.04330  Val_Acc: 0.861  Time(sec): 21.73
Epoch: 5  Loss: 0.09489  Acc: 0.879  Val_Loss: 0.04182  Val_Acc: 0.863  Time(sec): 21.69
Epoch: 6  Loss: 0.09003  Acc: 0.883  Val_Loss: 0.04006  Val_Acc: 0.866  Time(sec): 21.41
Epoch: 7  Loss: 0.08711  Acc: 0.887  Val_Loss: 0.03845  Val_Acc: 0.868  Time(sec): 21.28
Epoch: 8  Loss: 0.08390  Acc: 0.890  Val_Loss: 0.03677  Val_Acc: 0.870  Time(sec): 21.25
Epoch: 9  Loss: 0.08149  Acc: 0.892  Val_Loss: 0.03511  Val_Acc: 0.871  Time(sec): 21.11


In [10]:
def weight_pruning(w, s):
    
    """Performs pruning on a weight matrix w:

    1. Compute absolute value of all elements.
    2. The indices of the top k% elements according to their absolute values are selected.
    3. A new tensor is formed with indices of topK% elements set to 1.
    4. The new tensor will be used as mask & multiplied with the original weights

    Args:(w: tf.Variable, s: float)
    ------
        w: The weight matrix.
        k: The percentage of values (units) that should be pruned from the matrix.

    Returns: tf.Variable
    -------
        The pruned weight matrix.

    """
    # store the original w shape
    w_shape = tf.shape(w)
    # calc % of weights to retain (notice multiplication with 1-s) & type cast to int32
    s = tf.cast(tf.size(w, out_type=tf.float32)*tf.constant(1-s), dtype=tf.int32)
    # flatten w
    w_reshaped = tf.reshape(w, [-1])
    # get indices to keep only top s% weights
    _, indices = tf.nn.top_k(tf.abs(w_reshaped), s, sorted=True)
    # make a mask with top indices values = 1
    mask = tf.scatter_nd(tf.reshape(indices, [-1, 1]),
                         tf.ones([s], tf.float32), tf.shape(w_reshaped),
                         name = 'pruning_mask')
    # multiply, reshape, assign & return the weight
    return w.assign(tf.reshape(w_reshaped * mask, w_shape))

In [11]:
# intialize the metris & constants
acc_meter = metrics.Accuracy()
val_acc = metrics.Accuracy()
epochs = 10
total_params = len(network.trainable_params)

# iter over epochs
for e in range(epochs):
    epoch_start = time.time()
    # iter over train data
    for step, (xt, yt) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            # [bs, 28, 28] => [bs, 784]
            xt = tf.reshape(xt, (-1, 28*28))
            # [bs, 784] => [bs, 10]
            y_pred = network(xt)
            # [bs] => [bs, 10]
            y_true = tf.one_hot(yt, depth=10)
            # compute loss
            loss = loss_fn(y_pred, y_true)/batch_size

        # calc train accuracy
        acc_meter.update_state(tf.argmax(y_pred, axis=1), yt)
        # compute grads & apply them
        grads = tape.gradient(loss, network.trainable_params)
        optimizer.apply_gradients(zip(grads, network.trainable_params))
        
    # pruning weights after end of the epoch
    for i in range(0, total_params-2, 2):
        network.trainable_params[i] = weight_pruning(network.trainable_params[i], s = 0.4)

    # iter over val data
    for xv, yv in val_dataset:
        xv = tf.reshape(xv, (-1, 28*28))
        y_pred_val = network(xv)
        val_loss = loss_fn(y_pred_val, tf.one_hot(yv, depth=10))/batch_size
        val_acc.update_state(tf.argmax(y_pred_val, axis=1), yv)
        
    epoch_end = time.time()
    
    # print the epoch results
    print('Epoch: %d' %e, ' Loss: %.5f' %float(loss), 
          ' Acc: %.3f' %acc_meter.result().numpy(), 
          ' Val_Loss: %.5f' %float(val_loss), 
          ' Val_Acc: %.3f' %val_acc.result().numpy(),
          ' Time(sec): %.2f' %(epoch_end-epoch_start))
    
    # reset states of acc meters
    acc_meter.reset_states()
    val_acc.reset_states()

Epoch: 0  Loss: 0.07948  Acc: 0.895  Val_Loss: 0.04976  Val_Acc: 0.858  Time(sec): 21.94
Epoch: 1  Loss: 0.09410  Acc: 0.889  Val_Loss: 0.03721  Val_Acc: 0.869  Time(sec): 16.31
Epoch: 2  Loss: 0.09023  Acc: 0.891  Val_Loss: 0.03664  Val_Acc: 0.871  Time(sec): 16.53
Epoch: 3  Loss: 0.08793  Acc: 0.892  Val_Loss: 0.03661  Val_Acc: 0.872  Time(sec): 16.76
Epoch: 4  Loss: 0.08628  Acc: 0.892  Val_Loss: 0.03660  Val_Acc: 0.872  Time(sec): 16.40
Epoch: 5  Loss: 0.08489  Acc: 0.893  Val_Loss: 0.03653  Val_Acc: 0.872  Time(sec): 16.39
Epoch: 6  Loss: 0.08376  Acc: 0.893  Val_Loss: 0.03641  Val_Acc: 0.872  Time(sec): 16.38
Epoch: 7  Loss: 0.08271  Acc: 0.893  Val_Loss: 0.03630  Val_Acc: 0.872  Time(sec): 16.48
Epoch: 8  Loss: 0.08177  Acc: 0.893  Val_Loss: 0.03620  Val_Acc: 0.873  Time(sec): 16.51
Epoch: 9  Loss: 0.08091  Acc: 0.894  Val_Loss: 0.03610  Val_Acc: 0.873  Time(sec): 16.47


- As network & training time is already small in our case, Sparse matmul is taking the overhead for conversion & the time taken for epoch is actually more than incase of dense - So, guess not an option in this case.
- For stepwise pruning, again time gain is outweight by purning operation which will occur starting from 1000th step of every epoch.
- There is about 20-25% time gain (after 1st epoch), when using dense matmul & weights are pruned (s = 0.4) after each epoch.