# Optimizing Masks to create WTs

In [1]:
# importing necessary libraries and the cnn architecture I defined

from cnn_architecture import CNN2Model
from utils import *
from load_datasets import load_and_prep_dataset

import tensorflow_datasets as tfds
import tensorflow_probability as tfp
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.io import loadmat
import copy

# all the extra stuff for supermasks

all the variables i have to check their meaning:
- use bias
- dynamik scaling
- sigmoid bias
- use learning phase


In [2]:
class MaskedDense(tf.keras.layers.Dense):
    
    # untrainable normal Dense layer
    # trainable mask, that is sigmoided (maybe squished) and then multiplied to Dense
    
    def __init__(self, units,*args, **kwargs):
        super(MaskedDense, self).__init__(units, *args, **kwargs)        

        
    def build(self, input_shape):
        super(MaskedDense, self).build(input_shape)
        
        # make bias and weights untrainable
        self._trainable_weights.remove(self.kernel)
        self._non_trainable_weights.append(self.kernel)
        self._trainable_weights.remove(self.bias)
        self._non_trainable_weights.append(self.bias)
        
        # create mask and make it trainable
        mask_init = tf.random.uniform(shape=self.kernel.shape,minval=-1, maxval=1, seed=None)
        self.kernel_mask = tf.Variable(initial_value=mask_init,
                                        trainable=True,
                                        validate_shape=True,
                                        name='mask',
                                        dtype=self.dtype,
                                        shape=self.kernel.shape)
        self._trainable_weights.append(self.kernel_mask)
    
    @tf.function
    def call(self, inputs):

        effective_mask = tf.cast(tfp.distributions.Bernoulli(probs=tf.nn.sigmoid(self.kernel_mask)).sample(), dtype=tf.float32)
        effective_kernel = tf.math.multiply(self.kernel, effective_mask)

        inputs = tf.convert_to_tensor(inputs)
        outputs = tf.linalg.matmul(inputs, effective_kernel)
        outputs = tf.nn.bias_add(outputs, self.bias)
        output =  self.activation(outputs)
        #output = self.activation(tf.matmul(inputs, tf.math.multiply(self.kernel, tf.math.round(tf.nn.sigmoid(self.kernel_mask)))))
        return output
    
    def get_mask(self):
        return  tf.nn.sigmoid(self.kernel_mask)
    
    def get_binary_mask(self):
        return tf.math.round(tf.nn.sigmoid(self.kernel_mask))

