# Stocastic Network Depth with Autograph



<table class="tfo-notebook-buttons" align="left"><td>
<a target="_blank"  href="https://colab.research.google.com/github/tensorflow/tensorflow/tree/master/tensorflow/contrib/autograph/examples/notebooks/stocastic_depth.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>  
</td><td>
<a target="_blank"  href="https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/autograph/examples/notebooks/stocastic_depth.ipynb"><img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a></td></table>

An AutoGraph implementation of: https://arxiv.org/abs/1603.09382 as a keras model.

In [0]:
!pip install tf-nightly

In [0]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.contrib import autograph

import numpy as np

tf.VERSION

But to debug the model it might be convienient to shut off AutoGraph enable eager execution. 



In [0]:
if tf.executing_eagerly():
  def convert():
    def decorate(func):
      return func
    return decorate
else:
  convert = autograph.convert

The main idea is the `StocasticNetworkDepth` model. It's like a `tf.keras.Sequential` model except that each layer is run or skipped with a certain probability.

This is easily expressed in python, and converted for graph execution with `tf.contrib.autograph`:

In [0]:
import tensorflow.keras.backend as K

class StocasticNetworkDepth(tf.keras.Sequential):
  def __init__(self, pfirst=1.0, plast=0.5, **kwargs):
    self.pfirst = pfirst
    self.plast = plast
         
    super(StocasticNetworkDepth, self).__init__(**kwargs)

  @convert()
  def call(self, inputs):
    training = K.learning_phase()
    
    if not training:
      return super(StocasticNetworkDepth, self).call(inputs)
    
    depth = len(self.layers)
    plims = tf.lin_space(self.pfirst, self.plast, depth)
    
    p = tf.random_uniform((depth,), dtype=tf.float32)
    
    skips = p>=plims
    x = inputs
    for i in range(depth):
      x = self.layers[i](x, skip=skips[i])
        
    return x
  

Image sizes or sequence lengths _can_ change, as long as the rest of the output head is insensitive to the image (use global pooling, not a "fully connected" layer).

The layers of this model are residual-blocks containing two colvolutions, some supporting batch norm layers, and a bypass connection:  

In [0]:
class PadChannels(layers.Layer):
  def __init__(self, channels):
    super(PadChannels, self).__init__()
    self.channels = channels
    
  def call(self, inputs):
    in_shape = tf.shape(inputs)
    missing_channels = tf.convert_to_tensor(self.channels-in_shape[-1])
    z_shape = tf.concat([in_shape[:-1], missing_channels[tf.newaxis]], axis=0)
    zeros = tf.zeros(shape = z_shape, dtype=inputs.dtype)
    return tf.concat([inputs, zeros], axis=-1)

In [0]:
class NoOp(layers.Layer):
  def call(self, inputs):
    return inputs

In [0]:
class SkippableResBlock(tf.keras.Model):   
  def __init__(self, channels, downsample=False):
    super(SkippableResBlock, self).__init__()
    
    self.downsample=downsample
    
    if self.downsample:
      strides=(2,2)
      self.bypass1 = layers.AveragePooling2D(strides, strides)
      self.bypass2 = PadChannels(channels)
    else:
      strides=(1,1)
      self.bypass1 = NoOp()
      self.bypass2 = NoOp()
      
    self.conv1 = layers.Conv2D(channels, (3,3), padding='same', strides=strides,
                               kernel_initializer='he_normal', use_bias=False) 
    self.bn1 = layers.BatchNormalization()
    self.relu1 = layers.Activation('relu')
    self.conv2 = layers.Conv2D(channels, (3,3), padding='same', 
                               kernel_initializer='he_normal', use_bias=False) 
    self.bn2 = layers.BatchNormalization() 
    self.add = layers.Add()
    self.relu2 = layers.Activation('relu')
    
  @convert()
  def call(self, inputs, skip=False):
    
    bypass = self.bypass1(inputs)
    bypass = self.bypass2(bypass)
    
    #if skip:
    #  return bypass
    
    result = self.conv1(inputs)
    result = self.bn1(result)
    result = self.relu1(result)
    result = self.conv2(result)
    result = self.bn2(result)
    result = self.add([result, bypass])
    result = self.relu2(result)
  
    return result    
  
  

In [0]:
def make_model(res_blocks_per_group, start_channels=16, groups=3):
  channels = start_channels
  
  model = tf.keras.Sequential([
      layers.Lambda(
          lambda x: tf.image.convert_image_dtype(x, dtype=tf.float32)),
      layers.Conv2D(channels,(3,3), padding='same',
                    kernel_initializer='he_normal')
  ])

  stocastic = StocasticNetworkDepth(1.0, 0.5)
  
  for group in range(groups):
    for block in range(res_blocks_per_group):
      downsample = block == 0 and group != 0
      stocastic.add(SkippableResBlock(channels, downsample=downsample))        
    channels *= 2
  
  stocastic.build([None, None, None, start_channels])
  model.add(stocastic)
  model.add(layers.GlobalAvgPool2D())
  model.add(layers.BatchNormalization())
  model.add(layers.Dense(10, activation='softmax'))

  return model

In [0]:
EPOCHS = 1

train, test = tf.keras.datasets.cifar10.load_data()

train_data, train_labels = train

test_data, test_labels = test

In [0]:
model = make_model(18)

In [0]:
checkpoint_path = "training/cp-{epoch:04d}.ckpt"
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, 
                                                 save_weights_only=True,
                                                 verbose=1)

In [0]:
def lr_schedule(epoch):
  base_lr = 0.1
  
  lr = base_lr
  if epoch>250:
    lr /= 10
    
  if epoch>375:
    lr /= 10
    
  return lr

In [0]:
lr_callback = tf.keras.callbacks.LearningRateScheduler(schedule=lr_schedule)

In [0]:
callbacks = [cp_callback, lr_callback]
optimizer = tf.keras.optimizers.SGD(decay=1e-4, momentum=0.9, nesterov=True)
model.compile(optimizer, loss='sparse_categorical_crossentropy',
              callbacks=callbacks)

In [0]:
model.build([None, None, None, 3])
#model(tf.convert_to_tensor(test_data[:10]))

In [0]:
#model.build(tf.TensorShape([None, None, None, 3]))
model.summary()

In [0]:
model.layers[2].summary()

In [0]:
model.layers[2].layers[-1].summary()

In [0]:
model.fit(train_data, train_labels, epochs=500)