# Optimizing Masks to create WTs

In [1]:
# importing necessary libraries and the cnn architecture I defined

from cnn_architecture import CNN2Model
from utils import *
from load_datasets import load_and_prep_dataset

import tensorflow_datasets as tfds
import tensorflow_probability as tfp
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.io import loadmat
import copy

# all the extra stuff for supermasks

all the variables i have to check their meaning:
- use bias
- dynamik scaling
- sigmoid bias
- use learning phase


In [2]:
class MaskedDense(tf.keras.layers.Dense):
    
    # untrainable normal Dense layer
    # trainable mask, that is sigmoided (maybe squished) and then multiplied to Dense
    
    def __init__(self, units,*args, **kwargs):
        super(MaskedDense, self).__init__(units, *args, **kwargs)        

        
    def build(self, input_shape):
        super(MaskedDense, self).build(input_shape)
        mask_init = tf.random.uniform(shape=self.kernel.shape,minval=-1, maxval=1, seed=None)

        self._trainable_weights.remove(self.kernel)
        self._non_trainable_weights.append(self.kernel)
        self._trainable_weights.remove(self.bias)
        self._non_trainable_weights.append(self.bias)
        
        self.kernel_mask = tf.Variable(initial_value=mask_init,
                                        trainable=True,
                                        validate_shape=True,
                                        caching_device=None,
                                        name='mask',
                                        dtype=self.dtype,
                                        shape=self.kernel.shape)

        #self.kernel_mask = tf.get_variable('mask',
        #                                   shape=self.kernel.shape,
        #                                   dtype=self.dtype,
        #                                   initializer=mask_init,
        #                                   trainable=True)
        self._trainable_weights.append(self.kernel_mask)
    
    
    def get_effective_mask(self):
        # during train, clamp all of them to 0's and 1's sampled by bernoulli and backprop the probabilities
        # during test, clamp all of them to their rounded values
        # actually, sample them too
        which_to_clamp = tf.ones(self.kernel_mask.shape)
        binary_mask = lambda: tf.cast(tfp.distributions.Bernoulli(probs=tf.nn.sigmoid(self.kernel_mask)).sample(), dtype=tf.float32)

        return which_to_clamp * binary_mask + (1 - which_to_clamp) * tf.nn.sigmoid(self.kernel_mask)
    
    
    
    # same as original call() except round some sample to {0, 1} based on a sample
    def call(self, inputs):

        effective_mask = tf.cast(tfp.distributions.Bernoulli(probs=tf.nn.sigmoid(self.kernel_mask)).sample(), dtype=tf.float32)
        effective_kernel = tf.math.multiply(self.kernel, effective_mask)

        #if self.dynamic_scaling:
            #self.ones_in_mask = tf.reduce_sum(effective_mask)
            #self.multiplier = tf.div(tf.to_float(tf.size(effective_mask)), self.ones_in_mask)
            #effective_kernel = self.multiplier * effective_kernel

        # original code from https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/keras/layers/core.py:
        inputs = tf.convert_to_tensor(inputs)
        #print("inputs: ", inputs.shape)
        #print("kernel: ", effective_kernel.shape)
        outputs = tf.linalg.matmul(inputs, effective_kernel)
        outputs = tf.nn.bias_add(outputs, self.bias)
        if self.activation is not None:
            return self.activation(outputs)
        return outputs
    
    def get_mask(self):
        return  tf.nn.sigmoid(self.kernel_mask)
    
    def get_binary_mask(self):
        return tf.math.round(tf.nn.sigmoid(self.kernel_mask))

