In [1]:
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
from tensorflow.keras import datasets, layers, utils, Sequential, Model
import matplotlib.pyplot as plt


# Dataset

In [2]:
(train_x, train_y), (test_x, test_y) = datasets.mnist.load_data()
train_x = train_x.reshape([train_x.shape[0], -1]) / 255.
test_x = test_x.reshape([test_x.shape[0], -1]) / 255.

train_y = utils.to_categorical(train_y, 10)
test_y = utils.to_categorical(test_y, 10)

print(train_x.shape, test_x.shape)

(60000, 784) (10000, 784)


# Model

In [5]:
class NN(tf.keras.Model):
    
    def __init__(self, inputs_shape= 28 * 28):
        super(NN, self).__init__()
        
        self.inputs_shape = inputs_shape
        self.fc1 = layers.Dense(512, activation="relu", kernel_initializer='glorot_uniform')
        self.fc2 = layers.Dense(10, activation="softmax", kernel_initializer='glorot_uniform')
        
        self.train_on_sparse = True
    
    def reset_masks(self):
        self.masks = []
        
        for layer in self.layers:        
            w, b = layer.get_weights()
            self.masks.append(tf.ones_like(w))
            self.masks.append(tf.ones_like(b))
        
        return self.masks
    
    def update_masks(self, trainable_vars):
           
        # FIXME
        sparsity = 0.5
    
        for i, wb in enumerate(trainable_vars):
            
            qk = tfp.stats.percentile(tf.math.abs(wb), q = sparsity * 100)
            mask = tf.where(tf.math.abs(wb) < qk, 0., 1.)
            
            # Keep track of masks for "Training on Sparse" step.
            self.masks[i] = mask
        
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
            
        # Compute gradients
        trainable_vars = self.trainable_variables
        
        # ----------- 1) UPDATE WEIGHTS/BIASES WITH MASKS -----------
        if self.train_on_sparse:
            self.update_masks(trainable_vars)
            
            for i, (wb, mask) in enumerate(zip(trainable_vars, self.masks)):
                trainable_vars[i].assign(tf.multiply(wb, mask))
            
            #raise Exception("HELLO")
        
        gradients = tape.gradient(loss, trainable_vars)
        
        # ----------- 2) UPDATE GRADS WITH MASKS -----------
        if self.train_on_sparse:
            # USE MASKS FROM STEP 1).
            
            for i, (grad, mask) in enumerate(zip(gradients, self.masks)):
                gradients[i] = tf.multiply(grad, mask)
                
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}
    
    def call(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

    def build_model(self):
        x = layers.Input(shape=self.inputs_shape)
        return Model(inputs=[x], outputs = self.call(x))         

In [6]:
model = NN()
model.build_model()
model.reset_masks()
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

# Just use `fit` as usual
model.fit(train_x, train_y, epochs=1)

 196/1875 [==>...........................] - ETA: 1:24 - loss: 1.0477 - accuracy: 0.7211

KeyboardInterrupt: 

# Test

In [None]:
def f(w1, w2):
     return 3 * w1 ** 2 + 2 * w1 * w2

w1, w2 = tf.Variable(5.), tf.Variable(3.)

with tf.GradientTape() as tape:
    z = f(w1, w2)

gradients = tape.gradient(z, [w1, w2])

In [None]:
gradients