In [None]:
import matplotlib.pyplot as plt
import numpy as np

import tensorflow as tf
import tensorflow_probability as tfp
import time

In [None]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
def preprocess_images(images):
  images = images.reshape((images.shape[0], 28,28)) / 255.
  return images

train_images = preprocess_images(train_images)
test_images = preprocess_images(test_images)

train_labels = np.expand_dims(train_labels,axis=-1)
test_labels = np.expand_dims(test_labels,axis=-1)
train_size = 60000
batch_size = 200
test_size = 10000

train_images = tf.expand_dims(train_images, axis = -1)
test_images = tf.expand_dims(test_images, axis = -1)

In [None]:
train_dataset = (tf.data.Dataset.from_tensor_slices((train_images, train_labels))
                 .shuffle(train_size,reshuffle_each_iteration=True).batch(batch_size,drop_remainder=True))
test_dataset = (tf.data.Dataset.from_tensor_slices((test_images,test_labels))
                .shuffle(test_size).batch(batch_size,drop_remainder=True))

In [None]:
train_labels.shape, train_images.shape

In [None]:
import tensorflow as tf
tf.get_logger().setLevel('ERROR')

In [None]:

class CustomDropout(tf.keras.layers.Layer):
    def __init__(self, rate, input_dim, **kwargs):
        super(CustomDropout, self).__init__(**kwargs)
        self.rate = 1-rate
        self.input_dim = input_dim
        self.mask_w = self.add_weight(shape=(self.input_dim,n_decision_makers), trainable=True)
        self.mask_b = self.add_weight(shape=(n_decision_makers,), initializer="zeros",trainable=True)

    def call(self, inputs, label, training=None):
        if training:
          scce = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
          loss = scce(tf.tile(tf.transpose([label],perm = [1,2,0]),[1,n_decision_makers,1]),inputs)
          threshold = tfp.stats.percentile(loss, q=self.rate*100)
          dropout_mask = (loss<=threshold) ## <= 1-rate keep the best 10%
          mask = tf.tile(tf.expand_dims(dropout_mask, axis=-1), [1,1,10])

          mask_pred = tf.nn.sigmoid(tf.matmul(tf.keras.layers.Flatten()(inputs), self.mask_w)+self.mask_b)
          mask_pred = tf.tile(mask_pred, [1,10])
          return tf.multiply(tf.keras.layers.Reshape((n_decision_makers,10))(mask_pred), inputs), tf.cast(mask,'float32'), mask_pred
        else:
          mask_pred = tf.nn.sigmoid(tf.matmul(tf.keras.layers.Flatten()(inputs), self.mask_w)+self.mask_b)
          mask_pred = tf.tile(mask_pred, [1,10])
          return tf.multiply(tf.keras.layers.Reshape((n_decision_makers,10))(mask_pred),inputs),tf.ones(shape = (batch_size,n_decision_makers,10)),mask_pred ## reshape self.mask