In [3]:
class CNN2ModelMasked(tf.keras.Model):
    
    # basic
    def __init__(self):
        super(CNN2ModelMasked, self).__init__()
        
        # set biases to a value that is not exactly 0.0, so they don't get handled like pruned values
        self.bias_in = tf.keras.initializers.Constant(value=0.0000000001)
        
        self.conv1 = tf.keras.layers.Conv2D(filters=64, kernel_size=3,activation="relu", padding="same",kernel_initializer='glorot_uniform', bias_initializer=self.bias_in) # [batchsize,32,32,64]
        self.conv2 = tf.keras.layers.Conv2D(filters=64, kernel_size=3,activation="relu", padding="same",kernel_initializer='glorot_uniform', bias_initializer=self.bias_in) # [batchsize,32,32,64]
        self.maxpool = tf.keras.layers.MaxPooling2D(pool_size=(2, 2),strides=(2, 2),input_shape=(32, 32, 64)) # [batchsize,16,16,64]
        self.flatten = tf.keras.layers.Flatten() # [batch_size,16384]
        self.dense1 = MaskedDense(256, activation="relu",kernel_initializer='glorot_uniform', bias_initializer=self.bias_in) # [batch_size,256]
        self.dense2 = MaskedDense(256, activation="relu",kernel_initializer='glorot_uniform', bias_initializer=self.bias_in) # [batch_size,256]
        self.dense3 = MaskedDense(10, activation="softmax",kernel_initializer='glorot_uniform', bias_initializer=self.bias_in) # [batch_size,256]
        
        # Making the weights of the conv layers untrainable
        self.conv1.trainable = False
        self.conv2.trainable = False
    
    @tf.function
    def call(self, inputs):
        
        # adjust the dense layers to be multiplayed with trainable mask (which gets assigned binary values for this step)
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.maxpool(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense3(x)
        return x
    
    def get_masks(self):
        return [self.dense1.get_mask(), self.dense2.get_mask(), self.dense3.get_mask()]
        
    def get_binary_masks(self):
        return [self.dense1.get_binary_mask(), self.dense2.get_binary_mask(), self.dense3.get_binary_mask()]

In [4]:
# modified train loop to also work with sparse networks (such that pruned weights remain frozen at 0.0)

def train_mask(train, test, model, num_epochs=5):
    
    # hyperparameters
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)
    loss_function= tf.keras.losses.CategoricalCrossentropy()
    
    # initializing training statistics
    train_accuracy = tf.keras.metrics.Accuracy(name='test_accuracy')
    test_accuracy = tf.keras.metrics.Accuracy(name='train_accuracy')
    train_losses = tf.keras.metrics.CategoricalCrossentropy(name='train_losses')
    test_losses = tf.keras.metrics.CategoricalCrossentropy(name='test_losses')
    train_acc = []
    test_acc = []
    train_l =[]
    test_l = []

    for epoch in tqdm(range(num_epochs), leave=False, desc="training epochs"):
        
        #train step
        for x, t in train:
            with tf.GradientTape() as tape:
                pred = model(x)
                loss = loss_function(t, pred)
                train_losses.update_state(t, pred)
                train_accuracy.update_state(tf.argmax(t,1), tf.argmax(pred,1))
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))
            
        # test step
        for x, t in test:
            pred = model(x)
            test_accuracy.update_state(tf.argmax(t,1), tf.argmax(pred,1))
            test_losses.update_state(t, pred)
        
        # updataing training statistics
        train_acc.append(train_accuracy.result().numpy())
        test_acc.append(test_accuracy.result().numpy())
        train_l.append(train_losses.result().numpy())
        test_l.append(test_losses.result().numpy())
        train_accuracy.reset_state()
        test_accuracy.reset_state()
        train_losses.reset_state()
        test_losses.reset_state()
        
    # collecting losses in a dictionary
    losses = { "test loss":test_l , "training loss":train_l , "test accuracy":test_acc , "training accuracy":train_acc}
    
    return  losses

In [5]:
train_dataset, test_dataset = load_and_prep_dataset("CIFAR", batch_size=60, shuffle_size=512)

model = CNN2ModelMasked()
model(list(train_dataset)[0][0])
initial_weights = model.get_weights()
initial_mask = model.get_masks()
initial_b_mask = model.get_binary_masks()
print(initial_mask)
print(initial_b_mask)
print("pruning_rates: ", get_pruning_rates(initial_b_mask))
print(model.trainable_variables)
model.summary()

losses = train_mask(train_dataset, test_dataset, model)
plot_losses("CIFAR", "TestSuperMaskOptimization", losses,"CNN Loss and Accuracy for supermask model")

