# Optimizing Masks to create WTs

In [1]:
# importing necessary libraries and the cnn architecture I defined

from cnn_architecture import CNN2Model
from utils import *
from load_datasets import load_and_prep_dataset

import tensorflow_datasets as tfds
import tensorflow_probability as tfp
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.io import loadmat
import copy

# all the extra stuff for supermasks

all the variables i have to check their meaning:
- use bias
- dynamik scaling
- sigmoid bias
- use learning phase


In [2]:
class MaskedDense(tf.keras.layers.Dense):
    
    # untrainable normal Dense layer
    # trainable mask, that is sigmoided (maybe squished) and then multiplied to Dense
    
    def __init__(self, units,*args, **kwargs):
        super(MaskedDense, self).__init__(units, *args, **kwargs)        

        
    def build(self, input_shape):
        super(MaskedDense, self).build(input_shape)
        
        # make bias and weights untrainable
        self._trainable_weights.remove(self.kernel)
        self._non_trainable_weights.append(self.kernel)
        self._trainable_weights.remove(self.bias)
        self._non_trainable_weights.append(self.bias)
        
        # create mask and make it trainable
        mask_init = tf.random.uniform(shape=self.kernel.shape,minval=-1, maxval=1, seed=None)
        self.kernel_mask = tf.Variable(initial_value=mask_init,
                                        trainable=True,
                                        validate_shape=True,
                                        name='mask',
                                        dtype=self.dtype,
                                        shape=self.kernel.shape)
        self._trainable_weights.append(self.kernel_mask)
    
    
    def call(self, inputs):

        effective_mask = tf.cast(tfp.distributions.Bernoulli(probs=tf.nn.sigmoid(self.kernel_mask)).sample(), dtype=tf.float32)
        effective_kernel = tf.math.multiply(self.kernel, effective_mask)

        inputs = tf.convert_to_tensor(inputs)
        outputs = tf.linalg.matmul(inputs, effective_kernel)
        outputs = tf.nn.bias_add(outputs, self.bias)
        output =  self.activation(outputs)
        #output = self.activation(tf.matmul(inputs, tf.math.multiply(self.kernel, tf.math.round(tf.nn.sigmoid(self.kernel_mask)))))
        return output
    
    def get_mask(self):
        return  tf.nn.sigmoid(self.kernel_mask)
    
    def get_binary_mask(self):
        return tf.math.round(tf.nn.sigmoid(self.kernel_mask))

In [3]:
class CNN2ModelMasked(tf.keras.Model):
    
    # basic
    def __init__(self):
        super(CNN2ModelMasked, self).__init__()
        
        # set biases to a value that is not exactly 0.0, so they don't get handled like pruned values
        self.bias_in = tf.keras.initializers.Constant(value=0.0000000001)
        
        self.conv1 = tf.keras.layers.Conv2D(filters=64, kernel_size=3,activation="relu", padding="same",kernel_initializer='glorot_uniform', bias_initializer=self.bias_in) # [batchsize,32,32,64]
        self.conv2 = tf.keras.layers.Conv2D(filters=64, kernel_size=3,activation="relu", padding="same",kernel_initializer='glorot_uniform', bias_initializer=self.bias_in) # [batchsize,32,32,64]
        self.maxpool = tf.keras.layers.MaxPooling2D(pool_size=(2, 2),strides=(2, 2),input_shape=(32, 32, 64)) # [batchsize,16,16,64]
        self.flatten = tf.keras.layers.Flatten() # [batch_size,16384]
        self.dense1 = MaskedDense(256, activation="relu",kernel_initializer='glorot_uniform', bias_initializer=self.bias_in) # [batch_size,256]
        self.dense2 = MaskedDense(256, activation="relu",kernel_initializer='glorot_uniform', bias_initializer=self.bias_in) # [batch_size,256]
        self.dense3 = MaskedDense(10, activation="softmax",kernel_initializer='glorot_uniform', bias_initializer=self.bias_in) # [batch_size,256]
        
        # Making the weights of the conv layers untrainable
        self.conv1.trainable = False
        self.conv2.trainable = False
    
    @tf.function
    def call(self, inputs):
        
        # adjust the dense layers to be multiplayed with trainable mask (which gets assigned binary values for this step)
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense3(x)
        return x
    
    def get_masks(self):
        return [self.dense1.get_mask(), self.dense2.get_mask(), self.dense3.get_mask()]
        
    def get_binary_masks(self):
        return [self.dense1.get_binary_mask(), self.dense2.get_binary_mask(), self.dense3.get_binary_mask()]