In [3]:
class CNN2ModelMasked(tf.keras.Model):
    
    # basic
    def __init__(self):
        super(CNN2ModelMasked, self).__init__()
        
        # set biases to a value that is not exactly 0.0, so they don't get handled like pruned values
        self.bias_in = tf.keras.initializers.Constant(value=0.0000000001)
        
        self.conv1 = tf.keras.layers.Conv2D(filters=64, kernel_size=3,activation="relu", padding="same",kernel_initializer='glorot_uniform', bias_initializer=self.bias_in) # [batchsize,32,32,64]
        self.conv2 = tf.keras.layers.Conv2D(filters=64, kernel_size=3,activation="relu", padding="same",kernel_initializer='glorot_uniform', bias_initializer=self.bias_in) # [batchsize,32,32,64]
        self.maxpool = tf.keras.layers.MaxPooling2D(pool_size=(2, 2),strides=(2, 2),input_shape=(32, 32, 64)) # [batchsize,16,16,64]
        self.flatten = tf.keras.layers.Flatten() # [batch_size,16384]
        self.dense1 = MaskedDense(256, activation="relu",kernel_initializer='glorot_uniform', bias_initializer=self.bias_in) # [batch_size,256]
        self.dense2 = MaskedDense(256, activation="relu",kernel_initializer='glorot_uniform', bias_initializer=self.bias_in) # [batch_size,256]
        self.dense3 = MaskedDense(10, activation="softmax",kernel_initializer='glorot_uniform', bias_initializer=self.bias_in) # [batch_size,256]
        
        # Making the weights of the conv layers untrainable
        self.conv1.trainable = False
        self.conv2.trainable = False
    
    @tf.function
    def call(self, inputs):
        
        # adjust the dense layers to be multiplayed with trainable mask (which gets assigned binary values for this step)
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense3(x)
        return x
    
    def get_masks(self):
        return [self.dense1.get_mask(), self.dense2.get_mask(), self.dense3.get_mask()]
        
    def get_binary_masks(self):
        return [self.dense1.get_binary_mask(), self.dense2.get_binary_mask(), self.dense3.get_binary_mask()]

In [4]:
# modified train loop to also work with sparse networks (such that pruned weights remain frozen at 0.0)

def train_mask(train, test, model, num_epochs=5):
    
    # hyperparameters
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)
    loss_function= tf.keras.losses.CategoricalCrossentropy()
    
    # initializing training statistics
    train_accuracy = tf.keras.metrics.Accuracy(name='test_accuracy')
    test_accuracy = tf.keras.metrics.Accuracy(name='train_accuracy')
    train_losses = tf.keras.metrics.CategoricalCrossentropy(name='train_losses')
    test_losses = tf.keras.metrics.CategoricalCrossentropy(name='test_losses')
    train_acc = []
    test_acc = []
    train_l =[]
    test_l = []

    for epoch in tqdm(range(num_epochs), leave=False, desc="training epochs"):
        
        #train step
        for x, t in train:
            with tf.GradientTape() as tape:
                pred = model(x)
                loss = loss_function(t, pred)
                train_losses.update_state(t, pred)
                train_accuracy.update_state(tf.argmax(t,1), tf.argmax(pred,1))
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))
            
        # test step
        for x, t in test:
            pred = model(x)
            test_accuracy.update_state(tf.argmax(t,1), tf.argmax(pred,1))
            test_losses.update_state(t, pred)
        
        # updataing training statistics
        train_acc.append(train_accuracy.result().numpy())
        test_acc.append(test_accuracy.result().numpy())
        train_l.append(train_losses.result().numpy())
        test_l.append(test_losses.result().numpy())
        train_accuracy.reset_state()
        test_accuracy.reset_state()
        train_losses.reset_state()
        test_losses.reset_state()
        
    # collecting losses in a dictionary
    losses = { "test loss":test_l , "training loss":train_l , "test accuracy":test_acc , "training accuracy":train_acc}
    
    return  losses

In [8]:
train_dataset, test_dataset = load_and_prep_dataset("CIFAR", batch_size=60, shuffle_size=512)

model = CNN2ModelMasked()
model(list(train_dataset)[0][0])
initial_weights = model.get_weights()
initial_mask = model.get_masks()
initial_b_mask = model.get_binary_masks()
print(initial_mask)
print(initial_b_mask)
print("pruning_rates: ", get_pruning_rates(initial_b_mask))
print(model.trainable_variables)
model.summary()

losses = train_mask(train_dataset, test_dataset, model)
plot_losses("CIFAR", "TestSuperMaskOtimization", losses,"CNN Loss and Accuracy for supermask model")