In [None]:
n_decision_makers = 4  #100
class MyModel(tf.keras.Model):
    def __init__(self,**kwargs):
      super(MyModel,self).__init__(**kwargs)

      self.flat1 = tf.keras.layers.Flatten()
      self.flat2 = tf.keras.layers.Flatten()
      self.flat3 = tf.keras.layers.Flatten()
      self.flat4 = tf.keras.layers.Flatten()
      self.flat5 = tf.keras.layers.Flatten()
      self.flat6 = tf.keras.layers.Flatten()
      self.reshape1 = tf.keras.layers.Reshape((n_decision_makers,10))
      self.reshape2 = tf.keras.layers.Reshape((n_decision_makers,10))
      self.dropout1 = CustomDropout(0.7,n_decision_makers*10)

      self.dropout4 = tf.keras.layers.Dropout(0.2)
      self.dropout5 = tf.keras.layers.Dropout(0.2)

      self.pool1 = tf.keras.layers.MaxPooling2D((2, 2))
      self.pool2 = tf.keras.layers.MaxPooling2D((2, 2))

      self.conv1 = tf.keras.layers.Conv2D(64, 3, activation='relu',padding='same',kernel_regularizer=tf.keras.regularizers.l1(l=0.01),kernel_initializer='he_uniform',)
      self.conv11 = tf.keras.layers.Conv2D(128, 3, activation='relu',padding='same',kernel_regularizer=tf.keras.regularizers.l1(l=0.01),kernel_initializer='he_uniform',)
      self.dense1 = tf.keras.layers.Dense(10,activation=tf.nn.softmax)
      self.batchnorm1 = tf.keras.layers.BatchNormalization()

      self.conv2 = tf.keras.layers.Conv2D(128, 3, activation='relu',padding='same',kernel_regularizer=tf.keras.regularizers.l1(l=0.01),kernel_initializer='he_uniform',)
      self.conv22 = tf.keras.layers.Conv2D(64, 3, activation='relu',padding='same',kernel_regularizer=tf.keras.regularizers.l1(l=0.01),kernel_initializer='he_uniform',)
      self.dense2 = tf.keras.layers.Dense(10,activation=tf.nn.softmax)
      self.batchnorm2 = tf.keras.layers.BatchNormalization()
      self.dense5 = tf.keras.layers.Dense(n_decision_makers*10,activation=tf.nn.relu)

      self.dense7 = tf.keras.layers.Dense(10, activation=tf.nn.softmax)

    def call(self, input):

      [input, label] = input
      hidden_conv1 = self.dropout4(self.batchnorm1(self.pool1(self.conv1(self.conv11(input)))))
      hidden_conv1_reshape = self.flat4(hidden_conv1)
      hidden_conv1_out = self.dense1(hidden_conv1_reshape)

      hidden_conv2 = self.dropout5(self.batchnorm2(self.pool2(self.conv2(self.conv22(hidden_conv1)))))
      hidden_conv2_reshape = self.flat5(hidden_conv2)
      hidden_conv2_out = self.dense2(hidden_conv2_reshape)

      hidden1 = self.dense5(hidden_conv2_reshape)

      hidden1_reshape = self.reshape1(hidden1)
      hidden1_softmax = tf.nn.softmax(hidden1_reshape)
      hidden1_out,hidden1_true_mask,hidden1_pred_mask = self.dropout1(hidden1_softmax,label)
      outputs = self.dense7(self.flat1(hidden1_out)) #leader outputs

      return self.flat2(hidden1_true_mask), hidden1_pred_mask, hidden_conv1_out, hidden_conv2_out, hidden1_out, outputs

In [None]:
model = MyModel()
model([tf.zeros((batch_size, 28, 28, 1)),tf.zeros((batch_size, 1))])
model.summary()

local_loss

In [None]:
from keras import backend as K


train_loss, test_loss = tf.keras.metrics.Mean(),tf.keras.metrics.Mean()
mask_loss = tf.keras.metrics.Mean()
train_acc = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')
test_acc = tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy')


initial_learning_rate = 1e-4
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=100000,
    decay_rate=0.96,
    staircase=True)

optimizer_mask = tf.keras.optimizers.Adam()
optimizer = tf.keras.optimizers.legacy.Adam(learning_rate = 5e-4)


scce = tf.keras.losses.SparseCategoricalCrossentropy()
cce = tf.keras.losses.BinaryCrossentropy()

def compute_loss(hidden1_true_mask, hidden1_pred_mask, conv1, conv2, hidden1, y, output):
  loss_local_conv1 = scce(y,conv1)
  loss_local_conv2 = scce(y,conv2)

  loss_local_hidden1 = scce(tf.tile(y, [1,n_decision_makers]),hidden1)

  loss = scce(y,output)

  loss_mask = cce(hidden1_true_mask, hidden1_pred_mask)

  return loss, loss_mask, loss_local_conv1, loss_local_conv2, loss_local_hidden1

def compute_loss_mask(hidden1_true_mask, hidden1_pred_mask):
  loss_mask = cce(hidden1_true_mask, hidden1_pred_mask)

  return loss_mask