In [4]:
# modified train loop to also work with sparse networks (such that pruned weights remain frozen at 0.0)

def train_mask(train, test, model, num_epochs=5):
    
    # hyperparameters
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)
    loss_function= tf.keras.losses.CategoricalCrossentropy()
    
    # initializing training statistics
    train_accuracy = tf.keras.metrics.Accuracy(name='test_accuracy')
    test_accuracy = tf.keras.metrics.Accuracy(name='train_accuracy')
    train_losses = tf.keras.metrics.CategoricalCrossentropy(name='train_losses')
    test_losses = tf.keras.metrics.CategoricalCrossentropy(name='test_losses')
    train_acc = []
    test_acc = []
    train_l =[]
    test_l = []

    for epoch in tqdm(range(num_epochs), leave=False, desc="training epochs"):
        
        #train step
        for x, t in train:
            with tf.GradientTape() as tape:
                pred = model(x)
                loss = loss_function(t, pred)
                train_losses.update_state(t, pred)
                train_accuracy.update_state(tf.argmax(t,1), tf.argmax(pred,1))
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))
            
        # test step
        for x, t in test:
            pred = model(x)
            test_accuracy.update_state(tf.argmax(t,1), tf.argmax(pred,1))
            test_losses.update_state(t, pred)
        
        # updataing training statistics
        train_acc.append(train_accuracy.result().numpy())
        test_acc.append(test_accuracy.result().numpy())
        train_l.append(train_losses.result().numpy())
        test_l.append(test_losses.result().numpy())
        train_accuracy.reset_state()
        test_accuracy.reset_state()
        train_losses.reset_state()
        test_losses.reset_state()
        
    # collecting losses in a dictionary
    losses = { "test loss":test_l , "training loss":train_l , "test accuracy":test_acc , "training accuracy":train_acc}
    
    return  losses

In [5]:
train_dataset, test_dataset = load_and_prep_dataset("CIFAR", batch_size=60, shuffle_size=512)

model = CNN2ModelMasked()
model(list(train_dataset)[0][0])
initial_weights = model.get_weights()
initial_mask = model.get_masks()
initial_b_mask = model.get_binary_masks()
print(initial_mask)
print(initial_b_mask)
print("pruning_rates: ", get_pruning_rates(initial_b_mask))
print(model.trainable_variables)
model.summary()

losses = train_mask(train_dataset, test_dataset, model)
plot_losses("CIFAR", "TestSuperMaskOptimization", losses,"CNN Loss and Accuracy for supermask model")