[<tf.Tensor: shape=(16384, 256), dtype=float32, numpy=
array([[0.59537953, 0.41279688, 0.2965429 , ..., 0.67885566, 0.6025972 ,
        0.56398875],
       [0.3956442 , 0.27382806, 0.60157245, ..., 0.5319057 , 0.3677409 ,
        0.6960501 ],
       [0.66915065, 0.6015967 , 0.7213296 , ..., 0.5422151 , 0.43727148,
        0.7220916 ],
       ...,
       [0.3562014 , 0.37151465, 0.67716545, ..., 0.6896454 , 0.4130926 ,
        0.63781345],
       [0.5311647 , 0.30909315, 0.37218198, ..., 0.7097203 , 0.5879254 ,
        0.5107565 ],
       [0.6952853 , 0.41865733, 0.3240383 , ..., 0.5837637 , 0.5600426 ,
        0.6335896 ]], dtype=float32)>, <tf.Tensor: shape=(256, 256), dtype=float32, numpy=
array([[0.7300684 , 0.5553849 , 0.39425352, ..., 0.6950708 , 0.26959568,
        0.26990044],
       [0.47060034, 0.5187653 , 0.34267816, ..., 0.27809468, 0.5124563 ,
        0.5464804 ],
       [0.33388066, 0.43779075, 0.537593  , ..., 0.6862121 , 0.38142988,
        0.36360753],
       ...,
     

                                                                                                                       

ValueError: No gradients provided for any variable: (['masked_dense/mask:0', 'masked_dense_1/mask:0', 'masked_dense_2/mask:0'],). Provided `grads_and_vars` is ((None, <tf.Variable 'masked_dense/mask:0' shape=(16384, 256) dtype=float32, numpy=
array([[ 0.38624954, -0.35241508, -0.86381507, ...,  0.74851775,
         0.41629863,  0.25736618],
       [-0.42364788, -0.9752865 ,  0.4120214 , ...,  0.12779641,
        -0.54192066,  0.8285587 ],
       [ 0.7043462 ,  0.4121225 ,  0.9510665 , ...,  0.16926336,
        -0.25224304,  0.95486045],
       ...,
       [-0.5918896 , -0.5257244 ,  0.7407756 , ...,  0.79846215,
        -0.3511951 ,  0.5658865 ],
       [ 0.12482071, -0.8043623 , -0.52286744, ...,  0.89402604,
         0.35539556,  0.04303265],
       [ 0.8249464 , -0.3282876 , -0.7352748 , ...,  0.33824348,
         0.24133515,  0.5476477 ]], dtype=float32)>), (None, <tf.Variable 'masked_dense_1/mask:0' shape=(256, 256) dtype=float32, numpy=
array([[ 0.9949696 ,  0.22245216, -0.42946744, ...,  0.8239341 ,
        -0.9966748 , -0.9951277 ],
       [-0.11773443,  0.07509637, -0.65138197, ..., -0.9539323 ,
         0.04983544,  0.18645978],
       [-0.6906853 , -0.25013304,  0.15065646, ...,  0.78246975,
        -0.48348355, -0.55974054],
       ...,
       [-0.5895989 ,  0.16284013,  0.8015759 , ...,  0.75324774,
        -0.66168   ,  0.6958852 ],
       [ 0.64037466,  0.27150702, -0.12357831, ..., -0.3145256 ,
        -0.45562577,  0.65652156],
       [ 0.7273257 ,  0.9206035 ,  0.6498828 , ...,  0.82206273,
         0.9152143 ,  0.9761505 ]], dtype=float32)>), (None, <tf.Variable 'masked_dense_2/mask:0' shape=(256, 10) dtype=float32, numpy=
array([[ 0.8343587 ,  0.9642339 , -0.37855053, ...,  0.65979934,
        -0.479177  , -0.97033453],
       [ 0.1748693 ,  0.679363  , -0.47277403, ...,  0.9636688 ,
        -0.5824561 ,  0.85412073],
       [ 0.7098551 ,  0.36310554,  0.19838381, ..., -0.05459714,
         0.6844263 , -0.92587805],
       ...,
       [-0.29909205, -0.50934076,  0.5422611 , ...,  0.41986918,
        -0.33481193,  0.7812269 ],
       [ 0.07545018, -0.58397675,  0.80845404, ..., -0.24385357,
        -0.8144741 , -0.85531497],
       [ 0.59705234,  0.183671  ,  0.6438403 , ..., -0.7426691 ,
        -0.01123381,  0.6296878 ]], dtype=float32)>)).

debugging to do:
- check paper for optimizer
- make the call function simpler
- research other examples of unusual trainable parameters in models