def compute_acc(model, x, y):
  _,_,_,_,_,output = model([x,y])
  acc = tf.keras.metrics.sparse_categorical_accuracy(y, output)
  return acc

def train_step(model, x, y, optimizer):
  with tf.GradientTape(persistent =True) as tape:
    hidden1_true_mask, hidden1_pred_mask, conv1, conv2, hidden1, output = model([x,y], training=True)
    loss, loss_mask, loss_local_conv1, loss_local_conv2, loss_local_hidden1 = compute_loss(hidden1_true_mask, hidden1_pred_mask, conv1, conv2, hidden1, y, output)

  gradients_global = tape.gradient(loss, model.layers[-1].trainable_variables)
  optimizer.apply_gradients(zip(gradients_global, model.layers[-1].trainable_variables))

  gradients_local = tape.gradient(loss_local_hidden1, model.layers[-2].trainable_variables)
  optimizer.apply_gradients(zip(gradients_local, model.layers[-2].trainable_variables))

  gradients_local = tape.gradient(loss_local_conv2,model.trainable_variables[10:17])
  optimizer.apply_gradients(zip(gradients_local, model.trainable_variables[10:17]))

  gradients_local = tape.gradient(loss_local_conv1,model.trainable_variables[2:10])
  optimizer.apply_gradients(zip(gradients_local, model.trainable_variables[2:10]))

  gradients_local = tape.gradient(loss_mask, model.layers[-15].trainable_variables)
  optimizer_mask.apply_gradients(zip(gradients_local, model.layers[-15].trainable_variables))

  train_acc(y,output)
  train_loss(loss)

  for i in range(3):
    with tf.GradientTape(persistent =True) as tape:
      hidden1_true_mask, hidden1_pred_mask, _ ,_,_,_= model([x,y], training=True)
      loss_mask = compute_loss_mask(hidden1_true_mask, hidden1_pred_mask)
    gradients_local = tape.gradient(loss_mask, model.layers[-15].trainable_variables)
    optimizer_mask.apply_gradients(zip(gradients_local, model.layers[-15].trainable_variables))

  mask_loss(loss_mask)

def test_step(model, x, y):
  _,_,_,_,_,output = model([x,y])
  scce = tf.keras.losses.SparseCategoricalCrossentropy()
  loss = scce(y, output)

  test_loss(loss)
  test_acc(y, output)

In [None]:
%load_ext tensorboard

In [None]:
import tensorflow as tf
import datetime


!rm -rf ./logs/

In [None]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = 'logs/gradient_tape/' + current_time + str(n_decision_makers)+'dropout/train'
test_log_dir = 'logs/gradient_tape/' + current_time + str(n_decision_makers)+'dropout/test'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir)

In [None]:
EPOCHS = 50
from tqdm.notebook import tqdm

import time

for epoch in range(0, 0+EPOCHS):
  start=time.time()
  for i, (train_x, train_y) in enumerate(tqdm(train_dataset)):
    train_step(model, train_x, train_y, optimizer)

  with train_summary_writer.as_default():
    tf.summary.scalar('loss', train_loss.result(), step=epoch)
    tf.summary.scalar('loss', mask_loss.result(), step=epoch)
    tf.summary.scalar('accuracy', train_acc.result(), step=epoch)

  for test_x, test_y in test_dataset:
    test_step(model, test_x, test_y)
  with test_summary_writer.as_default():
    tf.summary.scalar('loss', test_loss.result(), step=epoch)
    tf.summary.scalar('accuracy', test_acc.result(), step=epoch)

  template = 'Epoch {}, Loss: {}, Accuracy: {}, MaskLoss: {}, Test Loss: {}, Test Accuracy: {}'
  print(template.format(epoch+1,
                         train_loss.result(),
                         train_acc.result()*100,
                         mask_loss.result(),
                         test_loss.result(),
                         test_acc.result()*100))

  train_loss.reset_states()
  test_loss.reset_states()
  mask_loss.reset_states()
  train_acc.reset_states()
  test_acc.reset_states()

  print("Time elapsed: ", time.time()-start)