[<tf.Tensor: shape=(16384, 256), dtype=float32, numpy=
array([[0.53760535, 0.46728128, 0.324164  , ..., 0.6236474 , 0.39963368,
        0.38473284],
       [0.5977696 , 0.2712129 , 0.6589951 , ..., 0.3204576 , 0.31190428,
        0.45388857],
       [0.60054004, 0.51819366, 0.6498607 , ..., 0.45939526, 0.36383328,
        0.69260967],
       ...,
       [0.5446179 , 0.28087568, 0.63103867, ..., 0.71974254, 0.42722768,
        0.51104695],
       [0.27105245, 0.33154792, 0.7295897 , ..., 0.37883714, 0.322522  ,
        0.36402732],
       [0.71628326, 0.4329152 , 0.2865379 , ..., 0.28553247, 0.3928525 ,
        0.3451348 ]], dtype=float32)>, <tf.Tensor: shape=(256, 256), dtype=float32, numpy=
array([[0.65370953, 0.2837926 , 0.61603445, ..., 0.69430786, 0.3465285 ,
        0.6250208 ],
       [0.28235665, 0.32301968, 0.2797009 , ..., 0.50128764, 0.37503606,
        0.2780315 ],
       [0.42282945, 0.29413077, 0.4734474 , ..., 0.6665324 , 0.6972801 ,
        0.5295898 ],
       ...,
     

                                                                                                                       

ValueError: No gradients provided for any variable: (['masked_dense/mask:0', 'masked_dense_1/mask:0', 'masked_dense_2/mask:0'],). Provided `grads_and_vars` is ((None, <tf.Variable 'masked_dense/mask:0' shape=(16384, 256) dtype=float32, numpy=
array([[ 0.15070581, -0.13106203, -0.7347009 , ...,  0.5050585 ,
        -0.40699172, -0.46950746],
       [ 0.3961804 , -0.98847747,  0.6588192 , ..., -0.75166965,
        -0.79123163, -0.18497133],
       [ 0.40771604,  0.07280684,  0.6184268 , ..., -0.16277742,
        -0.55876493,  0.8123481 ],
       ...,
       [ 0.17894769, -0.9401221 ,  0.5366752 , ...,  0.94318485,
        -0.29317117,  0.04419494],
       [-0.9892895 , -0.7011924 ,  0.9925418 , ..., -0.4944868 ,
        -0.7422056 , -0.55792665],
       [ 0.9260993 , -0.26996684, -0.9122586 , ..., -0.91718173,
        -0.43533754, -0.64049435]], dtype=float32)>), (None, <tf.Variable 'masked_dense_1/mask:0' shape=(256, 256) dtype=float32, numpy=
array([[ 0.63538504, -0.9257262 ,  0.47274995, ...,  0.8203368 ,
        -0.63433385,  0.5109143 ],
       [-0.9328017 , -0.73992896, -0.94594574, ...,  0.00515056,
        -0.51067185, -0.954247  ],
       [-0.3111689 , -0.87540555, -0.10631037, ...,  0.69254327,
         0.8343792 ,  0.11849761],
       ...,
       [-0.252342  ,  0.87439346,  0.08004522, ..., -0.6409266 ,
         0.49668264,  0.5864594 ],
       [ 0.14189792, -0.47228813, -0.41450596, ...,  0.2370553 ,
         0.05270553, -0.44875073],
       [ 0.9366102 , -0.9388621 ,  0.36147094, ..., -0.91936445,
         0.67487574, -0.30333972]], dtype=float32)>), (None, <tf.Variable 'masked_dense_2/mask:0' shape=(256, 10) dtype=float32, numpy=
array([[-0.7375388 , -0.44195628,  0.2902515 , ..., -0.9249046 ,
        -0.84397936, -0.19851017],
       [ 0.38751173,  0.9416125 , -0.14278531, ...,  0.15969372,
         0.9029684 , -0.92737246],
       [ 0.37761712,  0.02486968, -0.38356018, ..., -0.93937945,
        -0.33539557, -0.84828234],
       ...,
       [-0.14054036,  0.16727209,  0.01426864, ...,  0.03117347,
         0.14169931, -0.65410614],
       [ 0.5838053 ,  0.48769236, -0.15653753, ...,  0.03008032,
        -0.39825583, -0.86773825],
       [-0.70815086,  0.18554115, -0.42761683, ..., -0.8131945 ,
        -0.8805325 ,  0.60060644]], dtype=float32)>)).

debugging to do:
- check paper for optimizer
- make the call function simpler
- research other examples of unusual trainable parameters in models