[<tf.Tensor: shape=(16384, 256), dtype=float32, numpy=
array([[0.4581738 , 0.5799989 , 0.4621882 , ..., 0.52411056, 0.36852047,
        0.67673594],
       [0.43519253, 0.6635228 , 0.58269286, ..., 0.40841594, 0.4972405 ,
        0.6766733 ],
       [0.3719052 , 0.4826121 , 0.626531  , ..., 0.706634  , 0.68747675,
        0.5611489 ],
       ...,
       [0.65025574, 0.4183221 , 0.43523479, ..., 0.38998443, 0.60151505,
        0.4569376 ],
       [0.6262519 , 0.44637433, 0.27587998, ..., 0.30124587, 0.5658544 ,
        0.49375638],
       [0.57453865, 0.5532706 , 0.4838324 , ..., 0.3154695 , 0.64649326,
        0.515374  ]], dtype=float32)>, <tf.Tensor: shape=(256, 256), dtype=float32, numpy=
array([[0.72930586, 0.62081826, 0.5630467 , ..., 0.5992753 , 0.7032884 ,
        0.42649275],
       [0.6091988 , 0.68663853, 0.546972  , ..., 0.33514825, 0.66069543,
        0.3636892 ],
       [0.71843505, 0.39507845, 0.47479647, ..., 0.67338884, 0.30302197,
        0.6124765 ],
       ...,
     

                                                                                                                       

ValueError: No gradients provided for any variable: (['masked_dense_9/mask:0', 'masked_dense_10/mask:0', 'masked_dense_11/mask:0'],). Provided `grads_and_vars` is ((None, <tf.Variable 'masked_dense_9/mask:0' shape=(16384, 256) dtype=float32, numpy=
array([[-0.16769671,  0.32276893, -0.15153646, ...,  0.09651709,
        -0.5385692 ,  0.7388115 ],
       [-0.2606964 ,  0.6790328 ,  0.33383775, ..., -0.37051773,
        -0.01103806,  0.7385254 ],
       [-0.52405214, -0.0695796 ,  0.5173633 , ...,  0.879092  ,
         0.78834915,  0.24582624],
       ...,
       [ 0.6201637 , -0.32966518, -0.2605245 , ..., -0.44737768,
         0.4117818 , -0.17267728],
       [ 0.5161705 , -0.21533084, -0.96499133, ..., -0.84137225,
         0.2649567 , -0.02497578],
       [ 0.3003931 ,  0.21389413, -0.06469297, ..., -0.7746711 ,
         0.6036601 ,  0.06151533]], dtype=float32)>), (None, <tf.Variable 'masked_dense_10/mask:0' shape=(256, 256) dtype=float32, numpy=
array([[ 0.99110365,  0.49302268,  0.25353622, ...,  0.4024465 ,
         0.86300635, -0.29617524],
       [ 0.44394565,  0.78445053,  0.18844366, ..., -0.6849911 ,
         0.66639495, -0.55938745],
       [ 0.936712  , -0.42601442, -0.1008997 , ...,  0.7235527 ,
        -0.8329487 ,  0.4577341 ],
       ...,
       [ 0.94095564,  0.35209584, -0.16997004, ..., -0.23656487,
         0.17381835, -0.46345973],
       [-0.96217513,  0.57883596,  0.6726358 , ..., -0.67818403,
        -0.87713027,  0.35162497],
       [-0.48331118,  0.6636536 ,  0.6246724 , ...,  0.76682377,
         0.33307028, -0.6073289 ]], dtype=float32)>), (None, <tf.Variable 'masked_dense_11/mask:0' shape=(256, 10) dtype=float32, numpy=
array([[ 0.6802385 , -0.31566882, -0.62436986, ...,  0.46554422,
        -0.72229314,  0.69384885],
       [ 0.5450978 , -0.29220414,  0.10997224, ..., -0.57311463,
        -0.08744979,  0.0216403 ],
       [ 0.44412112, -0.82029605, -0.60668993, ...,  0.42050028,
         0.01555204,  0.16507578],
       ...,
       [-0.3382361 , -0.506763  ,  0.7402952 , ...,  0.31500053,
         0.09572721, -0.53084564],
       [-0.05083132,  0.8928139 , -0.1665554 , ..., -0.21433973,
        -0.993649  , -0.3003471 ],
       [ 0.89363384,  0.34592652,  0.5137596 , ..., -0.6141715 ,
         0.09741783, -0.03924131]], dtype=float32)>)).

debugging to do:
- check paper for optimizer
- make the call function simpler
- research other examples of unusual trainable parameters in models